我想应用Tensorflow's Dataset API来训练数据集,该数据集在每次通过网络传播一批数据时都会发生变化 .
我遇到了这个代码(下面),它使用了向Tensorflow提供数据的feed_dict实现,我想让它适应使用Tensorflow API,因为Tensorflow自己说
代码的相关部分(关于Q-learning的实现)是:
def generate_session(t_max=1000, epsilon=0, train=False):
"""play env with approximate q-learning agent and train it at the same time"""
total_reward = 0
s = env.reset()
for t in range(t_max):
a = get_action(s, epsilon=epsilon)
next_s, r, done, _ = env.step(a)
if train:
sess.run(train_step,feed_dict={
states_ph: [s], actions_ph: [a], rewards_ph: [r],
next_states_ph: [next_s], is_done_ph: [done]
})
total_reward += r
s = next_s
if done: break
return total_reward
我想使用Tensorflow Data API,但是这里的问题是所有被提供的数据: s, a, r, next_s, is_done_ph
取决于训练迭代的输出 . 换句话说, t=50
的输入值 t=50
由 s, a, r, next_s, is_done_ph
的输出 t=49
创建 . 这是因为这条线
a = get_action(s, epsilon=epsilon)
这将根据预设训练步骤中的s输出创建新动作
next_s, r, done, _ = env.step(a)
基本上为我们提供了训练循环的剩余新输入 .
现在我遇到的问题是Tensorflow数据集API中的示例使用了训练开始之前已知的训练数据,但我不确定如何使用这个不断发展的数据集来实现Tensorflow数据集API .