我正在使用 tf.estimator
API来训练模型 .
据我了解, model_fn
定义了计算图,根据 mode
返回不同的 tf.estimator.EstimatorSpec
.
在 mode==tf.estimator.ModeKeys.TRAIN
中,可以指定在每次训练迭代时调用 train_op
,这反过来会改变 tf.Variable
的 trainable
实例,以优化某种损失 .
Let's call the train_op optimizer, and the variables A and B.
为了加快预测和评估,我希望有一个辅助的非训练 tf.Variable
Tensor C ,完全取决于已经训练过的变量 . 因此,这个张量的值是可输出的 . 此Tensor不会影响训练损失 . 我们假设我们想要:
C = tf.Variable(tf.matmul(A,B))
update_op = tf.assign(C, tf.matmul(A,B))
- 我尝试了什么:
在 EstimatorSpec
works good but slows down training a lot 中将 tf.group(optimizer, update_op)
作为 train_op
传递,因为 train_op
现在在每次迭代时更新 C
.
因为 C
仅在eval / predict时间需要,所以在训练结束时调用 update_op
就足够了 .
Is it possible to assign a Variable at the end of training a tf.estimator.Estimator?
1 回答
通常,模型函数的单次迭代不知道训练是否会在运行后结束,所以我怀疑这可以直接完成 . 我看到两个选择:
如果仅在训练后需要辅助变量,则可以使用
tf.estimator.Estimator.get_variable_value
(参见here)在训练为numpy数组后提取变量A
和B
的值,并进行计算以获得C
. 但是C
不会成为该模型的一部分 .使用钩子(见here) . 您可以使用
end
方法编写一个钩子,该方法将在会话结束时调用(即训练停止时) . 你可能需要研究如何定义/使用钩子 - 例如here您可以在Tensorflow中找到大多数"basic"钩子的实现 . 粗糙的骨架看起来像这样:由于钩子需要访问变量,因此您需要在模型函数内定义钩子 . 您可以将此类挂钩传递给
EstimatorSpec
中的训练过程(请参阅here) .我没有测试过这个!我不确定你是否可以在钩子中定义操作 . 如果没有,它应该有效地在模型函数内定义更新操作并直接将其传递给钩子 .