我正在使用多个编码器/解码器(每种语言各一个)进行多语言机器翻译设置,在火车时刻,我将批量从单个源提供给单个目标,我在编码器和解码器之间切换a tf.case ,取决于批次给出的 lang_src 和 lang_tgt .我的问题是,当我使用5种语言时,我达到了GPU的12GB内存限制 .我不确定tensorflow是如何工作的,但我认为它为每个分支的激活或渐变分配GPU内存,但这在我的情况下是不必要的,因为对于任何批处理只有一条路径 .有没有办法优化它?
tf.case
lang_src
lang_tgt