您好我正在研究tensorflow中的数据集API,我对datat.map()函数有一个问题,该函数执行数据预处理 .
file_name = ["image1.jpg", "image2.jpg", ......]
im_dataset = tf.data.Dataset.from_tensor_slices(file_names)
im_dataset = im_dataset.map(lambda image:tuple(tf.py_func(image_parser(), [image], [tf.float32, tf.float32, tf.float32])))
im_dataset = im_dataset.batch(batch_size)
iterator = im_dataset.make_initializable_iterator()
数据集接收图像名称并将其解析为3个张量(关于图像的3个信息) .
如果我的训练文件夹中有大量图像,预处理它们需要很长时间 . 我的问题是,由于数据集API据说是为高效的输入管道而设计的,因此在我将它们提供给我的工作人员(比如说GPU)之前对整个数据集进行预处理,或者每次我只预处理一批图像调用iterator.get_next()?
1 回答
如果预处理管道很长且输出很小,则处理后的数据应该适合内存 . 如果是这种情况,您可以使用tf.data.Dataset.cache将已处理的数据缓存在内存或文件中 .
来自官方performance guide:
在内存中使用缓存的示例
以下是每个预处理需要花费大量时间(0.5秒)的示例 . 数据集上的第二个时期将比第一个时期快得多
缓存到文件
如果要将缓存数据写入文件,可以为
cache()
提供参数:这将允许您只处理数据集一次,并对数据运行多个实验,而无需再次重新处理它 .
Warning :缓存到文件时必须是 careful . 如果您更改数据,但保留
/tmp/cache.*
文件,它仍将读取缓存的旧数据 . 例如,如果我们使用上面的数据并将数据范围更改为[10, 15]
,我们仍将获取[0, 5]
中的数据:可能出现的另一个问题是,如果在缓存所有数据之前中断脚本 . 您将收到如下错误:
确保您处理整个数据集以包含整个缓存文件 .