首页 文章

Tensorflow如何管理图表?

提问于
浏览
1

我已经意识到Tensorflow似乎正在管理图形的方式正在发生一些时髦的事情 .

由于构建(和重建)模型是如此繁琐,我决定将我的自定义模型包装在一个类中,以便我可以轻松地在其他地方重新实例化它 .

当我训练和测试代码时(在原始位置)它可以正常工作,但是在我加载图形变量的代码中,我会得到各种奇怪的错误 - 变量重新定义和其他所有内容 . 这(从我关于类似事情的最后一个问题)是暗示一切都被调用了两次 .

在进行TON跟踪之后,它归结为我使用加载代码的方式 . 它是在具有类似结构的类中使用的

class MyModelUser(object):
    def forecast(self):
       # .. build the model in the same way as in the training code
       # load the model checkpoint
       # call the "predict" function on the model
       # manipulate the prediction and return it

然后在一些使用 MyModelUser 的代码中

def test_the_model(self):
   model_user = MyModelUser()
   print(model_user.forecast())  # 1
   print(model_user.forecast())  # 2

而且我(显然)预计会看到两个预测 . 相反,第一个预测被调用并按预期工作,但第二个调用抛出了一个变量重用Value的TON,其中一个例子是:

ValueError: Variable weight_def/weights already exists, disallowed. Did you mean to set reuse=True in VarScope?

我设法通过添加一系列try / except块来平息错误,这些块使用 get_variable 来创建变量,然后在异常上,在作用域上调用 reuse_variables 然后 get_variable ,除了名称之外没有任何东西 . 这带来了一系列新的令人讨厌的错误,其中之一是:

tensorflow.python.framework.errors.NotFoundError: Tensor name "weight_def/weights/Adam_1" not found in checkpoint files

我突然想到“如果我将建模构建代码移动到 __init__ ,那么它只构建一次会怎么样?”

我的新模特用户:

class MyModelUser(object):
    def __init__(self):
       # ... build the model in the same way as in the training code
       # load the model checkpoint


    def forecast(self):
       # call the "predict" function on the model
       # manipulate the prediction and return it

现在:

def test_the_model(self):
   model_user = MyModelUser()
   print(model_user.forecast())  # 1
   print(model_user.forecast())  # 2

按预期工作,打印两个没有错误的预测 . 这让我相信我也可以摆脱变量重用的东西 .

我的问题是:

为什么要修复它?理论上,图表应该在原始预测方法中每次重新安装,因此不应该创建多个图表 . 即使在函数完成后,Tensorflow是否仍然保留图形?这是为什么将创建代码移动到 __init__ 工作的原因?这让我无可救药地困惑 .

2 回答

  • 3

    默认情况下,TensorFlow使用在您第一次调用TensorFlow API时创建的单个全局tf.Graph实例 . 如果未显式创建 tf.Graph ,则将在该默认实例中创建所有操作,张量和变量 . 这意味着您的代码中的每个调用都将 model_user.forecast() 添加到同一个全局图中,这有点浪费 .

    这里有(至少)两种可能的行动方案:

    • 理想的操作是重构代码,以便 MyModelUser.__init__() 使用执行预测所需的所有操作构建整个 tf.GraphMyModelUser.forecast() 只是在现有图形上执行 sess.run() 调用 . 理想情况下,您也只能创建一个 tf.Session ,因为TensorFlow会在会话中缓存有关图形的信息,并且执行效率会更高 .

    • 侵入性较小但可能效率较低的变化是为每次调用 MyModelUser.forecast() 创建一个新的 tf.Graph . 从问题中不清楚 MyModelUser.__init__() 方法中创建了多少状态,但您可以执行以下操作将两个调用放在不同的图中:

    def test_the_model(self):
      with tf.Graph():  # Create a local graph
        model_user_1 = MyModelUser()
        print(model_user_1.forecast())
      with tf.Graph():  # Create another local graph
        model_user_2 = MyModelUser()
        print(model_user_2.forecast())
    
  • 0

    TF有一个默认图表,可以添加新的操作等 . 当你两次调用你的函数时,你会将相同的东西两次添加到同一个图形中 . 因此,要么构建一次图形并多次评估它(正如您所做的那样,这也是"normal"方法),或者,如果您想要更改内容,可以使用reset_default_graph https://www.tensorflow.org/versions/r0.11/api_docs/python/framework.html#reset_default_graph重置图形以便重新获得州 .

相关问题