我有一个3d张量,我需要在第二维中的某些位置保留向量,并将剩余的向量归零 . 位置指定为1d数组 . 我认为最好的方法是将张量乘以二进制掩码 .
这是一个简单的Numpy版本:
A.shape: (b, n, m)
indices.shape: (b)
mask = np.zeros(A.shape)
for i in range(b):
mask[i][indices[i]] = 1
result = A*mask
因此,对于A中的每个nxm矩阵,我需要保留由索引指定的行,并将其余部分清零 .
我正在尝试使用tf.scatter_nd op在TensorFlow中执行此操作,但我无法弄清楚索引的正确形状:
shape = tf.constant([3,5,4])
A = tf.random_normal(shape)
indices = tf.constant([2,1,4]) #???
updates = tf.ones((3,4))
mask = tf.scatter_nd(indices, updates, shape)
result = A*mask
1 回答
这是一种方法,创建一个掩码并使用
tf.where
:打印:
(该示例使用了急切执行,但相同的操作将与会话中的图形执行一起使用)