首页 文章

张量的中间维度的分散操作

提问于
浏览
-1

我有一个3d张量,我需要在第二维中的某些位置保留向量,并将剩余的向量归零 . 位置指定为1d数组 . 我认为最好的方法是将张量乘以二进制掩码 .

这是一个简单的Numpy版本:

A.shape: (b, n, m) 
indices.shape: (b)

mask = np.zeros(A.shape)
for i in range(b):
  mask[i][indices[i]] = 1
result = A*mask

因此,对于A中的每个nxm矩阵,我需要保留由索引指定的行,并将其余部分清零 .

我正在尝试使用tf.scatter_nd op在TensorFlow中执行此操作,但我无法弄清楚索引的正确形状:

shape = tf.constant([3,5,4])
A = tf.random_normal(shape)       
indices = tf.constant([2,1,4])   #???   
updates = tf.ones((3,4))           
mask = tf.scatter_nd(indices, updates, shape) 
result = A*mask

1 回答

  • 0

    这是一种方法,创建一个掩码并使用 tf.where

    import tensorflow as tf
    import tensorflow.contrib.eager as tfe
    tfe.enable_eager_execution()
    
    shape = tf.constant([3,5,4])
    A = tf.random_normal(shape)
    
    array_shape = tf.shape(A)
    indices = tf.constant([2,1,4])
    non_zero_indices = tf.stack((tf.range(array_shape[0]), indices), axis=1)
    should_keep_row = tf.scatter_nd(non_zero_indices, tf.ones_like(indices),
                                    shape=[array_shape[0], array_shape[1]])
    print("should_keep_row", should_keep_row)
    masked = tf.where(tf.cast(tf.tile(should_keep_row[:, :, None],
                                      [1, 1, array_shape[2]]), tf.bool),
                       A,
                       tf.zeros_like(A))
    print("masked", masked)
    

    打印:

    should_keep_row tf.Tensor(
    [[0 0 1 0 0]
     [0 1 0 0 0]
     [0 0 0 0 1]], shape=(3, 5), dtype=int32)
    masked tf.Tensor(
    [[[ 0.          0.          0.          0.        ]
      [ 0.          0.          0.          0.        ]
      [ 0.02036316 -0.07163608 -3.16707373  1.31406844]
      [ 0.          0.          0.          0.        ]
      [ 0.          0.          0.          0.        ]]
    
     [[ 0.          0.          0.          0.        ]
      [-0.76696759 -0.28313264  0.87965059 -1.28844094]
      [ 0.          0.          0.          0.        ]
      [ 0.          0.          0.          0.        ]
      [ 0.          0.          0.          0.        ]]
    
     [[ 0.          0.          0.          0.        ]
      [ 0.          0.          0.          0.        ]
      [ 0.          0.          0.          0.        ]
      [ 0.          0.          0.          0.        ]
      [ 1.03188455  0.44305769  0.71291149  1.59758031]]], shape=(3, 5, 4), dtype=float32)
    

    (该示例使用了急切执行,但相同的操作将与会话中的图形执行一起使用)

相关问题