首页 文章

TensorFlow图中的条件评估

提问于
浏览
2

这可以通过 tf.cond 完成,但是它将从manual更新图形的两个分支:

请注意,条件执行仅适用于true_fn和false_fn中定义的操作 . 考虑以下简单程序:

z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))

如果x <y,将执行tf.add操作,并且不会执行tf.square操作 . 由于cond的至少一个分支需要z,因此无条件地始终执行tf.multiply操作 .

我如何实现这一点,以便有条件地执行 tf.multiply (即仅在 x > Y 时)?

更具体地说,我正在尝试做什么:

var1 = tf.Variable(tf.zeros(4), trainable=False, name='var1')
update_var1 = tf.assign(var1,var1 +1)
training = tf.placeholder(tf.bool)

def f1():
  with tf.control_dependencies([update_var1]):
    return var1*1.1

def f2():
  return var1 * 1.1

final = tf.cond(training, f1, f2)
sess.run(final, feed_dict={training:False})

每次评估final时,这将使var1增加1,无论 training 的值如何,问题都在 tf.cond 中,因为手动它可以工作:

var1 = tf.Variable(tf.zeros(4), trainable=False, name='var1')
update_var1 = tf.assign(var1,var1 +1)
training = tf.placeholder(tf.bool)

with tf.control_dependencies([update_var1]):
  f1 = var1 * 1.1

f2 = var1 * 1.1

sess.run(f1)
>> array([1.1,1.1,1.1,1.1])
sess.run(f1)
>> array([2.2,2.2,2.2,2.2])
# var1 gets updated every call
sess.run(f2)
>> array([2.2,2.2,2.2,2.2])
sess.run(f2)
>> array([2.2,2.2,2.2,2.2])
# var1 does not get updated

1 回答

  • 5

    一般解决方案如下:移动要有条件地执行的代码 into the body of the lambda (或者通常是可调用对象),用于 tf.cond() 的相应分支 . 例如,要确保 tf.multiply(a, b) 仅在 x < y 时执行,请将其移动到 true_fn lambda中:

    result = tf.cond(x < y, lambda: tf.add(x, tf.multiply(a, b)), lambda: tf.square(y))
    

    相同的原理可以应用于变量更新操作,例如 tf.assign() . 重要的细节是您必须创建用于其中一个分支的 tf.assign() op inside the body of the function . 这里's how you' d修改你的第二个例子:

    var1 = tf.Variable(tf.zeros(4), trainable=False, name='var1')
    training = tf.placeholder(tf.bool)
    
    def f1():
      with tf.control_dependencies([tf.assign(var1, var1 + 1)]):
        return var1 * 1.1
    
    def f2():
      return var1 * 1.1
    
    final = tf.cond(training, f1, f2)
    sess.run(final, feed_dict={training: False})
    

    赋值的控制依赖关系有点繁琐,所以您也可以将 f1() 写为:

    def f1():
      return tf.assign(var1, var1 + 1) * 1.1
    

    ......或者将整个事情放在一行:

    final = tf.cond(training, lambda: tf.assign(var1, var1 + 1) * 1.1, lambda: var1 * 1.1)
    

相关问题