我正在尝试在一个实时运行的应用程序中运行在Keras训练的RNN . 这里的循环网络中的“时间”(它是LSTM)是接收数据的实际时刻 .
我想以在线方式获得RNN的输出 . 对于非循环模型,我只是将输入形状为 inputDatum=1,input_shape
并在其上运行 Model.predict
. 我不确定这是在Keras中为应用程序使用前向传递的预期方法,但它对我有用 .
但对于循环模块, Model.predict
期望输入整个输入,包括时间维度 . 所以它不起作用......
有没有办法在Keras做到这一点,还是我需要转到Tensorflow并在那里实施操作?
1 回答
您可以将
LSTM
图层设置为有状态 . LSTM的内部状态将一直保持到您手动调用model.reset_states()
为止 .例如,假设我们已经训练了一个简单的LSTM模型 .
然后,可以使用
stateful=True
将权重加载到另一个模型上进行预测(请记住在Input
层中设置batch_shape
) .对于您的用例,由于
predict_model
是有状态的,因此对长度为1的子序列的连续predict
调用将给出与预测整个序列相同的结果 . 只需记住在预测新序列之前调用reset_states()
.