我在将一些代码从tensorflow移植到pytorch时遇到了一些麻烦 .
所以我有一个尺寸为10x30的矩阵,代表10个例子,每个都有30个特征 . 然后,我有另一个尺寸为10x5的矩阵,其中包含第一个矩阵中每个示例的5个最接近的示例的索引 . 我想使用第二个矩阵中包含的索引“聚集”第一个矩阵中每个示例的5个壁橱示例,让我得到一个形状为10x5x30的3d张量 .
在tensorflow中,这是通过 tf.gather(matrix1, matrix2)
完成的 . 有谁知道我怎么能在pytorch做到这一点?
1 回答
这个怎么样?
它使用带有整数数组的索引技巧 .