首页 文章

在tensorflow中使用queue runner时出错

提问于
浏览
1

我是tensorflow的新手,我现在正在学习如何使用队列运行器 . 我想要做的是从dir读取二进制文件并使每个文件成为一个数组 . 我使用两个线程并批量生成4个数组 . 代码如下 .

import glob

  import tensorflow as tf

  def readfile(filenames_queue):

        filename = filenames_queue.dequeue()
        value_strings = tf.read_file(filename)
        array = tf.decode_raw(value_strings,tf.uint8)
        return [array]
 def input_pipeline(filenames,batch_size,num_threads=2):

       filenames_queue = tf.train.string_input_producer(filenames)
       thread_lists = [readfile(filenames_queue) for _ in range(num_threads)] 
       min_after_dequeue = 1000 
       capacity = min_after_dequeue+3*batch_size
       arrays = tf.train.shuffle_batch_join(thread_lists,batch_size,capacity,min_after_dequeue)
       return arrays
if __name__ == "__main__":

      filenames = glob.glob('dir/*')
      arrays_batch = input_pipeline(filenames,4)
      with tf.Session() as sess:
           tf.global_variables_initializer().run()
           coord = tf.train.Coordinator()
           threads = tf.train.start_queue_runners(sess,coord)
           for i in range(100):
                 print sess.run(arrays_batch)
           coord.request_stop()
           coord.join(threads)

我修正了Victor和Sorin指出的错误,但是出现了新的错误:

文件“input_queue.py”,第36行,打印sess.run(im_arrays_batch)

文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第889行,运行run_metadata_ptr)

在_run feed_dict_tensor,options,run_metadata中输入文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第1120行

文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第1317行,在_do_run选项中,run_metadata)

文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第1336行,在_do_call中提升类型(e)(node_def,op,message)tensorflow.python .framework.errors_impl.OutOfRangeError:RandomShuffleQueue'_1_shuffle_batch_join / random_shuffle_queue'已关闭且元素不足(请求2,当前大小为0)[[Node:shuffle_batch_join = QueueDequeueManyV2 [component_types = [DT_UINT8],timeout_ms = -1,_device =“/ job:localhost / replica:0 / task:0 / device:CPU:0“](shuffle_batch_join / random_shuffle_queue,shuffle_batch_join / n)]]

由op u'shuffle_batch_join'引起,定义于:

文件“input_queue.py”,第30行,in im_arrays_batch = input_pipeline(filenames,2)

文件“input_queue.py”,第23行,在input_pipeline arrays_batch = tf.train.shuffle_batch_join(thread_lists,batch_size,capacity,min_after_dequeue)

文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/training/input.py”,第1367行,shuffle_batch_join name = name)

文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/training/input.py”,第833行,在_shuffle_batch_join dequeued = queue.dequeue_many(batch_size,name = name)

文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/data_flow_ops.py”,第464行,在dequeue_many self._queue_ref中,n = n,component_types = self._dtypes,name =名)

文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_data_flow_ops.py”,第2418行,在_queue_dequeue_many_v2 component_types = component_types,timeout_ms = timeout_ms,name = name)

文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py”,第787行,在_apply_op_helper中op_def = op_def)

文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py”,第2956行,在create_op中op_def = op_def)

文件"/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py",第1470行,在 init self._traceback = self._graph._extract_stack()#pylint:disable = protected-access

OutOfRangeError(参见上面的回溯):RandomShuffleQueue'_1_shuffle_batch_join / random_shuffle_queue'关闭且元素不足(请求2,当前大小为0)[[Node:shuffle_batch_join = QueueDequeueManyV2 [component_types = [DT_UINT8],timeout_ms = -1,_device =“ / job:localhost / replica:0 / task:0 / device:CPU:0“](shuffle_batch_join / random_shuffle_queue,shuffle_batch_join / n)]]

2 回答

  • 0

    您的 readfile(...): 函数应该返回一个iterable,以便您可以返回功能和标签或其他类似的东西 .

    所以要修改你的代码更改 readfile(...):

    return [arrays]
    
  • 0

    tf.train.shuffle_batch_join

    tensors_list参数是张量元组的列表

    在这里,您调用了tf.decode_raw produces Tensor instances,然后将它们放在一个带有 thread_lists = [readfile(filenames_queue) for _ in range(num_threads)] 的列表中 .

    因此,它不是您提供的张量元组的列表,而是张量的列表,因此张量试图被迭代,因此错误 TypeError: 'Tensor' object is not iterable .

相关问题