首页 文章

在`tf.estimator`中,如何在训练结束时(而不是在每次迭代时)对变量进行“tf.assign”?

提问于
浏览
1

我正在使用 tf.estimator API来训练模型 .

据我了解, model_fn 定义了计算图,根据 mode 返回不同的 tf.estimator.EstimatorSpec .

mode==tf.estimator.ModeKeys.TRAIN 中,可以指定在每次训练迭代时调用 train_op ,这反过来会改变 tf.Variabletrainable 实例,以优化某种损失 .

Let's call the train_op optimizer, and the variables A and B.

为了加快预测和评估,我希望有一个辅助的非训练 tf.Variable Tensor C ,完全取决于已经训练过的变量 . 因此,这个张量的值是可输出的 . 此Tensor不会影响训练损失 . 我们假设我们想要:

C = tf.Variable(tf.matmul(A,B))
update_op = tf.assign(C, tf.matmul(A,B))
  • 我尝试了什么:

EstimatorSpec works good but slows down training a lot 中将 tf.group(optimizer, update_op) 作为 train_op 传递,因为 train_op 现在在每次迭代时更新 C .

因为 C 仅在eval / predict时间需要,所以在训练结束时调用 update_op 就足够了 .

Is it possible to assign a Variable at the end of training a tf.estimator.Estimator?

1 回答

  • 1

    通常,模型函数的单次迭代不知道训练是否会在运行后结束,所以我怀疑这可以直接完成 . 我看到两个选择:

    • 如果仅在训练后需要辅助变量,则可以使用 tf.estimator.Estimator.get_variable_value (参见here)在训练为numpy数组后提取变量 AB 的值,并进行计算以获得 C . 但是 C 不会成为该模型的一部分 .

    • 使用钩子(见here) . 您可以使用 end 方法编写一个钩子,该方法将在会话结束时调用(即训练停止时) . 你可能需要研究如何定义/使用钩子 - 例如here您可以在Tensorflow中找到大多数"basic"钩子的实现 . 粗糙的骨架看起来像这样:

    class UpdateHook(SessionRunHook):
        def __init__(update_variable, other_variables):
            self.update_op = tf.assign(update_variable, some_fn(other_variables))
    
        def end(session):
            session.run(self.update_op)
    

    由于钩子需要访问变量,因此您需要在模型函数内定义钩子 . 您可以将此类挂钩传递给 EstimatorSpec 中的训练过程(请参阅here) .

    我没有测试过这个!我不确定你是否可以在钩子中定义操作 . 如果没有,它应该有效地在模型函数内定义更新操作并直接将其传递给钩子 .

相关问题