首页 文章

循环中的Estimator.predict导致张量流中的内存泄漏

提问于
浏览
2

当我使用tensorflow estimator.predict 时,这发生在我身上 . 说,我通过以下方式从保存的模型中加载估算器:

estimator = tf.contrib.learn.Estimator(
    model_fn=model_fn, model_dir=FLAGS.model_dir, config=run_cfg)

get_input_fn() 将返回 input_fn ,如下所示:

def get_input_fn(arg1, arg2):
    def input_fn():
        # do something
        #    ....
        return features, None
    return input_fn

然后,循环将用于预测 file_iter 的所有输入,如下所示:

for idx, data in enumerate(file_iter):
    predicts = estimator.predict(input_fn=get_input_fn(data['query'],
                                                    data['responses']))

这将导致内存泄漏 . 每次调用 estimator.predict 后,内存会增加一点,但永远不会下降 . 我使用 objgraph 调试我的代码,并在每次调用 estimator.predict 后找到一些引用计数增加 .

我真的不知道 estimator.predict 的见解 . 我想问题可能是因为我不止一次调用input_fn . 我的张量流的版本是v1.2 .


[更新]

这里是 objgraph 的结果,左边是在调用 estimator.predict 之前,mid是在调用它之后,右边是另一个调用结果 . 如我所见, tuplelistdic 在每次调用 estimator.predict 后稍微增加一点 . 我没有绘制参考图,因为我不熟悉它 .

objgraph.show_most_common_types()    
tuple            146247 | tuple            180157   | tuple            213976
list             60745  | list             73107    | list             86111
dict             43412  | dict             50925    | dict             58437
function         28482  | function         28497    | function         28512
TensorShapeProto 9434   | TensorShapeProto 11793    | TensorShapeProto 14152
Dimension        8286   | Dimension        10360    | Dimension        12434
Operation        6098   | Operation        7625     | Operation        9152
AttrValue        6098   | NodeDef          7625     | NodeDef          9152
NodeDef          6098   | TensorShape      7575     | TensorShape      9092
TensorShape      6058   | Tensor           7575     | Tensor           9092

2 回答

  • 1

    最后,我发现这是由于调用太多 tf.convert_to_tensor 引起的,每次调用该函数都会在tensorflow图中生成一个新节点,这需要一些内存 .

    要解决此问题,只需使用 tf.placeholder 来提供数据 . 此外,tensorflow v1.3添加了一个新方法 tf.contrib.predictor 来执行此操作 . 阅读更多https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/predictor

  • 0

    你能发布你的objgraph的结果吗?如果这是一个张量流问题或一般的python问题,它将有助于清楚 .

相关问题