首页 文章

没有为4D 3D张量的张量流中的tf.matmul广播

提问于
浏览
3

首先我在这里找到另一个问题No broadcasting for tf.matmul in TensorFlow
但这个问题并没有解决我的问题 .

我的问题是一批矩阵乘以另一批向量 .

x=tf.placeholder(tf.float32,shape=[10,1000,3,4])
y=tf.placeholder(tf.float32,shape=[1000,4])

x是一批矩阵 . 有10 * 1000个矩阵 . 每个矩阵都是有形的[3,4]
y是一批向量 . 有1000个向量 . 每个向量都是形状[4]
Dim 1 of x and dim 0 of y are the same. (这是1000)
如果tf.matmul支持广播,我可以写

y=tf.reshape(y,[1,1000,4,1])
result=tf.matmul(x,y)
result=tf.reshape(result,[10,1000,3])

但是tf.matmul不支持广播
如果我使用上面引用的问题的方法

x=tf.reshape(x,[10*1000*3,4])
y=tf.transpose(y,perm=[1,0]) #[4,1000]
result=tf.matmul(x,y)
result=tf.reshape(result,[10,1000,3,1000])

结果是形状[10,1000,3,1000],而不是[10,1000,3] .
我不知道如何删除多余的1000
如何获得与支持广播的tf.matmul相同的结果?

2 回答

  • 3

    我自己解决了 .

    x=tf.transpose(x,perm=[1,0,2,3]) #[1000,10,3,4]
    x=tf.reshape(x,[1000,30,4])
    y=tf.reshape(y,[1000,4,1])
    result=tf.matmul(x,y) #[1000,30,1]
    result=tf.reshape(result,[1000,10,3])
    result=tf.transpose(result,perm=[1,0,2]) #[10,1000,3]
    
  • 0

    here所示,您可以使用函数来解决:

    def broadcast_matmul(A, B):
      "Compute A @ B, broadcasting over the first `N-2` ranks"
      with tf.variable_scope("broadcast_matmul"):
        return tf.reduce_sum(A[..., tf.newaxis] * B[..., tf.newaxis, :, :],
                             axis=-2)
    

相关问题