我想增加两个张量,这是我得到的:
-
A
张量形状(20, 96, 110)
-
B
张量形状(20, 16, 110)
第一个索引是批量大小 . 我想要做的主要是从 B
- (20, 1, 110)
中获取每个张量,例如,我希望将每个张量乘以 A
张量 (20, n, 110)
. 所以产品将在最后:张量 AB
哪个形状是 (20, 96 * 16, 110)
.
所以我想通过 B
广播将每个张量乘以 A
. PyTorch中有一种方法吗?
1 回答
使用torch.einsum后跟torch.reshape:
例: