首页 文章

Tensorflow使用标签读取图像

提问于
浏览
53

我正在使用Tensorflow构建标准图像分类模型 . 为此,我有输入图像,每个图像都分配有一个标签({0,1}中的数字) . 因此,可以使用以下格式将数据存储在列表中:

/path/to/image_0 label_0
/path/to/image_1 label_1
/path/to/image_2 label_2
...

我想使用TensorFlow的排队系统来读取我的数据并将其提供给我的模型 . 忽略标签,可以通过使用 string_input_producerwholeFileReader 轻松实现此目的 . 这里的代码:

def read_my_file_format(filename_queue):
  reader = tf.WholeFileReader()
  key, value = reader.read(filename_queue)
  example = tf.image.decode_png(value)
  return example

#removing label, obtaining list containing /path/to/image_x
image_list = [line[:-2] for line in image_label_list]

input_queue = tf.train.string_input_producer(image_list)                                                     
input_images = read_my_file_format(input_queue)

但是,标签在该过程中丢失,因为图像数据被有意地作为输入管道的一部分混洗 . 通过输入队列将标签与图像数据一起推送的最简单方法是什么?

3 回答

  • 2

    使用 slice_input_producer 提供了更清洁的解决方案 . Slice Input Producer允许我们创建一个包含任意多个可分离值的输入队列 . 这个问题的片段如下所示:

    def read_labeled_image_list(image_list_file):
        """Reads a .txt file containing pathes and labeles
        Args:
           image_list_file: a .txt file with one /path/to/image per line
           label: optionally, if set label will be pasted after each line
        Returns:
           List with all filenames in file image_list_file
        """
        f = open(image_list_file, 'r')
        filenames = []
        labels = []
        for line in f:
            filename, label = line[:-1].split(' ')
            filenames.append(filename)
            labels.append(int(label))
        return filenames, labels
    
    def read_images_from_disk(input_queue):
        """Consumes a single filename and label as a ' '-delimited string.
        Args:
          filename_and_label_tensor: A scalar string tensor.
        Returns:
          Two tensors: the decoded image, and the string label.
        """
        label = input_queue[1]
        file_contents = tf.read_file(input_queue[0])
        example = tf.image.decode_png(file_contents, channels=3)
        return example, label
    
    # Reads pfathes of images together with their labels
    image_list, label_list = read_labeled_image_list(filename)
    
    images = ops.convert_to_tensor(image_list, dtype=dtypes.string)
    labels = ops.convert_to_tensor(label_list, dtype=dtypes.int32)
    
    # Makes an input queue
    input_queue = tf.train.slice_input_producer([images, labels],
                                                num_epochs=num_epochs,
                                                shuffle=True)
    
    image, label = read_images_from_disk(input_queue)
    
    # Optional Preprocessing or Data Augmentation
    # tf.image implements most of the standard image augmentation
    image = preprocess_image(image)
    label = preprocess_label(label)
    
    # Optional Image and Label Batching
    image_batch, label_batch = tf.train.batch([image, label],
                                              batch_size=batch_size)
    

    另请参阅TensorVision示例中的generic_input_producer以获取完整输入管道 .

  • 48

    解决此问题有三个主要步骤:

    • 使用包含原始空格分隔字符串的字符串列表填充tf.train.string_input_producer(),该字符串包含文件名和标签 .

    • 使用tf.read_file(filename)而不是 tf.WholeFileReader() 来读取图像文件 . tf.read_file() 是一个无状态的op,它使用单个文件名并生成一个包含文件内容的字符串 . 它的优点是's a pure function, so it'易于将数据与输入和输出相关联 . 例如,您的 read_my_file_format 函数将变为:

    def read_my_file_format(filename_and_label_tensor):
      """Consumes a single filename and label as a ' '-delimited string.
    
      Args:
        filename_and_label_tensor: A scalar string tensor.
    
      Returns:
        Two tensors: the decoded image, and the string label.
      """
      filename, label = tf.decode_csv(filename_and_label_tensor, [[""], [""]], " ")
      file_contents = tf.read_file(filename)
      example = tf.image.decode_png(file_contents)
      return example, label
    
    • 通过从 input_queue 传递一个出列的元素来调用 read_my_file_format 的新版本:
    image, label = read_my_file_format(input_queue.dequeue())
    

    然后,您可以在模型的其余部分中使用 imagelabel 张量 .

  • 21

    除了提供的答案之外,您还可以做其他一些事情:

    Encode your label into the filename. 如果您有N个不同的类别,则可以将文件重命名为: 0_file001, 5_file002, N_file003 . 之后当您从reader key, value = reader.read(filename_queue) 读取数据时,您的键/值为:

    Read的输出将是文件名(键)和该文件的内容(值)

    然后解析文件名,提取标签并将其转换为int . 这将需要对数据进行一些预处理 .

    Use TFRecords 允许您将数据和标签存储在同一文件中 .

相关问题