首页 文章

pytorch相当于tf.gather

提问于
浏览
0

我在将一些代码从tensorflow移植到pytorch时遇到了一些麻烦 .

所以我有一个尺寸为10x30的矩阵,代表10个例子,每个都有30个特征 . 然后,我有另一个尺寸为10x5的矩阵,其中包含第一个矩阵中每个示例的5个最接近的示例的索引 . 我想使用第二个矩阵中包含的索引“聚集”第一个矩阵中每个示例的5个壁橱示例,让我得到一个形状为10x5x30的3d张量 .

在tensorflow中,这是通过 tf.gather(matrix1, matrix2) 完成的 . 有谁知道我怎么能在pytorch做到这一点?

1 回答

  • 0

    这个怎么样?

    matrix1 = torch.randn(10, 30)
    matrix2 = torch.randint(high=10, size=(10, 5))
    gathered = matrix1[matrix2]
    

    它使用带有整数数组的索引技巧 .

相关问题