首先我在这里找到另一个问题No broadcasting for tf.matmul in TensorFlow
但这个问题并没有解决我的问题 .
我的问题是一批矩阵乘以另一批向量 .
x=tf.placeholder(tf.float32,shape=[10,1000,3,4])
y=tf.placeholder(tf.float32,shape=[1000,4])
x是一批矩阵 . 有10 * 1000个矩阵 . 每个矩阵都是有形的[3,4]
y是一批向量 . 有1000个向量 . 每个向量都是形状[4]
Dim 1 of x and dim 0 of y are the same. (这是1000)
如果tf.matmul支持广播,我可以写
y=tf.reshape(y,[1,1000,4,1])
result=tf.matmul(x,y)
result=tf.reshape(result,[10,1000,3])
但是tf.matmul不支持广播
如果我使用上面引用的问题的方法
x=tf.reshape(x,[10*1000*3,4])
y=tf.transpose(y,perm=[1,0]) #[4,1000]
result=tf.matmul(x,y)
result=tf.reshape(result,[10,1000,3,1000])
结果是形状[10,1000,3,1000],而不是[10,1000,3] .
我不知道如何删除多余的1000
如何获得与支持广播的tf.matmul相同的结果?
2 回答
我自己解决了 .
如here所示,您可以使用函数来解决: