首页 文章

如何实时在Keras RNN中实现前向传递?

提问于
浏览
1

我正在尝试在一个实时运行的应用程序中运行在Keras训练的RNN . 这里的循环网络中的“时间”(它是LSTM)是接收数据的实际时刻 .

我想以在线方式获得RNN的输出 . 对于非循环模型,我只是将输入形状为 inputDatum=1,input_shape 并在其上运行 Model.predict . 我不确定这是在Keras中为应用程序使用前向传递的预期方法,但它对我有用 .

但对于循环模块, Model.predict 期望输入整个输入,包括时间维度 . 所以它不起作用......

有没有办法在Keras做到这一点,还是我需要转到Tensorflow并在那里实施操作?

1 回答

  • 2

    您可以将 LSTM 图层设置为有状态 . LSTM的内部状态将一直保持到您手动调用 model.reset_states() 为止 .

    例如,假设我们已经训练了一个简单的LSTM模型 .

    x = Input(shape=(None, 10))
    h = LSTM(8)(x)
    out = Dense(4)(h)
    model = Model(x, out)
    model.compile(loss='mse', optimizer='adam')
    
    X_train = np.random.rand(100, 5, 10)
    y_train = np.random.rand(100, 4)
    model.fit(X_train, y_train)
    

    然后,可以使用 stateful=True 将权重加载到另一个模型上进行预测(请记住在 Input 层中设置 batch_shape ) .

    x = Input(batch_shape=(1, None, 10))
    h = LSTM(8, stateful=True)(x)
    out = Dense(4)(h)
    predict_model = Model(x, out)
    
    # copy the weights from `model` to this model
    predict_model.set_weights(model.get_weights())
    

    对于您的用例,由于 predict_model 是有状态的,因此对长度为1的子序列的连续 predict 调用将给出与预测整个序列相同的结果 . 只需记住在预测新序列之前调用 reset_states() .

    X = np.random.rand(1, 3, 10)
    print(model.predict(X))
    # [[-0.09485822,  0.03324107,  0.243945  , -0.20729265]]
    
    predict_model.reset_states()
    for t in range(3):
        print(predict_model.predict(X[:, t:(t + 1), :]))
    # [[-0.04117237 -0.06340873  0.10212967 -0.06400848]]
    # [[-0.12808001  0.0039286   0.23223262 -0.23842749]]
    # [[-0.09485822  0.03324107  0.243945   -0.20729265]]
    

相关问题