首页 文章

如何在Tensorflow中用二维张量索引三维张量?

提问于
浏览
0

我试图使用二维张量来索引Tensorflow中的三维张量 . 例如,我有 x 形状 [2, 3, 4]

[[[ 0,  1,  2,  3],
  [ 4,  5,  6,  7],
  [ 8,  9, 10, 11]],

 [[12, 13, 14, 15],
  [16, 17, 18, 19],
  [20, 21, 22, 23]]]

并且我想用另一个形状 [2, 3] 的张量 y 索引它,其中 y 的每个元素索引 x 的最后一个维度 . 例如,如果我们有 y 喜欢:

[[0, 2, 3],
 [1, 0, 2]]

输出形状 [2, 3]

[[0, 6, 11],
 [13, 16, 22]]

1 回答

  • 1

    使用 tf.meshgrid 创建索引,然后使用 tf.gather_nd 提取元素:

    # create a list of indices for except the last axis
    idx_except_last = tf.meshgrid(*[tf.range(s) for s in x.shape[:-1]], indexing='ij')
    
    # concatenate with last axis indices
    idx = tf.stack(idx_except_last + [y], axis=-1)
    
    # gather elements based on the indices
    tf.gather_nd(x, idx).eval()
    
    # array([[ 0,  6, 11],
    #        [13, 16, 22]])
    

相关问题