首页 文章

在图形创建时跟踪张量形状

提问于
浏览
4

在某些情况下,tensorflow似乎能够在图形创建时检查张量的值,而在其他情况下,这会失败 .

>>> shape = [constant([2])[0], 3]
>>> reshape([1,2,3,4,5,6], shape)
<tf.Tensor 'Reshape_13:0' shape=(2, 3) dtype=int32>
>>> zeros(shape)
<tf.Tensor 'zeros_2:0' shape=(?, 3) dtype=float32>

在上面的例子中,reshape()可以看到传入的张量因为形状的值为2,结果输出的形状为(2,3)但是零()不能,静态形状是(?,3) ) . 差异的原因是什么?

我的同事发布了Determining tensor shapes at time of graph creation in TensorFlow,这是基于相同的基础问题,但他问的是如何最好地使用tensorflow来解决这类问题的问题,而我的问题是为什么tensorflow会以这种方式运行 . 这是一个错误吗?

1 回答

  • 2

    TD; DR:

    • tf.reshape 可以推断出输出的形状但 tf.zeros 不能;

    • shape 支持两个函数的整数(如 static/definite )和张量(为 dynamic/indefinite ) .


    代码更具体,更清晰:

    shape = [tf.constant([2])[0], tf.constant([3])[0]]
    print(tf.reshape([1,2,3,4,5,6], shape))  
    # Tensor("Reshape:0", shape=(?, ?), dtype=int32)
    print(tf.zeros(shape))  
    # Tensor("zeros:0", shape=(?, ?), dtype=float32)
    

    还有这个:

    shape = [tf.constant([5])[0], 3]
    print tf.reshape([1,2,3,4,5,6], shape)  
    # Tensor("Reshape:0", shape=(2, 3), dtype=int32)
    # This will cause an InvalidArgumentError at running time!
    

    当使用 Tensor (如 tf.constant([2])[0] )作为 shape 来创建另一个 Tensor (如 tf.zeros(shape) )时,图形创建时形状总是不确定的 . 但是, tf.reshape() 是不同的 . 它可以使用输入的形状和给定形状(静态部分)推断输出的形状 .

    在您的代码中, 3 是一个静态整数,并给出了输入的形状( [6] );形状 (2, 3) 实际上是通过推断获得的,而不是提供的 . 这可以在代码的第二部分中证明 . 虽然我给了一个 tf.constant([5]) ,但形状并没有改变 . (图表创建时没有错误,但在运行时出现错误!)

相关问题