首页 文章

Pytorch广播两个张量的产品

提问于
浏览
0

我想增加两个张量,这是我得到的:

  • A 张量形状 (20, 96, 110)

  • B 张量形状 (20, 16, 110)

第一个索引是批量大小 . 我想要做的主要是从 B - (20, 1, 110) 中获取每个张量,例如,我希望将每个张量乘以 A 张量 (20, n, 110) . 所以产品将在最后:张量 AB 哪个形状是 (20, 96 * 16, 110) .

所以我想通过 B 广播将每个张量乘以 A . PyTorch中有一种方法吗?

1 回答

  • 1

    使用torch.einsum后跟torch.reshape

    AB = torch.einsum("ijk,ilk->ijlk", (A, B)).reshape(A.shape[0], -1, A.shape[2])
    

    例:

    import numpy as np
    import torch
    
    # A of shape (2, 3, 2):
    A = torch.from_numpy(np.array([[[1, 1], [2, 2], [3, 3]], 
                                   [[4, 4], [5, 5], [6, 6]]]))
    # B of shape (2, 2, 2):
    B = torch.from_numpy(np.array([[[1, 1], [10, 10]], 
                                   [[2, 2], [20, 20]]]))
    
    # AB of shape (2, 3*2, 2):
    AB = torch.einsum("ijk,ilk->ijlk", (A, B)).reshape(A.shape[0], -1, A.shape[2])
    # tensor([[[ 1, 1], [ 10, 10], [  2,  2], [ 20,   20], [ 3,   3], [ 30,  30]],
    #         [[ 8, 8], [ 80, 80], [ 10, 10], [ 100, 100], [ 12, 12], [ 120, 120]]])
    

相关问题