首页 文章

对于列车数据,使用tensorflow tf.Variable而不是tf.placeholder

提问于
浏览
0

我使用tf.Variable表示W(权重)和b(偏差),但是tf.placeholder表示X(输入批处理)和Y(此批处理的期望值) . 一切正常 . 但今天我发现了这个话题:Tensorflow github issues并且引用:

Feed_dict从Python运行时到TensorFlow运行时执行单线程内容memcpy . 如果GPU上需要数据,那么您将获得额外的CPU-> GPU传输 . 从feed_dict切换到本机TensorFlow(变量/队列)时,我习惯看到性能提高10倍

现在我尝试找到如何使用tf.Variable或Queue作为输入数据而没有feed_dict,以提高速度,特别是批量 . 因为我需要逐个更改数据批量 . 当所有批次都完成 - 时代结束 . 而不是从开始,第二纪元等...

但对不起,我不明白我该怎么用 .

1 回答

  • 1

    以下是如何使用队列来提供培训批次的自包含示例:

    import tensorflow as tf
    
    IMG_SIZE = [30, 30, 3]
    BATCH_SIZE_TRAIN = 50
    
    def get_training_batch(batch_size):
        ''' training data pipeline -- normally you would read data from files here using
        a TF reader of some kind. '''
        image = tf.random_uniform(shape=IMG_SIZE)
        label = tf.random_uniform(shape=[])
    
        min_after_dequeue = 100
        capacity = min_after_dequeue + 3 * batch_size
        images, labels = tf.train.shuffle_batch(
            [image, label], batch_size=batch_size, capacity=capacity,
            min_after_dequeue=min_after_dequeue)
        return images, labels
    
    # define the graph
    images_train, labels_train = get_training_batch(BATCH_SIZE_TRAIN)
    '''inference, training and other ops generally are defined here too'''
    
    # start a session
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
        ''' do something interesting here -- training, validation, etc'''
        for _ in range(5):
            # typical training step where batch data are drawn from the training queue
            py_images, py_labels = sess.run([images_train, labels_train])
            print('\nData from queue:')
            print('\tImages shape, first element: ', py_images.shape, py_images[0][0, 0, 0])
            print('\tLabels shape, first element: ', py_labels.shape, py_labels[0])
    
        # close threads
        coord.request_stop()
        coord.join(threads)
    

相关问题