首页 文章

张量流中的广播点积

提问于
浏览
0

在tensorflow中,我有以下问题 .

我有一个形状[batch_size,dim_a,dim_b]的张量 m 和形状[batch_size,dim_b]的矩阵 u .

M = tf.constant(shape=[batch_size, sequence_size, embed_dim])
U = tf.constant(shape=[batch_size, embed_dim])

我要实现的是我的批次的每个索引的[i,dim_a,dim_b] x [i,dim_b]的点积 .

P[i] = tf.matmul(M[i, :, :], tf.expand_dims(U[i, :], 1)) for each i.

基本上,在批轴上广泛地点击点积 . 这是可能的,我该如何实现?

1 回答

  • 2

    这可以通过tf.einsum()来实现:

    import tensorflow as tf
    import numpy as np
    
    batch_size = 2
    sequence_size = 3
    embed_dim = 4
    
    M = tf.constant(range(batch_size * sequence_size * embed_dim), shape=[batch_size, sequence_size, embed_dim])
    U = tf.constant(range(batch_size, embed_dim), shape=[batch_size, embed_dim])
    
    prod = tf.einsum('bse,be->bs', M, U)
    
    with tf.Session():
      print "M"
      print M.eval()
      print
      print "U"
      print U.eval()
      print
      print "einsum result"
      print prod.eval()
      print
    
      print "numpy, example 0"
      print np.matmul(M.eval()[0], U.eval()[0])
      print
      print "numpy, example 1"
      print np.matmul(M.eval()[1], U.eval()[1])
    

    输出:

    M
    [[[ 0  1  2  3]
      [ 4  5  6  7]
      [ 8  9 10 11]]
    
     [[12 13 14 15]
      [16 17 18 19]
      [20 21 22 23]]]
    
    U
    [[2 3 3 3]
     [3 3 3 3]]
    
    einsum result
    [[ 18  62 106]
     [162 210 258]]
    
    numpy, example 0
    [ 18  62 106]
    
    numpy, example 1
    [162 210 258]
    

相关问题