我发现了lstm细胞的特殊属性(不限于lstm,但我只用这个检查过)的张量流,据我所知还没有报道 . 我不知道它是否真的有,所以我把这篇文章留在了SO中 . 以下是此问题的玩具代码:
import tensorflow as tf
import numpy as np
import time
def network(input_list):
input,init_hidden_c,init_hidden_m = input_list
cell = tf.nn.rnn_cell.BasicLSTMCell(256, state_is_tuple=True)
init_hidden = tf.nn.rnn_cell.LSTMStateTuple(init_hidden_c, init_hidden_m)
states, hidden_cm = tf.nn.dynamic_rnn(cell, input, dtype=tf.float32, initial_state=init_hidden)
net = [v for v in tf.trainable_variables()]
return states, hidden_cm, net
def action(x, h_c, h_m):
t0 = time.time()
outputs, output_h = sess.run([rnn_states[:,-1:,:], rnn_hidden_cm], feed_dict={
rnn_input:x,
rnn_init_hidden_c: h_c,
rnn_init_hidden_m: h_m
})
dt = time.time() - t0
return outputs, output_h, dt
rnn_input = tf.placeholder("float", [None, None, 512])
rnn_init_hidden_c = tf.placeholder("float", [None,256])
rnn_init_hidden_m = tf.placeholder("float", [None,256])
rnn_input_list = [rnn_input, rnn_init_hidden_c, rnn_init_hidden_m]
rnn_states, rnn_hidden_cm, rnn_net = network(rnn_input_list)
feed_input = np.random.uniform(low=-1.,high=1.,size=(1,1,512))
feed_init_hidden_c = np.zeros(shape=(1,256))
feed_init_hidden_m = np.zeros(shape=(1,256))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(10000):
_, output_hidden_cm, deltat = action(feed_input, feed_init_hidden_c, feed_init_hidden_m)
if i % 10 == 0:
print 'Running time: ' + str(deltat)
(feed_init_hidden_c, feed_init_hidden_m) = output_hidden_cm
feed_input = np.random.uniform(low=-1.,high=1.,size=(1,1,512))
[不重要]此代码的作用是从包含LSTM的'network()'函数生成输出,其中输入的时间维度为1,因此输出也为1,并为每个运行步骤拉入和拉出初始状态 .
[重要]查看'sess.run()'部分 . 出于某些原因,在我的真实代码中,我碰巧将[:, - 1:,:]放入'rnn_states' . 然后发生了什么 the time spent for each 'sess.run()' increases . 对于我自己的一些检查,我发现这种减速源于[:, - 1:,:] . 我只是想在最后一步得到输出 . 如果你执行'outputs, output_h = sess.run([rnn_states, rnn_hidden_cm], feed_dict{~' w / o [:, - 1:,:]并在'sess.run()'之后取'last_output = outputs[:,-1:,:]',则不会发生减速 .
我不知道为什么这个指数增量的时间发生在[:, - 1:,:]运行 . 这是张量流的本质还没有记录,但特别慢(可能是自己添加更多的图形?)?谢谢你,希望这篇文章不会给其他用户带来这个错误 .
2 回答
我遇到了同样的问题,TensorFlow在我运行它的每次迭代时都会减慢,并在尝试调试时发现了这个问题 . 这是我的情况的简短描述以及我如何解决它以供将来参考 . 希望它可以指向某人正确的方向并节省他们一些时间 .
在我的情况下,问题主要是我在执行
sess.run()
时没有使用feed_dict
来提供网络状态 . 相反,我每次重复都重新声明outputs
,final_state
和prediction
. https://github.com/tensorflow/tensorflow/issues/1439#issuecomment-194405649的答案让我意识到这是多么愚蠢......我不断在每次迭代中创建新的图形节点,使它变得越来越慢 . 有问题的代码看起来像这样:解决方案当然只是在开始时声明一次节点,并使用
feed_dict
提供新数据 . 代码从半慢(开始时> 15 ms)变为每次迭代变慢,在大约1 ms内执行每次迭代 . 我的新代码看起来像这样:从for循环中移出声明也消除了OP sdr2002所具有的问题,在for循环内的
sess.run()
中执行切片outputs[-1]
.如上所述,对于这种情况,“sess.run()”没有切片输出 .