在tensorflow中,我有以下问题 .
我有一个形状[batch_size,dim_a,dim_b]的张量 m 和形状[batch_size,dim_b]的矩阵 u .
M = tf.constant(shape=[batch_size, sequence_size, embed_dim])
U = tf.constant(shape=[batch_size, embed_dim])
我要实现的是我的批次的每个索引的[i,dim_a,dim_b] x [i,dim_b]的点积 .
P[i] = tf.matmul(M[i, :, :], tf.expand_dims(U[i, :], 1)) for each i.
基本上,在批轴上广泛地点击点积 . 这是可能的,我该如何实现?
1 回答
这可以通过tf.einsum()来实现:
输出: