我试图找出如何使用Generative Adversarial Networks的数据集标签信息 . 我试图使用can be found here的条件GAN的以下实现 . 我的数据集包含两个不同的图像域(真实对象和草图),具有公共类信息(椅子,树,橙等) . 我选择了这种实现,它只考虑两个不同的域作为对应的不同"classes"(列车样本 X
对应于真实图像,而目标样本 y
对应于草图图像) .
有没有办法修改我的代码并考虑我整个架构中的类信息(椅子,树等)?我希望我的鉴别器能够预测我生成的图像来自生成器是否属于特定类,而不仅仅是它们是否真实 . 实际上,在当前架构中,系统学会在所有情况下创建类似的草图 .
Update: 鉴别器返回一个大小为 1x7x7
的张量,然后 y_true
和 y_pred
在计算损失之前通过展平层:
def discriminator_loss(y_true, y_pred):
BATCH_SIZE=100
return K.mean(K.binary_crossentropy(K.flatten(y_pred), K.concatenate([K.ones_like(K.flatten(y_pred[:BATCH_SIZE,:,:,:])),K.zeros_like(K.flatten(y_pred[:BATCH_SIZE,:,:,:])) ]) ), axis=-1)
以及鉴别器对发电机的损失功能:
def discriminator_on_generator_loss(y_true,y_pred):
BATCH_SIZE=100
return K.mean(K.binary_crossentropy(K.flatten(y_pred), K.ones_like(K.flatten(y_pred))), axis=-1)
此外,我对输出1层的鉴别器模型的修改:
model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
#model.add(Activation('sigmoid'))
现在鉴别器输出1层 . 如何相应修改上述损失函数?我应该有7而不是1,因为 n_classes = 6
一类用于预测真假对吗?
2 回答
建议的解决方案
重用repository you shared中的代码,这里有一些建议的修改来训练分类器沿你的生成器和鉴别器(他们的架构和其他损失保持不变):
理论细节
我认为对于有条件的GAN如何工作以及这些方案中的鉴别者角色存在一些误解 .
鉴别者的角色
在GAN训练[4]的最小 - 最大游戏中,鉴别器
D
正在与生成器G
(您实际关心的网络)进行比赛,以便在D
的审查下,G
在输出真实结果方面变得更好 .为此,
D
经过培训,可以分析来自G
样本的实际样本;而G
经过培训,可以通过在目标分布后生成逼真的结果/结果来欺骗D
.针对
D
训练条件生成器(而不是简单地单独训练G
,仅使用L1 / L2丢失,例如DAE)提高了G
的采样能力,迫使其输出清晰,逼真的结果,而不是试图平均分布 .即使鉴别器可以有多个子网络来覆盖其他任务(参见下面的段落),
D
应该保留至少一个子网络/输出来覆盖其主要任务: telling real samples from generated ones apart . 要求D
同时回退进一步的语义信息(例如类)可能会干扰这个主要目的 .有条件的GAN
以无人监督的方式训练传统GAN以从随机噪声向量生成逼真数据(例如图像)作为输入 . [4]
如前所述,条件GAN具有进一步的输入条件 . 沿着/而不是噪声向量,它们从域
A
输入样本并从域B
返回相应的样本 .A
可以是完全不同的形式,例如B = sketch image
而A = discrete label
;B = volumetric data
而A = RGB image
等[3]这样的GAN也可以通过多个输入来调节,例如,
A = real image + discrete label
而B = sketch image
. 引入这种方法的着名作品是 InfoGAN [5] . 它介绍了如何在多个连续或离散输入上调整GAN(例如A = digit class + writing type
,B = handwritten digit image
), using a more advanced discriminator which has for 2nd task to force G to maximize the mutual-information between its conditioning inputs and its corresponding outputs .最大化cGAN的相互信息
InfoGAN鉴别器有2个头/子网络来完成它的2个任务[5]:
一个头
D1
做传统的真实/生成歧视 -G
必须最小化这个结果,即它必须欺骗D1
,以便它不能分辨真实形式生成的数据;另一个头
D2
(也称为Q
网络)试图回归输入A
信息 -G
必须最大化此结果,即它必须输出"show"的数据请求语义信息(参见G
条件输入及其输出之间的互信息最大化) .您可以在此处找到Keras实现,例如:https://github.com/eriklindernoren/Keras-GAN/tree/master/infogan .
一些工作正在使用类似的方案来改进对GAN生成的控制,通过使用提供的标签并最大化这些输入和输出之间的相互信息[6,7] . 基本思路总是一样的:
训练
G
以生成域B
的元素,给出域的一些输入A
;训练
D
以区分"real" / "fake"结果 -G
必须尽量减少这种情况;训练
Q
(例如分类器;可以与D
共享图层)估计B
样本的原始A
输入 -G
必须最大化此值 .结束
在您的情况下,您似乎有以下培训数据:
真实图片
Ia
相应的草图图片
Ib
对应的类标签
c
并且你想训练一个生成器
G
,这样给定一个图像Ia
及其类标签c
,它会输出一个合适的草图图像Ib'
.总而言之,这是您拥有的大量信息,您可以监控您在条件图像和条件标签上的培训......灵感来自上述方法[1,2,5,6,7],这里是一个使用所有这些信息来训练你的条件的可能方式
G
:Network G:
输入:
Ia
c
输出:
Ib'
架构:最新的(例如U-Net,ResNet,......)
损失:
Ib'
和Ib
之间的L1 / L2损失,-D
损失,Q
损失Network D:
输入:
Ia
Ib
(真实配对),Ia
Ib'
(假配对)输出:"fakeness"标量/矩阵
架构:最新的(例如PatchGAN)
损失:"fakeness"估计的交叉熵
Network Q:
输入:
Ib
(实际样本,用于训练Q
),Ib'
(假样本,当通过G
进行反向传播时)输出:
c'
(估计等级)架构:最新的(例如LeNet,ResNet,VGG,......)
损失:
c
和c'
之间的交叉熵Training Phase:
火车
D
上一批真正的对Ia
Ib
然后对一批假对Ia
Ib'
;在一批实际样本上训练
Q
Ib
;修复
D
和Q
权重;训练
G
,将其生成的输出Ib'
传递给D
和Q
以通过它们进行反向传播 .参考文献
Isola,Phillip,et al . "Image-to-image translation with conditional adversarial networks." arXiv preprint(2017) . http://openaccess.thecvf.com/content_cvpr_2017/papers/Isola_Image-To-Image_Translation_With_CVPR_2017_paper.pdf
朱俊彦,等 . "Unpaired image-to-image translation using cycle-consistent adversarial networks." arXiv preprint arXiv:1703.10593(2017) . http://openaccess.thecvf.com/content_ICCV_2017/papers/Zhu_Unpaired_Image-To-Image_Translation_ICCV_2017_paper.pdf
米尔扎,迈赫迪和西蒙奥辛德罗 . "Conditional generative adversarial nets." arXiv preprint arXiv:1411.1784(2014) . https://arxiv.org/pdf/1411.1784
Goodfellow,Ian,et al . "Generative adversarial nets."神经信息处理系统的进展 . 2014. http://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf
Chen,Xi,et al . "Infogan: Interpretable representation learning by information maximizing generative adversarial nets."神经信息处理系统的进展 . 2016. http://papers.nips.cc/paper/6399-infogan-interpretable-representation-learning-by-information-maximizing-generative-adversarial-nets.pdf
Lee,Minhyeok和Junhee Seok . "Controllable Generative Adversarial Network." arXiv preprint arXiv:1708.00598(2017) . https://arxiv.org/pdf/1708.00598.pdf
Odena,Augustus,Christopher Olah和Jonathon Shlens . "Conditional image synthesis with auxiliary classifier gans." arXiv preprint arXiv:1610.09585(2016) . http://proceedings.mlr.press/v70/odena17a/odena17a.pdf
您应该修改鉴别器模型,要么具有两个输出,要么具有“n_classes 1”输出 .
警告:我没有在你的鉴别器的定义中看到它输出'true / false',我看到它输出图像......
它应该包含
GlobalMaxPooling2D
或GlobalAveragePooling2D
.最后和一个或多个
Dense
层进行分类 .如果告诉真/假,最后一个密集应该有1个单位 .
否则
n_classes + 1
单位 .所以,你的鉴别者的结局应该是这样的
鉴别器现在将输出
n_classes
加上"true/fake"符号(您将无法使用"categorical")或甚至"fake class"(然后您将其他类归零并使用分类)您生成的草图应该与目标一起传递给鉴别器,该目标将是假类与其他类的串联 .
选项1 - 使用“真/假”标志 . (不要使用“categorical_crossentropy”)
选项2 - 使用“假类”(可以使用“categorical_crossentropy”):
现在将所有内容连接到一个目标数组(分别对应于输入草图)
更新了培训方法
对于此培训方法,您的损失函数应为以下之一:
discriminator.compile(loss='binary_crossentropy', optimizer=....)
discriminator.compile(loss='categorical_crossentropy', optimizer=...)
码:
正确编译模型
当您创建“discriminator”和“discriminator_on_generator”时: