LIUPENG BLOG
Liupeng
Feb 22, 2020
It takes 8 minutes to read this article.

model.fit()

model.fit()

数据集上的迭代

fit(
x=None, 
y=None, 
batch_size=None, 
epochs=1, 
verbose=1, 
callbacks=None, 
validation_split=0.0, 
validation_data=None, 
shuffle=True, 
class_weight=None, 
sample_weight=None, 
initial_epoch=0, 
steps_per_epoch=None, 
validation_steps=None, 
validation_freq=1, 
max_queue_size=10, 
workers=1, 
use_multiprocessing=False
)

参数

  • x:输入数据。它可能是:

    • Numpy数组(或类似数组的数组)或数组列表(如果模型具有多个输入)。
    • 如果模型已命名输入,则dict将输入名称映射到相应的数组/张量。
    • 生成器或keras.utils.Sequence返回 (inputs, targets)(inputs, targets, sample weights)
    • 如果从框架本机张量(例如TensorFlow数据张量)馈送,则为None(默认)。
  • y:目标数据。像输入数据一样x,它可以是Numpy数组,框架本机张量,Numpy数组列表(如果模型具有多个输出)或如果从框架本机张量馈送(例如TensorFlow)则为None(默认)数据张量)。如果命名了模型中的输出层,则还可以传递将输出名称映射到Numpy数组的字典。如果不指定,则不应指定x生成器或keras.utils.Sequence实例 y(因为将从获取目标x)。

  • batch_size:整数或None。每个梯度更新的样本数。如果未指定,batch_size则默认为32。batch_size如果数据是以符号张量,生成器或Sequence实例的形式(因为它们生成批处理),则不要指定。

  • epochs:整数。训练模型的时期数。时期是整个x和所y 提供数据的迭代。请注意,在与结合initial_epochepochs应理解为“最后时期”。不会针对给出的多次迭代训练模型epochs,而只是对epochs到达索引的纪元进行训练。

  • verbose:整数。0、1或2。详细模式。0 =静音,1 =进度条,2 =每个时期一行。

  • callbackskeras.callbacks.Callback实例列表。在培训和验证期间应用的回调列表(如果)。请参阅回调

  • validation_split:在0到1之间浮动。用作验证数据的训练数据的分数。模型将分开训练数据的这一部分,不对其进行训练,并且将在每个时期结束时评估此数据的损失和任何模型度量。在改组之前,从x和中y提供的数据中的最后一个样本中选择验证数据。当x是生成器或 Sequence实例时,不支持此参数。

  • validate_data:在每个时期结束时用于评估损失和任何模型指标的数据。该模型将不会根据此数据进行训练。 validation_data将覆盖validation_splitvalidation_data可能是:(x_val, y_val)-Numpy数组或张量的元组(x_val, y_val, val_sample_weights)-Numpy数组的元组-数据集或数据集迭代器

    对于前两种情况,batch_size必须提供。对于最后一种情况,validation_steps必须提供。

  • shuffle:布尔值(是否在每个时期之前都将训练数据洗牌)或str(用于“批处理”)。“批处理”是处理HDF5数据限制的特殊选项;它会按批处理大小的块进行混洗。当没有任何效果steps_per_epoch是不是None

  • class_weight:可选的字典,将类索引(整数)映射到权重(浮点)值,用于对损失函数加权(仅在训练过程中)。这有助于告诉模型“更多关注”来自代表性不足的类的样本。

  • sample_weight:训练样本的可选Numpy权重数组,用于对损失函数加权(仅在训练过程中)。您可以传递长度与输入样本相同的平坦(1D)Numpy数组(权重和样本之间的1:1映射),或者对于时间数据,可以传递带有shape的2D数组 (samples, sequence_length)以应用每个样品的每个时间步均具有不同的权重。在这种情况下,你应该确保指定 sample_weight_mode="temporal"compile()。当xgenerator或Sequenceinstance提供sample_weights作为的第三个元素时,不支持此参数x

  • initial_epoch:整数。开始训练的时期(用于恢复以前的训练运行)。

  • steps_per_epoch:整数或None。声明一个纪元完成并开始下一个纪元之前的总步数(一批样品)。在使用TensorFlow数据张量等输入张量进行训练时,默认None值等于数据集中的样本数除以批次大小;如果无法确定,则默认为1。

  • validation_steps:仅在steps_per_epoch 指定时才相关。停止之前要验证的步骤总数(样本批次)。

  • validation_steps:仅在validation_data提供时是相关的,并且是生成器。在每个时期结束时执行验证时,在停止之前要绘制的步骤总数(样本批次)。

  • validation_freq:仅在提供验证数据时才相关。整数或列表/元组/集合。如果是整数,则指定在执行新的验证运行之前要运行多少个训练时期,例如 validation_freq=2,每2个时期运行一次验证。如果列表,元组或集合指定要在其上运行验证的时期,例如validation_freq=[1, 2, 10]在第一个,第二个和第十个时期的末尾运行验证。

  • max_queue_size:整数。keras.utils.Sequence 仅用于生成器或输入。生成器队列的最大大小。如果未指定,max_queue_size则默认为10。

  • workers:整数。keras.utils.Sequence仅用于生成器或输入。使用基于进程的线程时,要启动的最大进程数。如果未指定,workers 则默认为1。如果为0,将在主线程上执行生成器。

  • use_multiprocessing:布尔值。keras.utils.Sequence仅用于生成器或 输入。如果为True,请使用基于进程的线程。如果未指定,use_multiprocessing则默认为 False。请注意,由于此实现依赖于多处理,因此不应将不可拾取的参数传递给生成器,因为它们无法轻易传递给子进程。

  • ** kwargs:用于向后兼容。

返回值

一个History对象。它的History.history属性是在连续时期训练损失值和度量值以及验证损失值和验证度量值(如果适用)的记录​​。

报错信息

  • RuntimeError:如果从未编译过模型。
  • ValueError:如果提供的输入数据与模型期望值不匹配。