首页 文章

Keras警告:Epoch包含的不仅仅是`samples_per_epoch`样本

提问于
浏览
3

我有大约6200个训练图像,我想使用 keras.preprocessing.image.ImageDataGenerator 类的 flow(X, y) 方法以下列方式扩充小数据集:

train_datagen = ImageDataGenerator(
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow(X_train , y_train)

validation_generator = test_datagen.flow(X_val , y_val)

history = model.fit_generator(
        train_generator,
        samples_per_epoch=1920,
        nb_epoch=10,
        verbose=1,
        validation_data=validation_generator,
        nb_val_samples=800)

其中 X_train / y_train 包含大约6000个训练图像和标签,并且 X_val / y_val 验证数据和模型是增强的VGG16模型 .

文件说

flow(X,y):获取numpy数据和标签数组,并生成批量的扩充/规范化数据 . 在无限循环中无限期地产生批次 .

对于具有10个时期的训练设置,每个时期1920个样本和32个batch_size,我得到以下训练跟踪:

1920/1920 [==============================] - 3525s - loss: 3.9101 - val_loss: 0.0269
Epoch 2/10
1920/1920 [==============================] - 3609s - loss: 1.0245 - val_loss: 0.0229
Epoch 3/10
1920/1920 [==============================] - 3201s - loss: 0.7620 - val_loss: 0.0161
Epoch 4/10
1916/1920 [============================>.] - ETA: 4s - loss: 0.5978 C:\Miniconda3\envs\carnd-term1\lib\site-packages\keras\engine\training.py:1537: UserWarning: Epoch comprised more than `samples_per_epoch` samples, which might affect learning results. Set `samples_per_epoch` correctly to avoid this warning.
  warnings.warn('Epoch comprised more than

为什么生成器无法生成无限批次,如文档所述?

1 回答

  • 4

    所以基本上在 KerasImageGenerator 类实现中存在一个小错误 . 有什么好处 - 除了这个恼人的警告之外没有任何错误发生 . 所以要澄清:

    • flowflow_from_directory 实际上在无限循环中产生样本 . 您可以通过测试以下代码轻松检查(警告 - 它将冻结您的 Python ):
    for x, y in train_generator:
        x = None
    
    • 您提到的警告在 fit_generator 方法中被提升 . 它基本上检查在一个纪元中处理的样本数量是否小于或等于 samples_per_epoch . 在你的情况下 - samples_per_epoch 可以被_1856506整除 - 如果Keras的实现是正确的 - 这个警告应该永远不会被提出......但是......

    • ..是的,为什么要提出这个警告呢?它's a little bit tricky. If you went deeper into implementation of a generator you would notice that generator is getting batches in a following manner : if you have let' s说 - 10个例子和 batch_size = 3 然后:

    • 它会先洗掉这十个例子的顺序,

    • 然后它将需要3个第一个洗牌的例子,然后是接下来的三个,依此类推,

    • 在第3批之后 - 当只剩下1个例子时 - 它将返回一个批次..只有一个样本 .

    不要问我为什么 - 这就是生成器的实现方式 . 好的是,它几乎不影响培训过程 .

    所以 - 总结一下 - 您可以忽略此警告,也可以将传递给发生器的样本数量整除 batch_size . 我知道它很麻烦,我希望它能在下一个版本中修复 .

相关问题