由于一些张力问题,我正试图以其他方式实施推特 .

在tf.tensordot op中,在3d * 2d matmul的过程中有一个批量大小的修改 .

M = tf.random_normal((batch_size, n, m))  # (3,6,9)
N = tf.random_normal((m, p)) # (9,9)

MT = tf.reshape(M, [batch_size*n, m]) # (18,9)
MTN = tf.matmul(M_T, N) # (18,9)

MN = tf.reshape(MTN, [batch_size, n, p]) # (3,6,9)

但我想要3d * 2d matmul而不改变批量大小 . 有办法吗?