首页 文章

Tensorflow - tf.variable_scope,GAN的重用参数

提问于
浏览
0

我正在尝试为项目构建一个GAN,我真的很想了解这个变量在tensorflow的variable_scope中的共享是如何工作的 .

为了构建GAN,我有一个发生器网络和两个鉴别器网络:一个鉴别器被馈送真实图像,一个鉴别器被馈送由生成器创建的伪图像 . 重要的是,用真实图像馈送的鉴别器和用假图像馈送的鉴别器需要共享相同的权重 . 为了做到这一点,我需要分享权重 .

我有一个鉴别器和生成器定义,让我们说:

def discriminator(images, reuse=False):
    with variable_scope("discriminator", reuse=reuse):

        #.... layer definitions, not important here
        #....
        logits = tf.layers.dense(X, 1)
        logits = tf.identity(logits, name="logits")
        out = tf.sigmoid(logits, name="out")
        # 14x14x64
    return logits, out

def generator(input_z, reuse=False):
    with variable_scope("generator", reuse=reuse):

        #.. not so important 
        out = tf.tanh(logits)

    return out

现在调用生成器和鉴别器函数:

g_model = generator(input_z)
d_model_real, d_logits_real = discriminator(input_real)

#Here , reuse=True should produce the weight sharing between d_model_real, d_logits_real
#and d_model_fake and d_logits_fake.. why?
d_model_fake, d_logits_fake = discriminator(g_model, reuse=True)

为什么 second 调用中的reuse = True语句会产生权重的共享?据我所知,您需要决定在第一次调用中重用变量,以便稍后在程序中的某个地方使用它们 .

如果有人可以向我解释这一点,我会非常高兴,我找不到这个主题的好消息来源,这对我来说似乎真的很混乱和复杂 . 谢谢!

1 回答

  • 2

    在引擎盖下,使用 tf.get_variable() 创建变量 .

    此函数将使用范围为变量名称添加前缀,并在创建新范围之前检查它是否存在 .

    例如,如果您在范围 "fc" 并调用 tf.get_variable("w", [10,10]) ,则变量名称将为 "fc/w:0" .

    现在当你再次执行此操作时,如果 reuse=True ,则范围将再次为 "fc" ,并且get_variable将重用变量 "fc/w:0" .

    但是,如果 reuse=False ,您将收到错误,因为变量 "fc/w:0" 已经存在,提示您使用其他名称或使用 reuse=True .

    例:

    In [1]: import tensorflow as tf
    
    In [2]: with tf.variable_scope("fc"):
       ...:      v = tf.get_variable("w", [10,10])
       ...:
    
    In [3]: v
    Out[3]: <tf.Variable 'fc/w:0' shape=(10, 10) dtype=float32_ref>
    
    In [4]: with tf.variable_scope("fc"):
       ...:      v = tf.get_variable("w", [10,10])
       ...:
    ValueError: Variable fc/w already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope?
    
    In [5]: with tf.variable_scope("fc", reuse=True):
       ...:      v = tf.get_variable("w", [10,10])
       ...:
    
    In [6]: v
    Out[6]: <tf.Variable 'fc/w:0' shape=(10, 10) dtype=float32_ref>
    

    请注意,您可以仅实例化一个鉴别器,而不是共享权重 . 然后,您可以使用placeholder_with_default决定使用实际数据或生成的数据来提供它 .

相关问题