我试图使用CNN和Keras进行简单的图像分类练习 .
我有一个列表存储图像的方向(train_glob)和另一个列表与相应的分类标签一个热编码(dummy_y) .
函数load_one()将路径和一些参数作为参数进行调整大小和扩充,并将变换后的图像作为numpy数组返回 .
当我通过.fit()以批处理模式运行代码时,创建一个包含所有图像的单个文件,称为batch_features,我在5个时期之后实现了0.7的正确精度 .
当我尝试使用python生成器复制结果来提供数据并使用.fit_generator()进行训练时出现问题,性能结果实际上很差,实际上我希望它们稍微好一点,因为据我所知,更多正在馈送数据 .
与批处理功能不同,在生成器中,y随机地改变图像的亮度并在数据上循环更多次,因此理论上如果我正确理解发生器如何工作,我希望结果更好 .
这是我的发电机功能
def generate_arrays_from_file(paths,cat_list, batch_size = 128):
number = 0
max_len = len(paths)
while True:
batch_features = np.zeros((batch_size, 128, 64, 3),np.uint8)
batch_labels = np.zeros((batch_size,cat_list.shape[1]),np.uint8)
for i in range(number*batch_size, number*batch_size + batch_size):
#choose random index in features
#index= np.random.choice(len(paths))
batch_features[i % batch_size] = load_one(paths[i], final_size=(64,128), augment = True)
batch_labels[i % batch_size] = cat_list[i]
batch_features = normalize_data(batch_features)
yield batch_features, batch_labels
number += 1
if number*batch_size + batch_size > max_len:
number = 0
这是对发电机的keras调用
mod.fit_generator(generate_arrays_from_file(train_glob, dummy_y, 256),
samples_per_epoch=16368, nb_epoch=10)
这是通过发电机的正确方法吗?
谢谢
1 回答
为了符合您的准确性,您希望输入相同的数据 . 由于您在没有生成器的情况下对图像进行了一些转换,因此准确度不匹配是正常的 .
如果你认为发电机是问题,你可以很容易地测试它 .
启动一个python shell,导入你的包,制作一个生成器并获得一些样本,看看它们是否符合你的预期 .
要保存图像或显示图像(来自this tutorial):
More about accuracy and data augmentation
如果您在不同的数据集上测试两个模型(一个使用生成器训练,一个使用预加载的整个数据),精度将明显不同 . 尝试对两个模型使用完全相同的测试和训练数据,完全关闭增强,你应该看到类似的精度(相同数量的纪元,batch_sizes等) . 如果您不使用上述方法来修复生成器 .
如果只有很少的数据点,模型将非常快速地过度拟合(因此具有高训练精度) . 数据增强有助于减少过度拟合并使模型更好地概括 . 这也意味着随着数据的变化,在极少数时期之后训练的准确性将会降低 .
请注意,很容易让图像处理(数据增加)出错并且没有意识到这一点 . 错误地裁剪,你得到一个黑色的图像 . 放大太多你只会得到噪音 . 混淆x和y,你得到一个完全错误的图像 . 等等......测试你的发电机,看看它输出的图像是否符合预期,标签是否匹配 .
亮度 . 如果改变输入图像的亮度,则会使模型与亮度无关 . 您不会改进旋转和缩放等其他内容的泛化 . 确保不要过度改变亮度:不要让图像完全变白或完全变黑 - 如果发生这种情况,它将解释精度的大幅下降 .
如VMRuiz的评论中所指出的,如果您有分类数据(您这样做),请使用
keras.preprocessing.image.ImageDataGenerator
(docs) . 它会为你节省很多时间 . A very good example on Keras blog(code here) . 如果您对自己的图像处理感兴趣,请查看ImageDataGenerator source code .