首页 文章

批量迭代如何在Tensorflow中工作?

提问于
浏览
2

我试图在我的数据上重用PTB language model但是缺乏Tensorflow的知识来理解它如何处理训练数据的批量迭代 . 以下是我在培训期间理解批量迭代的方法:

while epoch <= maxepoch do
  for minibatch in data_iterator() do
    model.forward(minibatch)
    (...)
  end
end

不能比这更简单,可以吗?类似的东西在许多其他框架中完成但在Tensorflow中没有完成:)以下是官方PTB语言模型教程中的minibatch函数示例:

def ptb_producer(raw_data, batch_size, num_steps, name=None):
    with tf.name_scope(name, "PTBProducer", [raw_data, batch_size, num_steps]):
        raw_data = tf.convert_to_tensor(raw_data, name="raw_data", dtype=tf.int32)

        data_len = tf.size(raw_data)
        batch_len = data_len // batch_size
        data = tf.reshape(raw_data[0 : batch_size * batch_len],
                                            [batch_size, batch_len])

        epoch_size = (batch_len - 1) // num_steps
        assertion = tf.assert_positive(
                epoch_size,
                message="epoch_size == 0, decrease batch_size or num_steps")
        with tf.control_dependencies([assertion]):
            epoch_size = tf.identity(epoch_size, name="epoch_size")

        i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()
        x = tf.strided_slice(data, [0, i * num_steps], [batch_size, (i + 1) * num_steps])
        x.set_shape([batch_size, num_steps])
        y = tf.strided_slice(data, [0, i * num_steps + 1], [batch_size, (i + 1) * num_steps + 1])
        y.set_shape([batch_size, num_steps])
        return x, y

一旦调用,该函数返回 x 个输入和 y 个目标 . 我在这里看不到Python迭代器的迹象,但是有一个 tf.strided_slice 的调用,它使用 tf.train.range_input_producer 生成的 i 索引,所以这应该模拟数据的滑动窗口 . 但是,在训练之前只调用一次该函数,那么它如何迭代我的数据呢?目前还不清楚 . 有人可以解释这个和完全模糊的Tensorflow机制吗?

1 回答

  • 2

    "magic"隐藏在调用tf.train.range_input_producer的行中:

    i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()
    

    ...创建一个 pops the values from the queue 持有 0..epoch_size-1 整数的操作 . 换句话说,它迭代 0..epoch_size-1 范围 .


    是的,这似乎违反直觉 . 所以这是一个在tensorflow中使用队列的简单可运行示例:

    index = tf.train.range_input_producer(10, shuffle=False).dequeue()
    
    with tf.Session() as sess:
      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(coord=coord)
    
      for i in range(15):
        print(sess.run(index))
    
      coord.request_stop()
      coord.join(threads)
    

    运行时,您应该看到 09 的值,然后从 04 再看到5个值 . 请注意 sess.run 评估相同的张量 index ,但它得到 different value each time . 可以添加依赖于 index 的其他操作,并且将使用新值 index 对它们进行求值 .

    另请注意,队列在另一个线程中运行,因此为了使用 tf.train.range_input_producer ,必须启动 Coordinator 并生成多个线程(并在最后停止它们) . 如果您尝试在没有 Coordinator 的情况下运行相同的示例,则执行 sess.run(index) will block 脚本 .

    您可以使用此示例,例如,设置 shuffle=True 等 .


    回到PTB制作人片段:

    i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()
    x = tf.strided_slice(data, [0, i*num_steps], [batch_size, (i+1)*num_steps])
    x.set_shape([batch_size, num_steps])
    y = tf.strided_slice(data, [0, i*num_steps+1], [batch_size, (i+1)*num_steps+1])
    y.set_shape([batch_size, num_steps])
    

    现在应该很清楚,即使 xy 被定义为简单张量,它们实际上也是 data 切片上的迭代器 . 所有的线程工作都由tf.train.Supervisor处理 . 因此,调用优化操作(取决于 xy )将自动获取新批次 .


    建议阅读:

相关问题