当我使用tensorflow estimator.predict
时,这发生在我身上 . 说,我通过以下方式从保存的模型中加载估算器:
estimator = tf.contrib.learn.Estimator(
model_fn=model_fn, model_dir=FLAGS.model_dir, config=run_cfg)
get_input_fn()
将返回 input_fn
,如下所示:
def get_input_fn(arg1, arg2):
def input_fn():
# do something
# ....
return features, None
return input_fn
然后,循环将用于预测 file_iter
的所有输入,如下所示:
for idx, data in enumerate(file_iter):
predicts = estimator.predict(input_fn=get_input_fn(data['query'],
data['responses']))
这将导致内存泄漏 . 每次调用 estimator.predict
后,内存会增加一点,但永远不会下降 . 我使用 objgraph
调试我的代码,并在每次调用 estimator.predict
后找到一些引用计数增加 .
我真的不知道 estimator.predict
的见解 . 我想问题可能是因为我不止一次调用input_fn . 我的张量流的版本是v1.2 .
[更新]
这里是 objgraph
的结果,左边是在调用 estimator.predict
之前,mid是在调用它之后,右边是另一个调用结果 . 如我所见, tuple
, list
, dic
在每次调用 estimator.predict
后稍微增加一点 . 我没有绘制参考图,因为我不熟悉它 .
objgraph.show_most_common_types()
tuple 146247 | tuple 180157 | tuple 213976
list 60745 | list 73107 | list 86111
dict 43412 | dict 50925 | dict 58437
function 28482 | function 28497 | function 28512
TensorShapeProto 9434 | TensorShapeProto 11793 | TensorShapeProto 14152
Dimension 8286 | Dimension 10360 | Dimension 12434
Operation 6098 | Operation 7625 | Operation 9152
AttrValue 6098 | NodeDef 7625 | NodeDef 9152
NodeDef 6098 | TensorShape 7575 | TensorShape 9092
TensorShape 6058 | Tensor 7575 | Tensor 9092
2 回答
最后,我发现这是由于调用太多
tf.convert_to_tensor
引起的,每次调用该函数都会在tensorflow图中生成一个新节点,这需要一些内存 .要解决此问题,只需使用
tf.placeholder
来提供数据 . 此外,tensorflow v1.3添加了一个新方法tf.contrib.predictor
来执行此操作 . 阅读更多https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/predictor你能发布你的objgraph的结果吗?如果这是一个张量流问题或一般的python问题,它将有助于清楚 .