首页 文章

Seq2Seq模型学习仅在几次迭代后输出EOS令牌(<\ s>)

提问于
浏览
4

我正在使用NMT创建一个在Cornell Movie Dialogs Corpus上训练的聊天机器人 .

我的代码部分来自https://github.com/bshao001/ChatLearnerhttps://github.com/chiphuyen/stanford-tensorflow-tutorials/tree/master/assignments/chatbot

在训练期间,我打印从批处理中提供给解码器的随机输出答案以及我的模型预测的相应答案,以观察学习进度 .

My issue: 仅经过大约4次训练后,模型学习为每个时间步输出EOS令牌( <\s> ) . 它总是将其输出作为其响应(使用logg的argmax确定),即使训练仍在继续 . 偶尔,很少,模型输出一系列时期作为答案 .

我还在训练期间打印了前10个logit值(不仅仅是argmax),看看是否有正确的单词在那里,但它似乎预测了词汇中最常见的单词(例如i,you,?,. ) . 即使这些前10个单词在培训期间也没有太大变化 .

我已经确保正确计算编码器和解码器的输入序列长度,并相应地添加了SOS( <s> )和EOS(也用于填充)令牌 . 我还在损失计算中执行 masking .

这是一个示例输出:

Training iteration 1:

Decoder Input: <s> sure . sure . <\s> <\s> <\s> <\s> <\s> <\s> <\s> 
<\s> <\s>
Predicted Answer: wildlife bakery mentality mentality administration 
administration winston winston winston magazines magazines magazines 
magazines

...

Training iteration 4:

Decoder Input: <s> i guess i had it coming . let us call it settled . 
<\s> <\s> <\s> <\s> <\s>
Predicted Answer: <\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s> 
<\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s>

经过几次迭代后,它只能预测EOS(很少有一些时期)

我不确定是什么原因引起了这个问题,并且已经停留了一段时间 . 任何帮助将不胜感激!

Update: 我让它训练超过十万次迭代,它仍然只输出EOS(偶尔出现) . 经过几次迭代后,训练损失也不会减少(从一开始就保持在47左右)

1 回答

  • 0

    最近我也在研究seq2seq模型 . 我以前遇到过你的问题,在我的情况下,我通过改变损失函数来解决它 .

    你说你使用了面具,所以我猜你像我一样使用 tf.contrib.seq2seq.sequence_loss .

    我改为 tf.nn.softmax_cross_entropy_with_logits ,它正常工作(和更高的计算成本) .

    (编辑05/10/2018 . 对不起,我需要编辑,因为我发现我的代码中存在一个令人震惊的错误)

    tf.contrib.seq2seq.sequence_loss 可以正常工作,如果 logitstargetsmask 的形状是正确的 . 正如官方文件中所定义:tf.contrib.seq2seq.sequence_loss

    loss=tf.contrib.seq2seq.sequence_loss(logits=decoder_logits,
                                          targets=decoder_targets,
                                          weights=masks) 
    
    #logits:  [batch_size, sequence_length, num_decoder_symbols]  
    #targets: [batch_size, sequence_length] 
    #weights: [batch_size, sequence_length]
    

    嗯,即使形状不符合,它仍然可以工作 . 但结果可能很奇怪(很多#EOS #PAD等等) .

    由于 decoder_outputsdecoder_targets 可能具有所需的相同形状(在我的情况下,我的 decoder_targets 具有 [sequence_length, batch_size] 形状) . 因此,尝试使用 tf.transpose 帮助您重塑张量 .

相关问题