首页 文章

tensorflow相当于torch.gather

提问于
浏览
0

我有一个形状张量 (16, 4096, 3) . 我有另一个形状指数 (16, 32768, 3) 的张量 . 我试图收集 dim=1 的值 . 这最初是在pytorch中使用gather function完成的,如下所示 -

# a.shape (16L, 4096L, 3L)
# idx.shape (16L, 32768L, 3L)
b = a.gather(1, idx)
# b.shape (16L, 32768L, 3L)

请注意,输出 b 的大小与 idx 的大小相同 . 但是,当我应用张量流的 gather 函数时,我得到一个完全不同的输出 . 发现输出维度不匹配,如下所示 -

b = tf.gather(a, idx, axis=1)
# b.shape (16, 16, 32768, 3, 3)

我也试过使用 tf.gather_nd 但是徒劳无功 . 见下文-

b = tf.gather_nd(a, idx)
# b.shape (16, 32768)

为什么我会得到不同形式的张量?我想获得与pytorch计算的相同形状的张量 .

In other words, I want to know the tensorflow equivalent of torch.gather.

1 回答

  • 1

    对于2D情况,有一种方法可以做到:

    # a.shape (16L, 10L)
    # idx.shape (16L,1)
    idx = tf.stack([tf.range(tf.shape(idx)[0]),idx[:,0]],axis=-1)
    b = tf.gather_nd(a,idx)
    

    但是,对于ND情况,这种方法可能非常复杂

相关问题