首页 文章

Tensorflow:tf.get_variable如何工作?

提问于
浏览
9

我从这个question读到了关于tf.get_variable的内容,还有一些来自tensorflow网站上的文档 . 但是,我仍然不清楚,无法在网上找到答案 .

tf.get_variable如何工作?例如:

var1 = tf.Variable(3.,dtype=float64)
var2 = tf.get_variable("var1",[],dtype=tf.float64)

这是否意味着 var2another 变量,初始化类似于 var1 ?或 var2var1 的别名(我试过,它似乎没有)?

var1var2 如何相关?

当我们得到的变量不存在时,如何构造变量?

1 回答

  • 17

    tf.get_variable(name) 在tensorflow图中创建一个名为 name 的新变量(如果 name 已存在于当前作用域中,则添加_) .

    在您的示例中,您正在创建名为 var1python 变量 .

    ** Tensorflow图中该变量的名称不是** var1 ,而是 Variable:0 .

    您定义的每个节点都有自己可以指定的名称,或者让tensorflow给出一个默认(并且始终不同)的名称 . 您可以看到 name 值访问python变量的 name 属性 . (即 print(var1.name) ) .

    在第二行,您要定义 Python variable var2 ,其名称 in the tensorflow graphvar1 .

    剧本

    import tensorflow as tf
    
    var1 = tf.Variable(3.,dtype=tf.float64)
    print(var1.name)
    var2 = tf.get_variable("var1",[],dtype=tf.float64)
    print(var2.name)
    

    事实上印刷品:

    Variable:0
    var1:0
    

    相反,如果你想在张量流图中定义一个名为 var1 的变量(节点),然后获得对该节点的引用,那么你只需使用 tf.get_variable("var1") ,因为它将创建一个新的不同变量valled var1_1 .

    这个脚本

    var1 = tf.Variable(3.,dtype=tf.float64, name="var1")
    print(var1.name)
    var2 = tf.get_variable("var1",[],dtype=tf.float64)
    print(var2.name)
    

    打印:

    var1:0
    var1_1:0
    

    如果要创建对节点 var1 的引用,首先要:

    • 必须用 tf.get_variable 替换 tf.Variable . 使用 tf.Variable 创建的变量无法共享,而后者则可以 .

    • 知道 var1scope 是什么,并在声明引用时允许该范围的 reuse .

    查看代码是更好的理解方式

    import tensorflow as tf
    
    #var1 = tf.Variable(3.,dtype=tf.float64, name="var1")
    var1 = tf.get_variable(initializer=tf.constant_initializer(3.), dtype=tf.float64, name="var1", shape=())
    current_scope = tf.contrib.framework.get_name_scope()
    print(var1.name)
    with tf.variable_scope(current_scope, reuse=True):
        var2 = tf.get_variable("var1",[],dtype=tf.float64)
        print(var2.name)
    

    输出:

    var1:0
    var1:0
    

相关问题