首页 文章

如何明确地广播张量以匹配张量流中的另一个形状?

提问于
浏览
12

我有三个张量, A, B and C 在tensorflow, AB 都是形状 (m, n, r)C 是二元张量的形状 (m, n, 1) .

我想根据 C 的值从A或B中选择元素 . 显而易见的工具是 tf.select ,但是它没有广播语义,所以我需要首先将 C 显式广播为与A和B相同的形状 .

这将是我第一次尝试如何做到这一点,但它不喜欢我将张量( tf.shape(A)[2] )混合到形状列表中 .

import tensorflow as tf
A = tf.random_normal([20, 100, 10])
B = tf.random_normal([20, 100, 10])
C = tf.random_normal([20, 100, 1])
C = tf.greater_equal(C, tf.zeros_like(C))

C = tf.tile(C, [1,1,tf.shape(A)[2]])
D = tf.select(C, A, B)

这里的正确方法是什么?

3 回答

  • 12

    EDIT: 在0.12rc0之后的所有版本的TensorFlow中,问题中的代码直接起作用 . TensorFlow会自动将张量和Python数字叠加到张量参数中 . 使用 tf.pack() 的以下解决方案仅在0.12rc0之前的版本中需要 . 请注意, tf.pack() 在TensorFlow 1.0中重命名为tf.stack() .


    您的解决方案非常接近工作 . 您应该替换该行:

    C = tf.tile(C, [1,1,tf.shape(C)[2]])
    

    ......以下内容:

    C = tf.tile(C, tf.pack([1, 1, tf.shape(A)[2]]))
    

    (问题的原因是TensorFlow不会隐式地将张量和Python文字列表转换为张量.tf.pack()采用张量列表,因此它将转换其输入中的每个元素( 11tf.shape(C)[2] )由于每个元素都是一个标量,因此结果将是一个向量 . )

  • 3

    这是一个肮脏的黑客:

    import tensorflow as tf
    
    def broadcast(tensor, shape):
        return tensor + tf.zeros(shape, dtype=tensor.dtype)
    
    A = tf.random_normal([20, 100, 10])
    B = tf.random_normal([20, 100, 10])
    C = tf.random_normal([20, 100, 1])
    
    C = broadcast(C, A.shape)
    D = tf.select(C, A, B)
    
  • 0
    import tensorflow as tf
    
    def broadcast(tensor, shape):
         """Broadcasts ``x`` to have shape ``shape``.
                                                                       |
         Uses ``tf.Assert`` statements to ensure that the broadcast is
         valid.
    
         First calculates the number of missing dimensions in 
         ``tf.shape(x)`` and left-pads the shape of ``x`` with that many 
         ones. Then identifies the dimensions of ``x`` that require
         tiling and tiles those dimensions appropriately.
    
         Args:
             x (tf.Tensor): The tensor to broadcast.
             shape (Union[tf.TensorShape, tf.Tensor, Sequence[int]]): 
                 The shape to broadcast to.
    
         Returns:
             tf.Tensor: ``x``, reshaped and tiled to have shape ``shape``.
    
         """
         with tf.name_scope('broadcast') as scope:
             shape_x = tf.shape(x)
             rank_x = tf.shape(shape0)[0]
             shape_t = tf.convert_to_tensor(shape, preferred_dtype=tf.int32)
             rank_t = tf.shape(shape1)[0]
    
             with tf.control_dependencies([tf.Assert(
                 rank_t >= rank_x,
                 ['len(shape) must be >= tf.rank(x)', shape_x, shape_t],
                 summarize=255
             )]):
                 missing_dims = tf.ones(tf.stack([rank_t - rank_x], 0), tf.int32)
    
             shape_x_ = tf.concat([missing_dims, shape_x], 0)
             should_tile = tf.equal(shape_x_, 1)
    
             with tf.control_dependencies([tf.Assert(
                 tf.reduce_all(tf.logical_or(tf.equal(shape_x_, shape_t), should_tile),
                 ['cannot broadcast shapes', shape_x, shape_t],
                 summarize=255
             )]):
                 multiples = tf.where(should_tile, shape_t, tf.ones_like(shape_t))
                 out = tf.tile(tf.reshape(x, shape_x_), multiples, name=scope)
    
             try:
                 out.set_shape(shape)
             except:
                 pass
    
             return out
    
    A = tf.random_normal([20, 100, 10])
    B = tf.random_normal([20, 100, 10])
    C = tf.random_normal([20, 100, 1])
    
    C = broadcast(C, A.shape)
    D = tf.select(C, A, B)
    

相关问题