我有一个名为 main_decoder 形状的3-d Tensor (None,9,256)
我想提取9个形状的张量 (None,256)
我尝试过使用Keras gather ,以下是模式代码片段:
for i in range(0,9):
sub_decoder_input = Lambda(lambda main_decoder:gather(main_decoder,(i)), name='lambda'+str(i))(main_decoder)
结果是9个λ层的形状 (9,256)
如何修改它以便我可以获得或收集9个形状的张量 (None,256)
谢谢 .
1 回答
您可以将3D张量切片为9个2D张量,并从
Lambda
层返回张量列表 .