首页 文章

从梯度/损耗计算中去耦出队操作

提问于
浏览
1

我目前正试图摆脱使用Feed并开始使用队列以支持更大的数据集 . 使用队列对张量流中的优化器工作正常,因为它们仅为每个出列操作计算一次渐变 . 但是,我已经与执行行搜索的其他优化器实现了接口,我不仅需要评估渐变,还需要评估同一批次的多个点的损失 . 不幸的是,对于正常的排队系统,每个损失评估将执行出队而不是多次计算同一批次 .

有没有办法将出列操作与梯度/损耗计算分离,使得我可以执行一次出列,然后在当前批次上多次执行梯度/损失计算?

编辑:请注意我的输入张量的大小在批次之间是可变的 . 我们使用分子数据,每个分子都有不同数量的原子 . 这与图像数据完全不同,图像数据通常按比例缩放以具有相同的尺寸 .

1 回答

  • 4

    通过创建变量存储出列值来解耦它,然后依赖于此变量而不是出列操作 . 推进队列发生在 assign

    Solution #1 :固定大小的数据,使用变量

    (image_batch_live,) = tf.train.batch([image],batch_size=5,num_threads=1,capacity=614)
    
    image_batch = tf.Variable(
      tf.zeros((batch_size, image_size, image_size, color_channels)),
      trainable=False,
      name="input_values_cached")
    
    advance_batch = tf.assign(image_batch, image_batch_live)
    

    现在 image_batch 给出队列的最新值而不推进它,并且 advance_batch 使队列前进 .

    Solution #2 :可变大小的数据,使用持久性的Tensors

    在这里,我们通过引入 dequeue_opdequeue_op2 来解耦工作流程 . 所有计算都取决于 dequeue_op2 ,它被提供 dequeue_op 的保存值 . 使用 get_session_tensor/get_session_handle 确保实际数据保留在TensorFlow运行时中,并且通过 feed_dict 传递的值是一个短字符串标识符 . 由于 dummy_handle ,API有点尴尬,我提出了这个问题here

    import tensorflow as tf
    def create_session():
        sess = tf.InteractiveSession(config=tf.ConfigProto(operation_timeout_in_ms=3000))
        return sess
    
    tf.reset_default_graph()
    
    sess = create_session()
    dt = tf.int32
    dummy_handle = sess.run(tf.get_session_handle(tf.constant(1)))
    q = tf.FIFOQueue(capacity=20, dtypes=[dt])
    enqueue_placeholder = tf.placeholder(dt, shape=[None])
    enqueue_op = q.enqueue(enqueue_placeholder)
    dequeue_op = q.dequeue()
    size_op = q.size()
    
    dequeue_handle_op = tf.get_session_handle(dequeue_op)
    dequeue_placeholder, dequeue_op2 = tf.get_session_tensor(dummy_handle, dt)
    compute_op1 = tf.reduce_sum(dequeue_op2)
    compute_op2 = tf.reduce_sum(dequeue_op2)+1
    
    
    # fill queue with variable size data
    for i in range(10):
        sess.run(enqueue_op, feed_dict={enqueue_placeholder:[1]*(i+1)})
    sess.run(q.close())
    
    try:
        while(True):
            dequeue_handle = sess.run(dequeue_handle_op) # advance the queue
            val1 = sess.run(compute_op1, feed_dict={dequeue_placeholder: dequeue_handle.handle})
            val2 = sess.run(compute_op2, feed_dict={dequeue_placeholder: dequeue_handle.handle})
            size = sess.run(size_op)
            print("val1 %d, val2 %d, queue size %d" % (val1, val2, size))
    except tf.errors.OutOfRangeError:
        print("Done")
    

    当你运行它时,你会看到类似下面的内容

    val1 1, val2 2, queue size 9
    val1 2, val2 3, queue size 8
    val1 3, val2 4, queue size 7
    val1 4, val2 5, queue size 6
    val1 5, val2 6, queue size 5
    val1 6, val2 7, queue size 4
    val1 7, val2 8, queue size 3
    val1 8, val2 9, queue size 2
    val1 9, val2 10, queue size 1
    val1 10, val2 11, queue size 0
    Done
    

相关问题