首页 文章

如何从另一个数组索引到张量张量流

提问于
浏览
0

我正在尝试为AI中的问题编写一个深入的q学习网络 . 我有一个函数 predict() ,它产生一个形状张量 (None, 3) 接受形状 (None, 5) 的输入 . (None, 3) 中的3对应于每个州可以采取的每个动作的q值 . 现在,在训练步骤中,我必须多次调用 predict() 并使用结果来计算成本并训练模型 . 为此,我还有另一个名为 current_actions 的数据数组,它是一个包含前一次迭代中特定状态所采取的动作索引的列表 .

需要发生的是 current_states_outputs 应该是从 predict() 的输出创建的张量,其中每行只包含一个q值(而不是 predict() 的输出中的三个),并且应该选择哪个q值应该取决于相应的索引 current_actions .

例如,如果 current_states_output = [[1,2,3],[4,5,6],[7,8,9]]current_actions=[0,2,1] ,则操作后的结果应为 [1,6,8] (已更新)

我该怎么做呢?

我试过以下 -

current_states_outputs = self.sess.run(self.prediction, feed_dict={self.X:current_states})
    current_states_outputs = np.array([current_states_outputs[a][current_actions[a]] for a in range(len(current_actions))])

我基本上在 predict() 上运行会话并使用普通的python方法完成了所需的操作 . 但是因为这会从图的前几层切断成本的连接,所以不能进行任何培训 . 因此,我需要将此操作保持在张量流中并将所有内容保持为张量流张量本身 . 我该怎么办呢?

1 回答

  • 2

    你可以试试,

    tf.squeeze(tf.gather_nd(a,tf.stack([tf.range(b.shape[0])[...,tf.newaxis], b[...,tf.newaxis]], axis=2)))
    

    示例代码:

    a = tf.Variable(current_states_outputs)
    b = tf.Variable(current_actions)
    out = tf.squeeze(tf.gather_nd(a,tf.stack([tf.range(b.shape[0])[...,tf.newaxis], b[...,tf.newaxis]], axis=2)))
    sess = tf.InteractiveSession()
    tf.global_variables_initializer().run()
    sess.run(out)
    
    #output
    [1, 6, 8]
    

相关问题