我有一个形状张量 (16, 4096, 3)
. 我有另一个形状指数 (16, 32768, 3)
的张量 . 我试图收集 dim=1
的值 . 这最初是在pytorch中使用gather function完成的,如下所示 -
# a.shape (16L, 4096L, 3L)
# idx.shape (16L, 32768L, 3L)
b = a.gather(1, idx)
# b.shape (16L, 32768L, 3L)
请注意,输出 b
的大小与 idx
的大小相同 . 但是,当我应用张量流的 gather
函数时,我得到一个完全不同的输出 . 发现输出维度不匹配,如下所示 -
b = tf.gather(a, idx, axis=1)
# b.shape (16, 16, 32768, 3, 3)
我也试过使用 tf.gather_nd
但是徒劳无功 . 见下文-
b = tf.gather_nd(a, idx)
# b.shape (16, 32768)
为什么我会得到不同形式的张量?我想获得与pytorch计算的相同形状的张量 .
In other words, I want to know the tensorflow equivalent of torch.gather.
1 回答
对于2D情况,有一种方法可以做到:
但是,对于ND情况,这种方法可能非常复杂