首页 文章

与Keras表现不佳的Python生成器

提问于
浏览
0

我试图使用CNN和Keras进行简单的图像分类练习 .

我有一个列表存储图像的方向(train_glob)和另一个列表与相应的分类标签一个热编码(dummy_y) .

函数load_one()将路径和一些参数作为参数进行调整大小和扩充,并将变换后的图像作为numpy数组返回 .

当我通过.fit()以批处理模式运行代码时,创建一个包含所有图像的单个文件,称为batch_features,我在5个时期之后实现了0.7的正确精度 .

当我尝试使用python生成器复制结果来提供数据并使用.fit_generator()进行训练时出现问题,性能结果实际上很差,实际上我希望它们稍微好一点,因为据我所知,更多正在馈送数据 .

与批处理功能不同,在生成器中,y随机地改变图像的亮度并在数据上循环更多次,因此理论上如果我正确理解发生器如何工作,我希望结果更好 .

这是我的发电机功能

def generate_arrays_from_file(paths,cat_list, batch_size = 128):
    number = 0
    max_len = len(paths)
    while True:
        batch_features = np.zeros((batch_size, 128, 64, 3),np.uint8)
        batch_labels = np.zeros((batch_size,cat_list.shape[1]),np.uint8)
        for i in range(number*batch_size, number*batch_size + batch_size):
            #choose random index in features
            #index= np.random.choice(len(paths))
            batch_features[i % batch_size] = load_one(paths[i], final_size=(64,128), augment = True)
            batch_labels[i % batch_size] = cat_list[i]
        batch_features = normalize_data(batch_features)
        yield batch_features, batch_labels
        number += 1
        if number*batch_size + batch_size > max_len:
            number = 0

这是对发电机的keras调用

mod.fit_generator(generate_arrays_from_file(train_glob, dummy_y, 256),
        samples_per_epoch=16368, nb_epoch=10)

这是通过发电机的正确方法吗?

谢谢

1 回答

  • 1

    为了符合您的准确性,您希望输入相同的数据 . 由于您在没有生成器的情况下对图像进行了一些转换,因此准确度不匹配是正常的 .

    如果你认为发电机是问题,你可以很容易地测试它 .

    启动一个python shell,导入你的包,制作一个生成器并获得一些样本,看看它们是否符合你的预期 .

    # say you save the generator in mygenerator.py
    
    $ python3
    Python 3.5.2 (default, Nov 17 2016, 17:05:23) 
    [GCC 5.4.0 20160609] on linux
    Type "help", "copyright", "credits" or "license" for more information.
    >>> import mygenerator
    
    # initialise paths, cat_list here:
    >>> paths = [...]
    >>> cat_list = [...]
    
    # use a small batch_size to be able to see the results 
    >>> g = mygenerator.generate_arrays_from_file(paths, cat_list, batch_size = 2)
    >>> batch = g.__next__()
    # now check if batch is what you expect
    

    要保存图像或显示图像(来自this tutorial):

    # Save:
    from scipy import misc
    misc.imsave('face.png', image_array) # uses the Image module (PIL)
    
    # Display:
    import matplotlib.pyplot as plt
    plt.imshow(image_array)
    plt.show()
    

    More about accuracy and data augmentation

    • 如果您在不同的数据集上测试两个模型(一个使用生成器训练,一个使用预加载的整个数据),精度将明显不同 . 尝试对两个模型使用完全相同的测试和训练数据,完全关闭增强,你应该看到类似的精度(相同数量的纪元,batch_sizes等) . 如果您不使用上述方法来修复生成器 .

    • 如果只有很少的数据点,模型将非常快速地过度拟合(因此具有高训练精度) . 数据增强有助于减少过度拟合并使模型更好地概括 . 这也意味着随着数据的变化,在极少数时期之后训练的准确性将会降低 .

    • 请注意,很容易让图像处理(数据增加)出错并且没有意识到这一点 . 错误地裁剪,你得到一个黑色的图像 . 放大太多你只会得到噪音 . 混淆x和y,你得到一个完全错误的图像 . 等等......测试你的发电机,看看它输出的图像是否符合预期,标签是否匹配 .

    • 亮度 . 如果改变输入图像的亮度,则会使模型与亮度无关 . 您不会改进旋转和缩放等其他内容的泛化 . 确保不要过度改变亮度:不要让图像完全变白或完全变黑 - 如果发生这种情况,它将解释精度的大幅下降 .

    • VMRuiz的评论中所指出的,如果您有分类数据(您这样做),请使用 keras.preprocessing.image.ImageDataGeneratordocs) . 它会为你节省很多时间 . A very good example on Keras blogcode here) . 如果您对自己的图像处理感兴趣,请查看ImageDataGenerator source code .

相关问题