我试图在tensorflow seq2seq代码中打印张量attns . Seq2Seq.py
我试过了:
tf.Print(attns, [attns])
但它什么都不打印 .
我试过了
sess = tf.Session()
sess.run(attns) or attns.eval()
我抛出这种情况:InvalidArgumentError:你必须为占位符张量提供一个值
我也尝试过使用sess.run()
sess = tf.get_default_session()
aa = sess.run(attns)
在这种情况下,sess对象是None .
1 回答
tf.Print
不是"classic"操作指令,因为它们不是在符号的基于图形的代码中执行的 . 相反,需要的是计算图中的一个特定节点,每当您计算该节点时,该节点就会被触发 .这正是tf.Print所做的 . 它通过创建标识操作在任何其他节点周围创建“包装器”节点,该标识操作在触发时打印张量列表的值 .
this print function,
input_
(或者在您的情况下为attns
)的第一个参数是包装节点,data
(或者在您的情况下为[attns]
)是要打印的张量列表 .你想要做的是添加这一行:
这里,
attns
在attns
上被分配了一个打印包装器标识操作 - 因此张量attns
具有完全相同的行为,除了在计算它时,它还将打印[attns]
.