首页 文章

使用word2vec作为张量流输入的LSTM的可变句长

提问于
浏览
3

我正在使用word2vec作为输入构建LSTM模型 . 我正在使用tensorflow框架 . 我已经完成了单词嵌入部分,但我遇到了LSTM部分 .

这里的问题是我有不同的句子长度,这意味着我必须做填充或使用指定序列长度的dynamic_rnn . 我和他们两个都在挣扎 .

  • 填充 . 填充的混乱部分是我填充时 . 我的模型就像

word_matrix = model.wv.syn0
X = tf.placeholder(tf.int32,shape)
data = tf.placeholder(tf.float32,shape)
data = tf.nn.embedding_lookup(word_matrix,X)

然后,我将word_matrix的单词索引序列送入X.我担心如果我将0填充到送入X的序列中,那么我会错误地继续输入不必要的输入(在这种情况下为word_matrix [0]) .

所以,我想知道0填充的正确方法是什么 . 如果你让我知道如何用tensorflow实现它会很棒 .

  • dynamic_rnn为此,我已经声明了一个包含所有句子长度的列表,并在最后将这些句子与X和y一起提供 . 在这种情况下,我不能批量输入输入 . 然后,我遇到了这个错误(ValueError:as_list()没有在未知的TensorShape上定义 . ),在我看来,sequence_length参数只接受列表? (我的想法可能完全不正确) .

以下是我的代码 .

X = tf.placeholder(tf.int32)
labels = tf.placeholder(tf.int32, [None, numClasses])
length = tf.placeholder(tf.int32)

data = tf.placeholder(tf.float32, [None, None, numDimensions])
data = tf.nn.embedding_lookup(word_matrix, X)

lstmCell = tf.contrib.rnn.BasicLSTMCell(lstmUnits, state_is_tuple=True)
lstmCell = tf.contrib.rnn.DropoutWrapper(cell=lstmCell, output_keep_prob=0.25)
initial_state=lstmCell.zero_state(batchSize, tf.float32)
value, _ = tf.nn.dynamic_rnn(lstmCell, data, sequence_length=length,
                             initial_state=initial_state, dtype=tf.float32)

我非常挣扎这部分,所以任何帮助都会非常感激 .

先感谢您 .

1 回答

  • 2

    Tensorflow不支持可变长度Tensor . 因此,当您声明Tensor时,list / numpy数组应该具有统一的形状 .

    • 从你的第一部分,我的理解是你已经能够在序列长度的最后一步填充零 . 理想的情况应该是什么 . 以下是它应该如何查找批量大小为4, max 序列长度10和50个隐藏单位 - >

    [4,10,50] 将是整个批次的大小,但在内部,当您尝试可视化填充时,它可能会像这样形状 - >

    `[[5+5pad,50],[10,50],[8+2pad,50],[9+1pad,50]`
    

    每个垫将表示序列长度为1,隐藏状态大小为50张 . 所有的东西都只有零 . 请查看this questionthis one以了解有关如何手动填充的更多信息 .

    • 您将使用动态rnn,原因是您不想在填充序列上计算它 . tf.nn.dynamic_rnn api将通过传递 sequence_length 参数来确保 .

    对于上面的示例,对于上面的示例,该参数将为: [5,10,8,9] . 您可以通过对每个批处理组件的非零实体求和来计算它 . 一种简单的计算方法是:

    data_mask = tf.cast(data, tf.bool)
    data_len = tf.reduce_sum(tf.cast(data_mask, tf.int32), axis=1)
    

    并在 tf.nn.dynamic_rnn api中传递:

    tf.nn.dynamic_rnn(lstmCell, data, sequence_length=data_len, initial_state=initial_state)
    

相关问题