我正在尝试使用tf.train.Saver()来保存和恢复张量流图 . 为了保存图表,我正在进行如下操作:
saver = tf.train.Saver()
init_op = tf.global_variables_initializer()
sess.run(init_op)
with sess.as_default():
for ep in range(epoch):
train_step.run(feed_dict={X: X_train,
Y: y_train.reshape(-1,1)})
saver.save(sess, 'my_test_model')
要恢复模型,我这样做:
with tf.Session() as sess:
saver.restore(sess, 'my_test_model')
test_accuracy = sess.run(x,feed_dict={X: X_valid,
Y: y_valid.reshape(-1,1)})
mse_test=loss.eval(feed_dict={X: X_valid,
Y: y_valid.reshape(-1,1)})
print(mse_test)
虽然在同一个笔记本中这样做,一切都很完美 . 但是当我试图在另一个笔记本中做同样的事情时,问题就开始了:在新笔记本中没有定义保护程序所以我试图再次将其定义为tf.train.Saver()并得到错误ValueError:No要保存的变量 .
我也试过了
saver = tf.train.import_meta_graph('my_test_model.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
但图形变量未保存在元文件中 .
有人遇到过类似的问题吗?我将不胜感激任何帮助 .
提前致谢!