我想知道是否有办法在TensorFlow中执行以下操作,使用 gather_nd
或类似的东西 .
我有两个张量:
-
values
,形状[128, 100]
, -
indices
,形状[128, 3]
,
其中 indices
的每一行都包含沿 values
第二维的索引(对于同一行) . 我想使用 indices
索引 values
. 例如,我想要这样做(使用松散表示法来表示张量):
values = [[0, 0, 0, 1, 1, 0, 1],
[1, 1, 0, 0, 1, 0, 0]]
indices = [[2, 3, 6],
[0, 2, 3]]
batched_gather(values, indices) = [[0, 1, 1], [1, 0, 0]]
此操作将遍历 values
和 indices
的每一行,并使用 indices
行在 values
行上执行收集 .
在TensorFlow中有一种简单的方法吗?
谢谢!
1 回答
不确定这是否符合"simple",但您可以使用
gather_nd
:说明:想法是构造索引向量,例如
[0, 1]
,意思是"the value in the 0th row and 1st column" .列索引已在函数的
indices
参数中给出 .行索引是从0到例如0的简单进展 . 128(在您的示例中),但是根据每行的列索引数重复(平铺)(在您的示例中为3;如果此数字是固定的,则可以硬编码而不是使用
tf.shape
) .然后堆叠行索引和列索引以产生索引向量 . 在您的示例中,结果索引将是
和
gather_nd
产生所需的结果 .