我正在TensorFlow getting started guide for CNN浏览CIFAR-10示例
现在在 cifar10_train.py 的火车功能中,我们得到的图像为
images,labels = cifar10.distorted_inputs()
在 distorted_inputs()
函数中,我们在队列中生成文件名,然后将单个记录读取为
# Create a queue that produces the filenames to read.
filename_queue = tf.train.string_input_producer(filenames)
# Read examples from files in the filename queue.
read_input = cifar10_input.read_cifar10(filename_queue)
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
当我添加调试代码时, read_input
变量只包含1条记录,其中包含图像及其高度,宽度和标签名称 .
然后,该示例将一些失真应用于读取的图像/记录,然后将其传递给 _generate_image_and_label_batch()
函数 .
然后该函数返回一个4D Tensor of shape [batch_size, 32, 32, 3]
,其中 batch_size = 128
.
返回批处理时,上述功能使用 tf.train.shuffle_batch()
函数 .
我的问题是 tf.train.shuffle_batch()
函数中的额外记录来自何处?我们没有传递任何文件名或读者对象 .
有人可以说明我们如何从1条记录转到128条记录吗?我查看了文档,但不明白 .
1 回答
tf.train.shuffle_batch()函数可用于生成包含一批输入的(一个或多个)张量 . 在内部,
tf.train.shuffle_batch()
创建一个tf.RandomShuffleQueue,在其上使用图像和标签张量调用q.enqueue()以将单个元素(图像标签对)排入队列 . 然后它返回q.dequeue_many(batch_size)的结果,该结果将batch_size
随机选择的元素(图像 - 标签对)连接成一批图像和一批标签 .请注意,虽然它看起来像
read_input
和filename_queue
这样的代码具有功能关系,但还有一个额外的皱纹 . 简单地评估tf.train.shuffle_batch()
的结果将永远阻止,因为没有元素添加到内部队列 . 为简化此操作,当您调用tf.train.shuffle_batch()
时,TensorFlow会将QueueRunner添加到图形中的内部集合 . 稍后调用tf.train.start_queue_runners()(例如here in cifar10_train.py)将启动一个向队列添加元素的线程,并使训练继续进行 . Threading and Queues HOWTO有关于其工作原理的更多信息 .