首页 文章

TensorFlow CIFAR-10示例教程中的distorted_inputs()函数如何获得每批128个图像?

提问于
浏览
8

我正在TensorFlow getting started guide for CNN浏览CIFAR-10示例

现在在 cifar10_train.py 的火车功能中,我们得到的图像为

images,labels = cifar10.distorted_inputs()

distorted_inputs() 函数中,我们在队列中生成文件名,然后将单个记录读取为

# Create a queue that produces the filenames to read.
 filename_queue = tf.train.string_input_producer(filenames)

 # Read examples from files in the filename queue.
 read_input = cifar10_input.read_cifar10(filename_queue)
 reshaped_image = tf.cast(read_input.uint8image, tf.float32)

当我添加调试代码时, read_input 变量只包含1条记录,其中包含图像及其高度,宽度和标签名称 .

然后,该示例将一些失真应用于读取的图像/记录,然后将其传递给 _generate_image_and_label_batch() 函数 .

然后该函数返回一个4D Tensor of shape [batch_size, 32, 32, 3] ,其中 batch_size = 128 .

返回批处理时,上述功能使用 tf.train.shuffle_batch() 函数 .

我的问题是 tf.train.shuffle_batch() 函数中的额外记录来自何处?我们没有传递任何文件名或读者对象 .

有人可以说明我们如何从1条记录转到128条记录吗?我查看了文档,但不明白 .

1 回答

  • 7

    tf.train.shuffle_batch()函数可用于生成包含一批输入的(一个或多个)张量 . 在内部, tf.train.shuffle_batch() 创建一个tf.RandomShuffleQueue,在其上使用图像和标签张量调用q.enqueue()以将单个元素(图像标签对)排入队列 . 然后它返回q.dequeue_many(batch_size)的结果,该结果将 batch_size 随机选择的元素(图像 - 标签对)连接成一批图像和一批标签 .

    请注意,虽然它看起来像 read_inputfilename_queue 这样的代码具有功能关系,但还有一个额外的皱纹 . 简单地评估 tf.train.shuffle_batch() 的结果将永远阻止,因为没有元素添加到内部队列 . 为简化此操作,当您调用 tf.train.shuffle_batch() 时,TensorFlow会将QueueRunner添加到图形中的内部集合 . 稍后调用tf.train.start_queue_runners()(例如here in cifar10_train.py)将启动一个向队列添加元素的线程,并使训练继续进行 . Threading and Queues HOWTO有关于其工作原理的更多信息 .

相关问题