首页 文章

从Keras的3-d张量中收集2-d张量列表

提问于
浏览
2

我有一个名为 main_decoder 形状的3-d Tensor (None,9,256)

我想提取9个形状的张量 (None,256)

我尝试过使用Keras gather ,以下是模式代码片段:

for i in range(0,9):
    sub_decoder_input = Lambda(lambda main_decoder:gather(main_decoder,(i)), name='lambda'+str(i))(main_decoder)

结果是9个λ层的形状 (9,256)

如何修改它以便我可以获得或收集9个形状的张量 (None,256)

谢谢 .

1 回答

  • 3

    您可以将3D张量切片为9个2D张量,并从 Lambda 层返回张量列表 .

    main_decoder = Input(shape=(9, 256))
    sub_decoder_input = Lambda(lambda x: [x[:, i, :] for i in range(9)])(main_decoder)
    
    print(sub_decoder_input)
    [<tf.Tensor 'lambda_1/strided_slice:0' shape=(?, 256) dtype=float32>,
     <tf.Tensor 'lambda_1/strided_slice_1:0' shape=(?, 256) dtype=float32>,
     <tf.Tensor 'lambda_1/strided_slice_2:0' shape=(?, 256) dtype=float32>,
     <tf.Tensor 'lambda_1/strided_slice_3:0' shape=(?, 256) dtype=float32>,
     <tf.Tensor 'lambda_1/strided_slice_4:0' shape=(?, 256) dtype=float32>,
     <tf.Tensor 'lambda_1/strided_slice_5:0' shape=(?, 256) dtype=float32>,
     <tf.Tensor 'lambda_1/strided_slice_6:0' shape=(?, 256) dtype=float32>,
     <tf.Tensor 'lambda_1/strided_slice_7:0' shape=(?, 256) dtype=float32>,
     <tf.Tensor 'lambda_1/strided_slice_8:0' shape=(?, 256) dtype=float32>]
    

相关问题