我从这个question读到了关于tf.get_variable的内容,还有一些来自tensorflow网站上的文档 . 但是,我仍然不清楚,无法在网上找到答案 .
tf.get_variable如何工作?例如:
var1 = tf.Variable(3.,dtype=float64)
var2 = tf.get_variable("var1",[],dtype=tf.float64)
这是否意味着 var2 是 another 变量,初始化类似于 var1 ?或 var2 是 var1 的别名(我试过,它似乎没有)?
var1 和 var2 如何相关?
当我们得到的变量不存在时,如何构造变量?
1 回答
tf.get_variable(name)
在tensorflow图中创建一个名为name
的新变量(如果name
已存在于当前作用域中,则添加_) .在您的示例中,您正在创建名为
var1
的 python 变量 .** Tensorflow图中该变量的名称不是**
var1
,而是Variable:0
.您定义的每个节点都有自己可以指定的名称,或者让tensorflow给出一个默认(并且始终不同)的名称 . 您可以看到
name
值访问python变量的name
属性 . (即print(var1.name)
) .在第二行,您要定义 Python variable
var2
,其名称 in the tensorflow graph 是var1
.剧本
事实上印刷品:
相反,如果你想在张量流图中定义一个名为
var1
的变量(节点),然后获得对该节点的引用,那么你只需使用tf.get_variable("var1")
,因为它将创建一个新的不同变量valledvar1_1
.这个脚本
打印:
如果要创建对节点
var1
的引用,首先要:必须用
tf.get_variable
替换tf.Variable
. 使用tf.Variable
创建的变量无法共享,而后者则可以 .知道
var1
的scope
是什么,并在声明引用时允许该范围的reuse
.查看代码是更好的理解方式
输出: