首页 文章

将类信息添加到keras网络

提问于
浏览
13

我试图找出如何使用Generative Adversarial Networks的数据集标签信息 . 我试图使用can be found here的条件GAN的以下实现 . 我的数据集包含两个不同的图像域(真实对象和草图),具有公共类信息(椅子,树,橙等) . 我选择了这种实现,它只考虑两个不同的域作为对应的不同"classes"(列车样本 X 对应于真实图像,而目标样本 y 对应于草图图像) .

有没有办法修改我的代码并考虑我整个架构中的类信息(椅子,树等)?我希望我的鉴别器能够预测我生成的图像来自生成器是否属于特定类,而不仅仅是它们是否真实 . 实际上,在当前架构中,系统学会在所有情况下创建类似的草图 .

Update: 鉴别器返回一个大小为 1x7x7 的张量,然后 y_truey_pred 在计算损失之前通过展平层:

def discriminator_loss(y_true, y_pred):
     BATCH_SIZE=100
     return K.mean(K.binary_crossentropy(K.flatten(y_pred), K.concatenate([K.ones_like(K.flatten(y_pred[:BATCH_SIZE,:,:,:])),K.zeros_like(K.flatten(y_pred[:BATCH_SIZE,:,:,:])) ]) ), axis=-1)

以及鉴别器对发电机的损失功能:

def discriminator_on_generator_loss(y_true,y_pred):
     BATCH_SIZE=100
     return K.mean(K.binary_crossentropy(K.flatten(y_pred), K.ones_like(K.flatten(y_pred))), axis=-1)

此外,我对输出1层的鉴别器模型的修改:

model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
#model.add(Activation('sigmoid'))

现在鉴别器输出1层 . 如何相应修改上述损失函数?我应该有7而不是1,因为 n_classes = 6 一类用于预测真假对吗?

2 回答

  • 3

    建议的解决方案

    重用repository you shared中的代码,这里有一些建议的修改来训练分类器沿你的生成器和鉴别器(他们的架构和其他损失保持不变):

    from keras import backend as K
    from keras.models import Sequential
    from keras.layers.core import Dense, Dropout, Activation, Flatten
    from keras.layers.convolutional import Convolution2D, MaxPooling2D
    
    def lenet_classifier_model(nb_classes):
        # Snipped by Fabien Tanc - https://www.kaggle.com/ftence/keras-cnn-inspired-by-lenet-5
        # Replace with your favorite classifier...
        model = Sequential()
        model.add(Convolution2D(12, 5, 5, activation='relu', input_shape=in_shape, init='he_normal'))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Convolution2D(25, 5, 5, activation='relu', init='he_normal'))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Flatten())
        model.add(Dense(180, activation='relu', init='he_normal'))
        model.add(Dropout(0.5))
        model.add(Dense(100, activation='relu', init='he_normal'))
        model.add(Dropout(0.5))
        model.add(Dense(nb_classes, activation='softmax', init='he_normal'))
    
    def generator_containing_discriminator_and_classifier(generator, discriminator, classifier):
        inputs = Input((IN_CH, img_cols, img_rows))
        x_generator = generator(inputs)
    
        merged = merge([inputs, x_generator], mode='concat', concat_axis=1)
        discriminator.trainable = False
        x_discriminator = discriminator(merged)
    
        classifier.trainable = False
        x_classifier = classifier(x_generator)
    
        model = Model(input=inputs, output=[x_generator, x_discriminator, x_classifier])
    
        return model
    
    
    def train(BATCH_SIZE):
        (X_train, Y_train, LABEL_train) = get_data('train')  # replace with your data here
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        Y_train = (Y_train.astype(np.float32) - 127.5) / 127.5
        discriminator = discriminator_model()
        generator = generator_model()
        classifier = lenet_classifier_model(6)
        generator.summary()
        discriminator_and_classifier_on_generator = generator_containing_discriminator_and_classifier(
            generator, discriminator, classifier)
        d_optim = Adagrad(lr=0.005)
        g_optim = Adagrad(lr=0.005)
        generator.compile(loss='mse', optimizer="rmsprop")
        discriminator_and_classifier_on_generator.compile(
            loss=[generator_l1_loss, discriminator_on_generator_loss, "categorical_crossentropy"],
            optimizer="rmsprop")
        discriminator.trainable = True
        discriminator.compile(loss=discriminator_loss, optimizer="rmsprop")
        classifier.trainable = True
        classifier.compile(loss="categorical_crossentropy", optimizer="rmsprop")
    
        for epoch in range(100):
            print("Epoch is", epoch)
            print("Number of batches", int(X_train.shape[0] / BATCH_SIZE))
            for index in range(int(X_train.shape[0] / BATCH_SIZE)):
                image_batch = Y_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]
                label_batch = LABEL_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]  # replace with your data here
    
                generated_images = generator.predict(X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE])
                if index % 20 == 0:
                    image = combine_images(generated_images)
                    image = image * 127.5 + 127.5
                    image = np.swapaxes(image, 0, 2)
                    cv2.imwrite(str(epoch) + "_" + str(index) + ".png", image)
                    # Image.fromarray(image.astype(np.uint8)).save(str(epoch)+"_"+str(index)+".png")
    
                # Training D:
                real_pairs = np.concatenate((X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], image_batch),
                                            axis=1)
                fake_pairs = np.concatenate(
                    (X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], generated_images), axis=1)
                X = np.concatenate((real_pairs, fake_pairs))
                y = np.zeros((20, 1, 64, 64))  # [1] * BATCH_SIZE + [0] * BATCH_SIZE
                d_loss = discriminator.train_on_batch(X, y)
                print("batch %d d_loss : %f" % (index, d_loss))
                discriminator.trainable = False
    
                # Training C:
                c_loss = classifier.train_on_batch(image_batch, label_batch)
                print("batch %d c_loss : %f" % (index, c_loss))
                classifier.trainable = False
    
                # Train G:
                g_loss = discriminator_and_classifier_on_generator.train_on_batch(
                    X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], 
                    [image_batch, np.ones((10, 1, 64, 64)), label_batch])
                discriminator.trainable = True
                classifier.trainable = True
                print("batch %d g_loss : %f" % (index, g_loss[1]))
                if index % 20 == 0:
                    generator.save_weights('generator', True)
                    discriminator.save_weights('discriminator', True)
    

    理论细节

    我认为对于有条件的GAN如何工作以及这些方案中的鉴别者角色存在一些误解 .

    鉴别者的角色

    在GAN训练[4]的最小 - 最大游戏中,鉴别器 D 正在与生成器 G (您实际关心的网络)进行比赛,以便在 D 的审查下, G 在输出真实结果方面变得更好 .

    为此, D 经过培训,可以分析来自 G 样本的实际样本;而 G 经过培训,可以通过在目标分布后生成逼真的结果/结果来欺骗 D .

    注意:在条件GAN的情况下,即GAN将输入样本从一个域A(例如实际图片)映射到另一个域B(例如草图),D通常被馈送堆叠在一起的样本对并且必须区分“真实的“对(来自B的相应目标样本的输入样本)和”伪“对(来自G的相应输出的输入样本)[1,2]

    针对 D 训练条件生成器(而不是简单地单独训练 G ,仅使用L1 / L2丢失,例如DAE)提高了 G 的采样能力,迫使其输出清晰,逼真的结果,而不是试图平均分布 .

    即使鉴别器可以有多个子网络来覆盖其他任务(参见下面的段落), D 应该保留至少一个子网络/输出来覆盖其主要任务: telling real samples from generated ones apart . 要求 D 同时回退进一步的语义信息(例如类)可能会干扰这个主要目的 .

    注意:D输出通常不是简单的标量/布尔值 . 通常有一个鉴别器(例如PatchGAN [1,2])返回概率矩阵,评估从其输入得到的真实补丁是多少 .


    有条件的GAN

    以无人监督的方式训练传统GAN以从随机噪声向量生成逼真数据(例如图像)作为输入 . [4]

    如前所述,条件GAN具有进一步的输入条件 . 沿着/而不是噪声向量,它们从域 A 输入样本并从域 B 返回相应的样本 . A 可以是完全不同的形式,例如 B = sketch imageA = discrete label ; B = volumetric dataA = RGB image 等[3]

    这样的GAN也可以通过多个输入来调节,例如, A = real image + discrete labelB = sketch image . 引入这种方法的着名作品是 InfoGAN [5] . 它介绍了如何在多个连续或离散输入上调整GAN(例如 A = digit class + writing typeB = handwritten digit image ), using a more advanced discriminator which has for 2nd task to force G to maximize the mutual-information between its conditioning inputs and its corresponding outputs .


    最大化cGAN的相互信息

    InfoGAN鉴别器有2个头/子网络来完成它的2个任务[5]:

    • 一个头 D1 做传统的真实/生成歧视 - G 必须最小化这个结果,即它必须欺骗 D1 ,以便它不能分辨真实形式生成的数据;

    • 另一个头 D2 (也称为 Q 网络)试图回归输入 A 信息 - G 必须最大化此结果,即它必须输出"show"的数据请求语义信息(参见 G 条件输入及其输出之间的互信息最大化) .

    您可以在此处找到Keras实现,例如:https://github.com/eriklindernoren/Keras-GAN/tree/master/infogan .

    一些工作正在使用类似的方案来改进对GAN生成的控制,通过使用提供的标签并最大化这些输入和输出之间的相互信息[6,7] . 基本思路总是一样的:

    • 训练 G 以生成域 B 的元素,给出域的一些输入 A ;

    • 训练 D 以区分"real" / "fake"结果 - G 必须尽量减少这种情况;

    • 训练 Q (例如分类器;可以与 D 共享图层)估计 B 样本的原始 A 输入 - G 必须最大化此值 .


    结束

    在您的情况下,您似乎有以下培训数据:

    • 真实图片 Ia

    • 相应的草图图片 Ib

    • 对应的类标签 c

    并且你想训练一个生成器 G ,这样给定一个图像 Ia 及其类标签 c ,它会输出一个合适的草图图像 Ib' .

    总而言之,这是您拥有的大量信息,您可以监控您在条件图像和条件标签上的培训......灵感来自上述方法[1,2,5,6,7],这里是一个使用所有这些信息来训练你的条件的可能方式 G
    Network G:

    • 输入: Ia c

    • 输出: Ib'

    • 架构:最新的(例如U-Net,ResNet,......)

    • 损失: Ib'Ib 之间的L1 / L2损失, -D 损失, Q 损失

    Network D:

    • 输入: Ia Ib (真实配对), Ia Ib' (假配对)

    • 输出:"fakeness"标量/矩阵

    • 架构:最新的(例如PatchGAN)

    • 损失:"fakeness"估计的交叉熵

    Network Q:

    • 输入: Ib (实际样本,用于训练 Q ), Ib' (假样本,当通过 G 进行反向传播时)

    • 输出: c' (估计等级)

    • 架构:最新的(例如LeNet,ResNet,VGG,......)

    • 损失: cc' 之间的交叉熵

    Training Phase:

    • 火车 D 上一批真正的对 Ia Ib 然后对一批假对 Ia Ib' ;

    • 在一批实际样本上训练 Q Ib ;

    • 修复 DQ 权重;

    • 训练 G ,将其生成的输出 Ib' 传递给 DQ 以通过它们进行反向传播 .

    注意:这是一个非常粗略的架构描述 . 我建议通过文献([1,5,6,7]作为一个良好的开端)来获得更多的细节,也许是一个更精细的解决方案 .


    参考文献

  • 7

    您应该修改鉴别器模型,要么具有两个输出,要么具有“n_classes 1”输出 .

    警告:我没有在你的鉴别器的定义中看到它输出'true / false',我看到它输出图像......

    它应该包含 GlobalMaxPooling2DGlobalAveragePooling2D .
    最后和一个或多个 Dense 层进行分类 .

    如果告诉真/假,最后一个密集应该有1个单位 .
    否则 n_classes + 1 单位 .

    所以,你的鉴别者的结局应该是这样的

    ...GlobalMaxPooling2D()...
    ...Dense(someHidden,...)...
    ...Dense(n_classes+1,...)...
    

    鉴别器现在将输出 n_classes 加上"true/fake"符号(您将无法使用"categorical")或甚至"fake class"(然后您将其他类归零并使用分类)

    您生成的草图应该与目标一起传递给鉴别器,该目标将是假类与其他类的串联 .

    选项1 - 使用“真/假”标志 . (不要使用“categorical_crossentropy”)

    #true sketches into discriminator:
    fakeClass = np.zeros((total_samples,))
    sketchClass = originalClasses
    
    targetClassTrue = np.concatenate([fakeClass,sketchClass], axis=-1)
    
    #fake sketches into discriminator:
    fakeClass = np.ones((total_fake_sketches))
    sketchClass = originalClasses
    
    targetClassFake = np.concatenate([fakeClass,sketchClass], axis=-1)
    

    选项2 - 使用“假类”(可以使用“categorical_crossentropy”):

    #true sketches into discriminator:
    fakeClass = np.zeros((total_samples,))
    sketchClass = originalClasses
    
    targetClassTrue = np.concatenate([fakeClass,sketchClass], axis=-1)
    
    #fake sketches into discriminator:
    fakeClass = np.ones((total_fake_sketches))
    sketchClass = np.zeros((total_fake_sketches, n_classes))
    
    targetClassFake = np.concatenate([fakeClass,sketchClass], axis=-1)
    

    现在将所有内容连接到一个目标数组(分别对应于输入草图)

    更新了培训方法

    对于此培训方法,您的损失函数应为以下之一:

    • discriminator.compile(loss='binary_crossentropy', optimizer=....)

    • discriminator.compile(loss='categorical_crossentropy', optimizer=...)

    码:

    for epoch in range(100):
        print("Epoch is", epoch)
        print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))
    
        for index in range(int(X_train.shape[0]/BATCH_SIZE)):
    
            #names:
                #images -> initial images, not changed    
                #sketches -> generated + true sketches    
                #classes -> your classification for the images    
                #isGenerated -> the output of your discriminator telling whether the passed sketches are fake
    
            batchSlice = slice(index*BATCH_SIZE,(index+1)*BATCH_SIZE)
            trueImages = X_train[batchSlice]
    
            trueSketches = Y_train[batchSlice] 
            trueClasses = originalClasses[batchSlice]
            trueIsGenerated = np.zeros((len(trueImages),)) #discriminator telling whether the sketch is fake or true (generated images = 1)
            trueEndTargets = np.concatenate([trueIsGenerated,trueClasses],axis=1)
    
            fakeSketches = generator.predict(trueImages)
            fakeClasses = originalClasses[batchSlize]             #if option 1 -> telling class + isGenerated - use "binary_crossentropy"
            fakeClasses = np.zeros((len(fakeSketches),n_classes)) #if option 2 -> telling if generated is an individual class - use "categorical_crossentropy"    
            fakeIsGenerated = np.ones((len(fakeSketches),))
            fakeEndTargets = np.concatenate([fakeIsGenerated, fakeClasses], axis=1)
    
            allSketches = np.concatenate([trueSketches,fakeSketches],axis=0)            
            allEndTargets = np.concatenate([trueEndTargets,fakeEndTargets],axis=0)
    
            d_loss = discriminator.train_on_batch(allSketches, allEndTargets)
    
            pred_temp = discriminator.predict(allSketches)
            #print(np.shape(pred_temp))
            print("batch %d d_loss : %f" % (index, d_loss))
    
            ##WARNING## In previous keras versions, "trainable" only takes effect if you compile the models. 
                #you should have the "discriminator" and the "discriminator_on_generator" with these set at the creation of the models and never change it again   
    
            discriminator.trainable = False
            g_loss = discriminator_on_generator.train_on_batch(trueImages, trueEndTargets)
            discriminator.trainable = True
    
    
            print("batch %d g_loss : %f" % (index, g_loss[1]))
            if index % 20 == 0:
                generator.save_weights('generator', True)
                discriminator.save_weights('discriminator', True)
    

    正确编译模型

    当您创建“discriminator”和“discriminator_on_generator”时:

    discriminator.trainable = True
    for l in discriminator.layers:
        l.trainable = True
    
    
    discriminator.compile(.....)
    
    for l in discriminator_on_generator.layer[firstDiscriminatorLayer:]:
        l.trainable = False
    
    discriminator_on_generator.compile(....)
    

相关问题