我正在使用多个编码器/解码器(每种语言各一个)进行多语言机器翻译设置,在火车时刻,我将批量从单个源提供给单个目标,我在编码器和解码器之间切换a tf.case ,取决于批次给出的 lang_srclang_tgt .
我的问题是,当我使用5种语言时,我达到了GPU的12GB内存限制 .
我不确定tensorflow是如何工作的,但我认为它为每个分支的激活或渐变分配GPU内存,但这在我的情况下是不必要的,因为对于任何批处理只有一条路径 .
有没有办法优化它?