首页 文章

Keras的生成性对抗网络不能像预期的那样工作

提问于
浏览
1

我是Keras机器学习的初学者 . 我正在尝试了解生成性对抗网络(GAN) . 为此,我试图编写一个简单的例子 . 我生成数据具有以下功能:

def genReal(l):
    realX = []
    for i in range(l):
        x = []
        y = []
        for i in np.arange(0.0, 1.0, 0.02):
            x.append(i + np.random.normal(0,0.01))
            y.append(-abs(i-0.5)+0.5+ np.random.normal(0,0.01))

        data = np.array(list(zip(x, y)))
        data = np.reshape(data, (100))
        data.clip(0,1)
        realX.append(data)

    realX = np.array(realX)
    return realX

使用此功能进行格式化的数据与以下示例类似:

enter image description here
现在的目标应该是训练神经网络以生成类似的数据 . 对于GAN,我们需要一个 Generator 网络,我建模如下:

generator = Sequential()
generator.add(Dense(128, input_shape=(100,), activation='relu'))
generator.add(Dropout(rate=0.2))
generator.add(Dense(128, activation='relu'))
generator.add(Dropout(rate=0.2))
generator.add(Dense(100, activation='sigmoid'))
generator.compile(loss='mean_squared_error', optimizer='adam')

一个看起来像这样的鉴别器:

discriminator = Sequential()
discriminator.add(Dense(128, input_shape=(100,), activation='relu'))
discriminator.add(Dropout(rate=0.2))
discriminator.add(Dense(128, activation='relu'))
discriminator.add(Dropout(rate=0.2))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='mean_squared_error', optimizer='adam')

组合模型:

ganInput = Input(shape=(100,))
x = generator(ganInput)
ganOutput = discriminator(x)

GAN = Model(inputs=ganInput, outputs=ganOutput)
GAN.compile(loss='binary_crossentropy', optimizer='adam')

我有一个产生噪音的函数(随机数组)

def noise(l):
   noise = np.array([np.random.uniform(0, 1, size=[l, ])])
   return noise

然后我正在训练模型:

for i in range(1000000):
    fake = generator.predict(noise(100))
    print(i, "==>", discriminator.predict(fake))
    discriminator.train_on_batch(genReal(1), np.array([1]))
    discriminator.train_on_batch(fake, np.array([0]))

    discriminator.trainable = False
    GAN.train_on_batch(noise(100), np.array([1]))
    discriminator.trainable = True

就像你可以看到我已经尝试训练模型为1. Mio迭代 . 但是,生成器输出后来看起来像这样的数据(尽管输入不同):

enter image description here

绝对不是我想要的 . 所以我的问题是:是1. Mio Iterations不够,或者我的程序概念有什么不对

编辑:

这是我绘制数据的函数:

def plotData(data):
    x = np.reshape(data,(50,2))
    x = x.tolist()
    plt.scatter(list(zip(*x))[0],list(zip(*x))[1], c=col)

1 回答

  • 2

    您的实现问题是 discriminator.trainable = False 在编译 discriminator 后没有任何影响 . 因此,当您执行 GAN.train_on_batch 时,所有权重(来自鉴别器和生成器网络)都是可训练的 .

    此问题的解决方案是在编译 discriminator 之后和编译 GAN 之前设置 discriminator.trainable = False

    discriminator.compile(loss='mean_squared_error', optimizer='adam')    
    discriminator.trainable = False
    
    ganInput = Input(shape=(100,))
    x = generator(ganInput)
    ganOutput = discriminator(x)
    
    GAN = Model(inputs=ganInput, outputs=ganOutput)
    GAN.compile(loss='binary_crossentropy', optimizer='adam')
    

    NOTE . 我已经绘制了您的数据,看起来更像是这样:
    Generated data

相关问题