当我的网格搜索选择Tensorflow(版本1.12.0)模型的超参数因内存消耗爆炸而崩溃时,我注意到了这一点 .

请注意,与此处类似的问题不同,我会关闭图形和会话(使用上下文管理器),而不是在循环中向节点添加节点 .

我怀疑也许tensorflow维护了迭代之间没有清除的全局变量,因此我在迭代之前和之后调用了globals(),但是在每次迭代之前和之后都没有观察到全局变量集的任何差异 .

我做了一个小例子来重现问题 . 我在循环中训练一个简单的MNIST分类器并绘制进程消耗的内存:

import matplotlib.pyplot as plt
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import psutil
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
process = psutil.Process(os.getpid())

N_REPS = 100
N_ITER = 10
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
x_test, y_test = mnist.test.images, mnist.test.labels

# Runs experiment several times.
mem = []
for i in range(N_REPS):
    with tf.Graph().as_default():
        net = tf.contrib.layers.fully_connected(x_test, 200)
        logits = tf.contrib.layers.fully_connected(net, 10, activation_fn=None)
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_test, logits=logits))
        train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)
        init = tf.global_variables_initializer()
        with tf.Session() as sess:
            # training loop.
            sess.run(init)
            for _ in range(N_ITER):
                sess.run(train_op)
    mem.append(process.memory_info().rss)
plt.plot(range(N_REPS), mem)

结果情节看起来像这样:

在我的实际项目中,进程内存从几百MB(取决于数据集大小)开始,并且在我的系统内存不足之前上升到64 GB . 我尝试过一些可以减缓增长的事情,例如使用占位符和feed_dicts而不是依赖convert_to_tensor . 但是持续增长仍然存在,只是变慢 .