首页 文章

TensorFlow:在Android中的推断之间初始化RNN的状态

提问于
浏览
1

我们在Android上运行了一个有效的TensorFlow网络(graphdef),我注意到随着时间的推移推断结果往往是相关的 . 也就是说,如果返回标签A,那么即使输入数据切换到应该生成B标签的数据,也会有及时跟随的A流 . 最终,结果将切换到B,但似乎存在延迟并且表明RNN在推理调用之间是有状态的 . 我们的网络使用RNN / LSTM .

cellLSTM    = tf.nn.rnn_cell.BasicLSTMCell(nHidden)
cellsLSTM   = tf.nn.rnn_cell.MultiRNNCell([cellLSTM] * 2)
RNNout, RNNstates = tf.nn.rnn(cellsLSTM, Xin)

我想知道是否需要在推理调用之间重新初始化RNN状态 . 我会注意到TensorFlowInferenceInterface.java接口中没有这样的方法 . 我想可以将RNN单元初始化节点插入到可以用节点值激活的图形中(使用FillNodeInt或类似方法) .

所以我的问题是:在Tensorflow中使用RNN / LSTM的最佳做法是什么 . 是否需要在推论之间清除状态?如果是这样,那怎么做呢?

2 回答

  • 0

    是否需要在推论之间清除状态?

    我认为这取决于RNN的训练方式以及如何使用它 . 但是,我猜想无论有没有重置状态,网络都可以正常工作 .

    怎么做的?

    评估与初始状态关联的每个张量的初始化操作 .

  • 0

    虽然我不能评论RNN状态初始化的一般实践,但这里是我们如何设法强制初始状态定义 . 问题是虽然批量大小确实是训练集的常量参数,但它不适用于测试集 . 测试集始终是数据语料库的20%,因此每次更改语料库时其大小都不同 .
    解决方案是为batchsize创建一个新变量:

    batch_size_T  = tf.shape(Xin)[0]
    

    其中 Xin 是大小为[b x m x n]的输入张量,其中 b 是批量大小, m x n 是训练帧的大小 . 辛从 feed_dict 被喂入 .

    然后可以将初始状态定义为:

    initial_state = lstm_cells.zero_state(batch_size_T, tf.float32)
    

    最后,RNN是根据新的动态RNN定义的:

    outputs, state = tf.nn.dynamic_rnn(cell=lstm_cells, inputs=Xin, dtype=tf.float32, initial_state=initial_state)
    

相关问题