首页 文章

Tensorflow - 带批量数据的输入矩阵的matmul

提问于
浏览
28

我有一些由 input_x 表示的数据 . 它是一个未知大小的张量(应该是批量输入),每个项目的大小为 n . input_x 经历 tf.nn.embedding_lookup ,因此 embed 现在具有维度 [?, n, m] ,其中 m 是嵌入大小, ? 指未知批量大小 .

这在这里描述:

input_x = tf.placeholder(tf.int32, [None, n], name="input_x") 
embed = tf.nn.embedding_lookup(W, input_x)

我现在正试图将输入数据中的每个样本(现在通过嵌入维度扩展)乘以矩阵变量 U ,我似乎无法得到如何做到这一点 .

我首先尝试使用 tf.matmul ,但由于形状不匹配而导致错误 . 然后,我通过扩展 U 的维度并应用 batch_matmul (我也尝试了 tf.nn.math_ops. 中的函数,结果是相同的)尝试了以下内容:

U = tf.Variable( ... )    
U1 = tf.expand_dims(U,0)
h=tf.batch_matmul(embed, U1)

这会传递初始编译,但是当应用实际数据时,我会收到以下错误:

In[0].dim(0) and In[1].dim(0) must be the same: [64,58,128] vs [1,128,128]

我也知道为什么会这样 - 我复制了 U 的维度,它现在是 1 ,但是小批量大小 64 不合适 .

如何正确地对张量矩阵输入进行矩阵乘法(对于未知的批量大小)?

5 回答

  • 63

    matmul operation仅适用于矩阵(2D张量) . 以下是执行此操作的两种主要方法,均假设 U 是2D张量 .

    • 切割 embed 到2D张量中并将它们中的每一个分别与 U 相乘 . 这可能是最容易使用tf.scan()这样做:
    h = tf.scan(lambda a, x: tf.matmul(x, U), embed)
    
    • 另一方面,如果效率很重要,最好将 embed 重塑为2D张量,这样乘法可以用单个 matmul 完成,如下所示:
    embed = tf.reshape(embed, [-1, m])
    h = tf.matmul(embed, U)
    h = tf.reshape(h, [-1, n, c])
    

    其中 cU 中的列数 . 最后一次重塑将确保 h 是一个3D张量,其中第0维对应于批次,就像原始 x_inputembed 一样 .

  • 5

    以前的答案已经过时了 . 目前tf.matmul()支持等级> 2的张量:

    输入必须是矩阵(或秩> 2的张量,代表批量矩阵),具有匹配的内部维度,可能在换位之后 .

    此外 tf.batch_matmul() 已被删除, tf.matmul() 是进行批量乘法的正确方法 . 从以下代码可以理解主要思想:

    import tensorflow as tf
    batch_size, n, m, k = 10, 3, 5, 2
    A = tf.Variable(tf.random_normal(shape=(batch_size, n, m)))
    B = tf.Variable(tf.random_normal(shape=(batch_size, m, k)))
    tf.matmul(A, B)
    

    现在你将收到一个形状 (batch_size, n, k) 的张量 . 这是这里发生的事情 . 假设您有 batch_size 矩阵 nxmbatch_size 矩阵 mxk . 现在,对于它们中的每一对,你计算 nxm X mxk ,它给你一个 nxk 矩阵 . 你将拥有 batch_size .

    请注意,这样的事情也是有效的:

    A = tf.Variable(tf.random_normal(shape=(a, b, n, m)))
    B = tf.Variable(tf.random_normal(shape=(a, b, m, k)))
    tf.matmul(A, B)
    

    并会给你一个形状 (a, b, n, k)

  • 0

    1.我想将一批矩阵与一批相同长度的矩阵相乘,成对

    M = tf.random_normal((batch_size, n, m))
    N = tf.random_normal((batch_size, m, p))
    
    # python >= 3.5
    MN = M @ N
    # or the old way,
    MN = tf.matmul(M, N)
    # MN has shape (batch_size, n, p)
    

    2.我想将一批矩阵与一批相同长度的矢量相乘,成对

    我们通过添加和删除维度 v 来回到案例1 .

    M = tf.random_normal((batch_size, n, m))
    v = tf.random_normal((batch_size, m))
    
    Mv = (M @ v[..., None])[..., 0]
    # Mv has shape (batch_size, n)
    

    3.我想将一个矩阵与一批矩阵相乘

    在这种情况下,我们不能简单地将 1 的批量维度添加到单个矩阵中,因为 tf.matmul 不在批量维度中广播 .

    3.1.单个矩阵位于右侧

    在这种情况下,我们可以使用简单的重塑将矩阵批处理视为单个大矩阵 .

    M = tf.random_normal((batch_size, n, m))
    N = tf.random_normal((m, p))
    
    MN = tf.reshape(tf.reshape(M, [-1, m]) @ N, [-1, n, p])
    # MN has shape (batch_size, n, p)
    

    3.2.单个矩阵位于左侧

    这种情况比较复杂 . 我们可以通过转置矩阵来回到案例3.1 .

    MT = tf.matrix_transpose(M)
    NT = tf.matrix_transpose(N)
    NTMT = tf.reshape(tf.reshape(NT, [-1, m]) @ MT, [-1, p, n])
    MN = tf.matrix_transpose(NTMT)
    

    然而,换位可能是一项昂贵的操作,并且在这里它在整批矩阵上完成两次 . 简单地复制 M 以匹配批量维度可能更好:

    MN = tf.tile(M[None], [batch_size, 1, 1]) @ N
    

    分析将告诉哪个选项对于给定的问题/硬件组合更有效 .

    4.我想将一个矩阵与一批向量相乘

    这看起来类似于情况3.2,因为单个矩阵在左边,但它实际上更简单,因为转置矢量本质上是一个无操作 . 我们最终得到了

    M = tf.random_normal((n, m))
    v = tf.random_normal((batch_size, m))
    
    MT = tf.matrix_transpose(M)
    Mv = v @ MT
    

    einsum怎么样?

    之前的所有乘法都可以用tf.einsum瑞士军刀编写 . 例如,3.2的第一个解决方案可以简单地写成

    MN = tf.einsum('nm,bmp->bnp', M, N)
    

    但请注意, einsum 最终是relying on tranpose and matmul用于计算 .

    因此,尽管 einsum 是一种非常方便的写矩阵乘法的方法,但它隐藏了下面操作的复杂性 - 例如,猜测 einsum 表达式将转置数据的次数并不简单,因此操作的成本会很高 . . 此外,它可能隐藏了同一操作可能存在多种替代方案的事实(参见案例3.2),并且可能不一定选择更好的选项 .

    出于这个原因,我个人会使用上面那些明确的公式来更好地传达它们各自的复杂性虽然如果你知道自己在做什么,并且喜欢 einsum 语法的简单性,那么请务必去寻找它 .

  • 14

    正如@Stryke所说,有两种方法可以达到这个目的:1 . 扫描,和2.重塑

    • tf.scan需要lambda函数,通常用于递归操作 . 这里有一些例子:https://rdipietro.github.io/tensorflow-scan-examples/

    • 我个人更喜欢重塑,因为它更直观 . 如果您试图通过2D张量矩阵(如Cijl = Aijk * Bkl)将3D张量中的每个矩阵进行矩阵乘法,则可以使用简单的重塑形式进行矩阵乘法 .

    A' = tf.reshape(Aijk,[i*j,k])
    C' = tf.matmul(A',Bkl)
    C = tf.reshape(C',[i,j,l])
    
  • 4

    似乎在TensorFlow 1.11.0中docstf.matmul 错误地说它适用于秩> = 2 .

    相反,我发现最好的清洁替代方案是使用 tf.tensordot(a, b, (-1, 0))docs) .

    此函数以其通用形式 tf.tensordot(a, b, axis) 获取数组 a 的任何轴和数组 b 的任何轴的点积 . 提供 axis 作为 (-1, 0) 获得两个数组的标准点积 .

相关问题