首页 文章

Batched Gather / GatherND

提问于
浏览
1

我想知道是否有办法在TensorFlow中执行以下操作,使用 gather_nd 或类似的东西 .

我有两个张量:

  • values ,形状 [128, 100]

  • indices ,形状 [128, 3]

其中 indices 的每一行都包含沿 values 第二维的索引(对于同一行) . 我想使用 indices 索引 values . 例如,我想要这样做(使用松散表示法来表示张量):

values  = [[0, 0, 0, 1, 1, 0, 1], 
           [1, 1, 0, 0, 1, 0, 0]]
indices = [[2, 3, 6], 
           [0, 2, 3]]
batched_gather(values, indices) = [[0, 1, 1], [1, 0, 0]]

此操作将遍历 valuesindices 的每一行,并使用 indices 行在 values 行上执行收集 .

在TensorFlow中有一种简单的方法吗?

谢谢!

1 回答

  • 1

    不确定这是否符合"simple",但您可以使用 gather_nd

    def batched_gather(values, indices):
        row_indices = tf.range(0, tf.shape(values)[0])[:, tf.newaxis]
        row_indices = tf.tile(row_indices, [1, tf.shape(indices)[-1]])
        indices = tf.stack([row_indices, indices], axis=-1)
        return tf.gather_nd(values, indices)
    

    说明:想法是构造索引向量,例如 [0, 1] ,意思是"the value in the 0th row and 1st column" .
    列索引已在函数的 indices 参数中给出 .
    行索引是从0到例如0的简单进展 . 128(在您的示例中),但是根据每行的列索引数重复(平铺)(在您的示例中为3;如果此数字是固定的,则可以硬编码而不是使用 tf.shape ) .
    然后堆叠行索引和列索引以产生索引向量 . 在您的示例中,结果索引将是

    array([[[0, 2],
            [0, 3],
            [0, 6]],
    
           [[1, 0],
            [1, 2],
            [1, 3]]])
    

    gather_nd 产生所需的结果 .

相关问题