首页 文章

打印tensorflow变量seq2seq

提问于
浏览
0

我试图在tensorflow seq2seq代码中打印张量attns . Seq2Seq.py

我试过了:

tf.Print(attns, [attns])

但它什么都不打印 .

我试过了

sess = tf.Session() 
sess.run(attns) or attns.eval()

我抛出这种情况:InvalidArgumentError:你必须为占位符张量提供一个值

我也尝试过使用sess.run()

sess = tf.get_default_session()
aa = sess.run(attns)

在这种情况下,sess对象是None .

1 回答

  • 1

    tf.Print 不是"classic"操作指令,因为它们不是在符号的基于图形的代码中执行的 . 相反,需要的是计算图中的一个特定节点,每当您计算该节点时,该节点就会被触发 .

    这正是tf.Print所做的 . 它通过创建标识操作在任何其他节点周围创建“包装器”节点,该标识操作在触发时打印张量列表的值 .

    this print functioninput_ (或者在您的情况下为 attns )的第一个参数是包装节点, data (或者在您的情况下为 [attns] )是要打印的张量列表 .

    你想要做的是添加这一行:

    attns = tf.Print(attns, [attns])
    

    这里, attnsattns 上被分配了一个打印包装器标识操作 - 因此张量 attns 具有完全相同的行为,除了在计算它时,它还将打印 [attns] .

相关问题