我有一些由 input_x
表示的数据 . 它是一个未知大小的张量(应该是批量输入),每个项目的大小为 n
. input_x
经历 tf.nn.embedding_lookup
,因此 embed
现在具有维度 [?, n, m]
,其中 m
是嵌入大小, ?
指未知批量大小 .
这在这里描述:
input_x = tf.placeholder(tf.int32, [None, n], name="input_x")
embed = tf.nn.embedding_lookup(W, input_x)
我现在正试图将输入数据中的每个样本(现在通过嵌入维度扩展)乘以矩阵变量 U
,我似乎无法得到如何做到这一点 .
我首先尝试使用 tf.matmul
,但由于形状不匹配而导致错误 . 然后,我通过扩展 U
的维度并应用 batch_matmul
(我也尝试了 tf.nn.math_ops.
中的函数,结果是相同的)尝试了以下内容:
U = tf.Variable( ... )
U1 = tf.expand_dims(U,0)
h=tf.batch_matmul(embed, U1)
这会传递初始编译,但是当应用实际数据时,我会收到以下错误:
In[0].dim(0) and In[1].dim(0) must be the same: [64,58,128] vs [1,128,128]
我也知道为什么会这样 - 我复制了 U
的维度,它现在是 1
,但是小批量大小 64
不合适 .
如何正确地对张量矩阵输入进行矩阵乘法(对于未知的批量大小)?
5 回答
matmul operation仅适用于矩阵(2D张量) . 以下是执行此操作的两种主要方法,均假设
U
是2D张量 .embed
到2D张量中并将它们中的每一个分别与U
相乘 . 这可能是最容易使用tf.scan()这样做:embed
重塑为2D张量,这样乘法可以用单个matmul
完成,如下所示:其中
c
是U
中的列数 . 最后一次重塑将确保h
是一个3D张量,其中第0维对应于批次,就像原始x_input
和embed
一样 .以前的答案已经过时了 . 目前tf.matmul()支持等级> 2的张量:
此外
tf.batch_matmul()
已被删除,tf.matmul()
是进行批量乘法的正确方法 . 从以下代码可以理解主要思想:现在你将收到一个形状
(batch_size, n, k)
的张量 . 这是这里发生的事情 . 假设您有batch_size
矩阵nxm
和batch_size
矩阵mxk
. 现在,对于它们中的每一对,你计算nxm X mxk
,它给你一个nxk
矩阵 . 你将拥有batch_size
.请注意,这样的事情也是有效的:
并会给你一个形状
(a, b, n, k)
1.我想将一批矩阵与一批相同长度的矩阵相乘,成对
2.我想将一批矩阵与一批相同长度的矢量相乘,成对
我们通过添加和删除维度
v
来回到案例1 .3.我想将一个矩阵与一批矩阵相乘
在这种情况下,我们不能简单地将
1
的批量维度添加到单个矩阵中,因为tf.matmul
不在批量维度中广播 .3.1.单个矩阵位于右侧
在这种情况下,我们可以使用简单的重塑将矩阵批处理视为单个大矩阵 .
3.2.单个矩阵位于左侧
这种情况比较复杂 . 我们可以通过转置矩阵来回到案例3.1 .
然而,换位可能是一项昂贵的操作,并且在这里它在整批矩阵上完成两次 . 简单地复制
M
以匹配批量维度可能更好:分析将告诉哪个选项对于给定的问题/硬件组合更有效 .
4.我想将一个矩阵与一批向量相乘
这看起来类似于情况3.2,因为单个矩阵在左边,但它实际上更简单,因为转置矢量本质上是一个无操作 . 我们最终得到了
einsum怎么样?
之前的所有乘法都可以用tf.einsum瑞士军刀编写 . 例如,3.2的第一个解决方案可以简单地写成
但请注意,
einsum
最终是relying on tranpose and matmul用于计算 .因此,尽管
einsum
是一种非常方便的写矩阵乘法的方法,但它隐藏了下面操作的复杂性 - 例如,猜测einsum
表达式将转置数据的次数并不简单,因此操作的成本会很高 . . 此外,它可能隐藏了同一操作可能存在多种替代方案的事实(参见案例3.2),并且可能不一定选择更好的选项 .出于这个原因,我个人会使用上面那些明确的公式来更好地传达它们各自的复杂性虽然如果你知道自己在做什么,并且喜欢
einsum
语法的简单性,那么请务必去寻找它 .正如@Stryke所说,有两种方法可以达到这个目的:1 . 扫描,和2.重塑
tf.scan需要lambda函数,通常用于递归操作 . 这里有一些例子:https://rdipietro.github.io/tensorflow-scan-examples/
我个人更喜欢重塑,因为它更直观 . 如果您试图通过2D张量矩阵(如Cijl = Aijk * Bkl)将3D张量中的每个矩阵进行矩阵乘法,则可以使用简单的重塑形式进行矩阵乘法 .
似乎在TensorFlow 1.11.0中docs为
tf.matmul
错误地说它适用于秩> = 2 .相反,我发现最好的清洁替代方案是使用
tf.tensordot(a, b, (-1, 0))
(docs) .此函数以其通用形式
tf.tensordot(a, b, axis)
获取数组a
的任何轴和数组b
的任何轴的点积 . 提供axis
作为(-1, 0)
获得两个数组的标准点积 .