首页 文章

对整个数据集或每次调用iterator.next()进行一次Tensorflow数据集数据预处理?

提问于
浏览
4

您好我正在研究tensorflow中的数据集API,我对datat.map()函数有一个问题,该函数执行数据预处理 .

file_name = ["image1.jpg", "image2.jpg", ......]
im_dataset = tf.data.Dataset.from_tensor_slices(file_names)
im_dataset = im_dataset.map(lambda image:tuple(tf.py_func(image_parser(), [image], [tf.float32, tf.float32, tf.float32])))
im_dataset = im_dataset.batch(batch_size)
iterator = im_dataset.make_initializable_iterator()

数据集接收图像名称并将其解析为3个张量(关于图像的3个信息) .

如果我的训练文件夹中有大量图像,预处理它们需要很长时间 . 我的问题是,由于数据集API据说是为高效的输入管道而设计的,因此在我将它们提供给我的工作人员(比如说GPU)之前对整个数据集进行预处理,或者每次我只预处理一批图像调用iterator.get_next()?

1 回答

  • 5

    如果预处理管道很长且输出很小,则处理后的数据应该适合内存 . 如果是这种情况,您可以使用tf.data.Dataset.cache将已处理的数据缓存在内存或文件中 .

    来自官方performance guide

    tf.data.Dataset.cache转换可以在内存或本地存储中缓存数据集 . 如果传递给映射转换的用户定义函数很昂贵,则只要生成的数据集仍然适合内存或本地存储,就可以在映射转换后应用缓存转换 . 如果用户定义的函数增加了存储数据集超出缓存容量所需的空间,请考虑在训练作业之前预处理数据以减少资源使用 .


    在内存中使用缓存的示例

    以下是每个预处理需要花费大量时间(0.5秒)的示例 . 数据集上的第二个时期将比第一个时期快得多

    def my_fn(x):
        time.sleep(0.5)
        return x
    
    def parse_fn(x):
        return tf.py_func(my_fn, [x], tf.int64)
    
    dataset = tf.data.Dataset.range(5)
    dataset = dataset.map(parse_fn)
    dataset = dataset.cache()    # cache the processed dataset, so every input will be processed once
    dataset = dataset.repeat(2)  # repeat for multiple epochs
    
    res = dataset.make_one_shot_iterator().get_next()
    
    with tf.Session() as sess:
        for i in range(10):
            # First 5 iterations will take 0.5s each, last 5 will not
            print(sess.run(res))
    

    缓存到文件

    如果要将缓存数据写入文件,可以为 cache() 提供参数:

    dataset = dataset.cache('/tmp/cache')  # will write cached data to a file
    

    这将允许您只处理数据集一次,并对数据运行多个实验,而无需再次重新处理它 .

    Warning :缓存到文件时必须是 careful . 如果您更改数据,但保留 /tmp/cache.* 文件,它仍将读取缓存的旧数据 . 例如,如果我们使用上面的数据并将数据范围更改为 [10, 15] ,我们仍将获取 [0, 5] 中的数据:

    dataset = tf.data.Dataset.range(10, 15)
    dataset = dataset.map(parse_fn)
    dataset = dataset.cache('/tmp/cache')
    dataset = dataset.repeat(2)  # repeat for multiple epochs
    
    res = dataset.make_one_shot_iterator().get_next()
    
    with tf.Session() as sess:
        for i in range(10):
            print(sess.run(res))  # will still be in [0, 5]...
    

    每当要缓存的数据发生更改时,始终删除缓存的文件 .

    可能出现的另一个问题是,如果在缓存所有数据之前中断脚本 . 您将收到如下错误:

    AlreadyExistsError(参见上面的回溯):似乎有一个并发的缓存迭代器正在运行 - 缓存锁定文件已经存在('/tmp/cache.lockfile') . 如果您确定没有其他正在运行的TF计算正在使用此缓存前缀,请删除锁定文件并重新初始化迭代器 .

    确保您处理整个数据集以包含整个缓存文件 .

相关问题