首页 文章

使用Keras随机扩充图像

提问于
浏览
3

我正在尝试使用MNIST数据集来学习Keras库 . 在MNIST中,有6万个训练图像和10k验证图像 .

在这两组中,我想在30%的图像上引入增强 .

datagen = ImageDataGenerator(horizontal_flip=True, vertical_flip=True)
datagen.fit(training_images)
datagen.fit(validation_images)

这不会增加图像,我不知道如何使用 model.fit_generator 方法 . 我目前的 model.fit 如下:

model.fit(training_images, training_labels, validation_data=(validation_images, validation_labels), epochs=10, batch_size=200, verbose=2)

如何对此数据集中的某些图像应用增强?

1 回答

  • 3

    我试着用以下方式定义自己的生成器:

    from sklearn.model_selection import train_test_split
    from six import next
    
    def partial_flow(array, flags, generator, aug_percentage, batch_size):
        # Splitting data into arrays which will be augmented and which won't
        not_aug_array, not_aug_flags, aug_array, aug_flags = train_test_split(
            array,
            test_size=aug_percentage)
        # Preparation of generators which will be used for augmentation.
        aug_split_size = int(batch_size * split_size)
        # We will use generator without any augmentation to yield not augmented data
        not_augmented_gen = ImageDataGenerator()
        aug_gen = generator.flow(
            x=aug_array,
            y=aug_flags,
            batch_size=aug_split_size)
        not_aug_gen = not_augmented_gen.flow(
            x=not_aug_array,
            y=not_aug_flags,
            batch_size=batch_size - aug_split_size)
        # Yiedling data
        while True:
            # Getting augmented data
            aug_x, aug_y = next(aug_gen)
            # Getting not augmented data
            not_aug_x, not_aug_y = next(not_aug_gen)
            # Concatenation
            current_x = numpy.concatenate([aug_x, not_aug_x], axis=0)
            current_y = numpy.concatenate([aug_y, not_aug_y], axis=0)
            yield current_x, current_y
    

    现在您可以通过以下方式进行培

    batch_size = 200
     model.fit_generator(partial_flow(training_images, training_labels, 0.7, batch_size),
                         steps_per_epoch=int(training_images.shape[0] / batch_size),
                         epochs=10,
                         validation_data=partial_flow(validation_images, validation_labels, 0.7, batch_size),
                         validation_steps=int(validation_images.shape[0] / batch_size))
    

相关问题