首页 文章

重塑Pytorch张量

提问于
浏览
0

我在Pytorch有一个大小为 (24, 2, 224, 224) 的张量 .

  • 24 =批量大小

  • 2 =表示前景和背景的矩阵

  • 224 =图像高度尺寸

  • 224 =图像宽度尺寸

这是执行二进制分段的CNN的输出 . 在2个矩阵的每个单元格中,存储该像素为前景或背景的概率:对于每个坐标,为 [n][0][h][w] + [n][1][h][w] = 1

我想将它重塑成一个大小为 (24, 1, 224, 224) 的张量 . 根据概率较高的矩阵,新图层中的值应为 01 .

我怎样才能做到这一点?我应该使用哪种功能?

1 回答

  • 1

    使用torch.argmax()(对于PyTorch 0.4):

    prediction = torch.argmax(tensor, dim=1) # with 'dim' the considered dimension 
    prediction = prediction.unsqueeze(1) # to reshape from (24, 224, 224) to (24, 1, 224, 224)
    

    如果PyTorch版本低于0.4.0,可以使用tensor.max(),它返回最大值及其索引(但不能在索引值上区分):

    _, prediction = tensor.max(dim=1)
    prediction = prediction.unsqueeze(1) # to reshape from (24, 224, 224) to (24, 1, 224, 224)
    

相关问题