首页 文章

在Keras合并一个向前的lstm和一个向后的lstm

提问于
浏览
1

我想在Keras中合并前向LSTM和后向LSTM . 后向LSTM的输入数组与前向LSTM的输入数组不同 . 因此,我不能使用keras.layers.Bidirectional .

正向输入是(10,4) . 反向输入为(12,4),在放入模型之前反转 . 我希望在LSTM之后再次将其反转并将其与前锋合并 .

简化模型如下 .

from lambdawithmask import Lambda as MaskLambda

def reverse_func(x, mask=None):
    return tf.reverse(x, [False, True, False])

forward = Sequential()
backward = Sequential()
model = Sequential()

forward.add(LSTM(input_shape = (10, 4), output_dim = 4, return_sequences = True))
backward.add(LSTM(input_shape = (12, 4), output_dim = 4, return_sequences = True))
backward.add(MaskLambda(function=reverse_func, mask_function=reverse_func))
model.add(Merge([forward, backward], mode = "concat", concat_axis = 1))

当我运行它时,错误消息是:传递给'ConcatV2'Op的'values'的列表中的张量具有不匹配的类型[bool,float32] .

谁能帮助我?我使用Keras(2.0.5)在Python 3.5.2中编码,后端是tensorflow(1.2.1) .

2 回答

  • 1

    首先,如果您有两个不同的输入,则不能使用Sequential模型 . 您必须使用功能API模型:

    from keras.models import Model
    

    两个第一个模型可以是顺序的,没有问题,但结必须是常规模型 . 当它是关于连接时,我也使用函数方法(创建图层,然后传递输入):

    junction = Concatenate(axis=1)([forward.output,backward.output])
    

    为什么axis = 1?您只能将具有相同形状的事物连接起来 . 由于你有10和12,它们是不兼容的,除非你使用这个精确的轴进行合并,这是第二个轴,考虑到你有(BatchSize,TimeSteps,Units)

    要创建最终模型,请使用 Model ,指定输入和输出:

    model = Model([forward.input,backward.input], junction)
    

    在要反转的模型中,只使用 Lambda 图层 . MaskLambda不仅仅是你想要的功能 . 我还建议你使用张量函数的keras后端:

    import keras.backend as K
    
    #instead of the MaskLambda:
    backward.add(Lambda(lambda x: K.reverse(x,axes=[1]), output_shape=(12,?))
    

    这里, ? 是LSTM图层具有的单位数量 . 最后见PS .


    PS:我不确定 output_dim 在LSTM层中是否有用 . 它's necessary in Lambda layers, but I never use it anywhere else. Shapes are natural consequences of the amount of 851598 you put in your layers. Strangely, you didn' t指定单位数量 .

    PS2:你究竟想要连接两个不同大小的序列?

  • 1

    如上面的答案所述,使用Functional API为多输入/输出模型提供了很大的灵活性 . 您只需将 go_backwards 参数设置为 True 即可反转 LSTM 层对输入向量的遍历 .

    我已经定义了下面的 smart_merge 函数,该函数将前向和后向LSTM层合并在一起并处理单个遍历情况 .

    from keras.models import Model
    from keras.layers import Input, merge
    
    def smart_merge(vectors, **kwargs):
            return vectors[0] if len(vectors)==1 else merge(vectors, **kwargs)      
    
    input1 = Input(shape=(10,4), dtype='int32')
    input2 = Input(shape=(12,4), dtype='int32')
    
    LtoR_LSTM = LSTM(56, return_sequences=False)
    LtoR_LSTM_vector = LtoR_LSTM(input1)
    RtoL_LSTM = LSTM(56, return_sequences=False, go_backwards=True)
    RtoL_LSTM_vector = RtoL_LSTM(input2)
    
    BidireLSTM_vector = [LtoR_LSTM_vector]
    BidireLSTM_vector.append(RtoL_LSTM_vector)
    BidireLSTM_vector= smart_merge(BidireLSTM_vector, mode='concat')
    

相关问题