首页 文章

了解张量流和变量共享中的variable_scope和name_scope

提问于
浏览
0

我喜欢做fowolling操作:给出4个张量 abcd 和一个权重变量 w ,计算 W*aW*bW*cW*d 但是在不同的子图中 . 代码,我有如下:

def forward(inputs):
  w = tf.get_variable("weights", ...)
  return tf.matmult(w, inputs)

with tf.name_scope("group_1"):
  a = tf.placeholder(...)
  b = tf.placeholder(...)
  c = tf.placeholder(...)

  aa = forward(a)
  bb = forward(b)
  cc = forward(c)

with tf.name_scope("group_2):
  d = tf.placeholder(...)

  tf.get_variable_scope().reuse_variable()
  dd = forward(d)

这个例子似乎运行但我不确定变量 W 是否被重用,特别是在 group_1 当我添加 tf.get_variable_scope.reuse_variable() 时,我得到一个错误,说没有可共享的变量 . 当我在tensorboard中可视化图形时,我在 group_1 子图中有几个 weigths_* .

1 回答

  • 1

    以下代码执行您想要的操作:

    import tensorflow as tf
    
    def forward(inputs):
        init = tf.random_normal_initializer()
        w = tf.get_variable("weights", shape=(3,2), initializer=init)
        return tf.matmul(w, inputs)
    
    with tf.name_scope("group_1"):
        a = tf.placeholder(tf.float32, shape=(2, 3), name="a")
        b = tf.placeholder(tf.float32, shape=(2, 3), name="b")
        c = tf.placeholder(tf.float32, shape=(2, 3), name="c")
        with tf.variable_scope("foo", reuse=False):
            aa = forward(a)
        with tf.variable_scope("foo", reuse=True):
            bb = forward(b)
            cc = forward(c)
    
    with tf.name_scope("group_2"):
        d = tf.placeholder(tf.float32, shape=(2, 3), name="d")
        with tf.variable_scope("foo", reuse=True):
            dd = forward(d)
    
    init = tf.initialize_all_variables()
    
    with tf.Session() as sess:
        sess.run(init)
        print(bb.eval(feed_dict={b: np.array([[1,2,3],[4,5,6]])}))
        for var in tf.all_variables():
            print(var.name)
            print(var.eval())
    

    需要了解的一些重要事项:

    • name_scope() 影响所有操作 except variables created with get_variable() .

    • 要在范围中放置变量,需要使用 variable_scope() . 例如,占位符 abc 实际上名为 "group_1/a""group_1/b""group_1/c""group_1/d" ,但 weights 变量名为 "foo/weights" . 所以 get_variable("weights") 在名称范围 "group_1" 和变量范围 "foo" 实际上寻找 "foo/weights" .

    如果您不确定存在哪些变量以及它们的命名方式,则 all_variables() 函数很有用 .

相关问题