我有一个简单的目标,即在tensorflow保存中训练模型并在以后恢复它,以便继续训练或使用某些功能/操作 .
这是模型的简单示例
import tensorflow as tf
import numpy as np
BATCH_SIZE = 3
VECTOR_SIZE = 1
LEARNING_RATE = 0.1
x = tf.placeholder(tf.float32, [BATCH_SIZE, VECTOR_SIZE],
name='input_placeholder')
y = tf.placeholder(tf.float32, [BATCH_SIZE, VECTOR_SIZE],
name='labels_placeholder')
W = tf.get_variable('W', [VECTOR_SIZE, BATCH_SIZE])
b = tf.get_variable('b', [VECTOR_SIZE], initializer=tf.constant_initializer(0.0))
y_hat = tf.matmul(W, x) + b
predict = tf.matmul(W, x) + b
total_loss = tf.reduce_mean(y-y_hat)
train_step = tf.train.AdagradOptimizer(LEARNING_RATE).minimize(total_loss)
X = np.ones([BATCH_SIZE, VECTOR_SIZE])
Y = np.ones([BATCH_SIZE, VECTOR_SIZE])
all_saver = tf.train.Saver()
sess= tf.Session()
sess.run(tf.global_variables_initializer())
sess.run([train_step], feed_dict = {x: X, y:Y}))
save_path = r'C:\some_path\save\\'
all_saver.save(sess,save_path)
现在我们在这里恢复它:
meta_path = r'C:\some_path\save\.meta'
new_all_saver = tf.train.import_meta_graph(meta_path)
graph = tf.get_default_graph()
all_ops = graph.get_operations()
for el in all_ops:
print(el)
在恢复的操作中,甚至无法从原始代码中找到 predict
或 train_step
. 我需要在保存之前命名此操作吗?我怎样才能得到 predict
并运行这样的东西
sess=tf.Session()
sess.run([predict], feed_dict = {x:X})
附:我在tensorflow中阅读了很多关于保存和恢复的教程,但是仍然很难理解它是如何工作的 .
1 回答
1)您的操作存在于已恢复的模型中,但由于您尚未对它们进行命名,因此将根据某些默认规则对它们进行命名 . 例如,因为你有:
然后代表
predict
的操作可能如下所示:在此示例中,当您执行
for el in all_ops:
并打印结果时打印,您会看到操作的名称是"add",它是自动分配的;操作类型("op")是"Add",它对应于代码行中执行的最后一次操作(即);输入是"MatMul"和"b/read",对应于你的总和 . 为了清楚起见,我不确定这个操作只对应于给定的代码行,因为打印中存在具有相同类型输入的其他添加,但这是可能的 .所以总结到现在为止:你的操作在那里,你在打印时看到它们 . 但为什么你不看“预测”这个词?好吧,因为这不是Tensorflow图中张量或操作的名称,它只是代码中变量的名称 .
展望未来,你怎么能访问这个“预测”?答案是通过它的名称,如图中所示 . 在上面的例子中,预测的名称可以是“添加”,如果我对我的猜测是正确的,但是让我们将其命名为“预测”,这样您就可以轻松控制哪个操作对应于它 .
为了命名"predict",让我们在
predict = tf.matmul(W, x) + b
下面添加以下代码行:这一行正在做的是创建一个新操作,它接收“predict”中定义的操作作为输入,并产生一个等于输入结果的输出 . 操作本身没有做太多 - 只是重复一个值 - 但通过这个操作,我可以为它添加一个名称 . 现在,如果您在打印中搜索,您将能够找到:
太好了!现在你1)可以使用名称“here_i_put_a_name”访问你的“预测”,2)我们可以确认你的“预测”实际上是名为“add_1”的操作 - 只需检查操作的“输入”属性上方“here_i_put_a_name” .
完成后,让我们访问操作“here_i_put_a_name”并完成一些预测 . 首先,更改save_path和meta_path,最后放置一个可能的文件名,例如:
然后,在还原代码的最后,添加:
使用此块,您将创建一个新的Tensorflow会话并使用存储在变量“graph”中的图形 . 在此上下文中,您正在将会话从save_path还原到当前会话 . 然后,您正在运行预测,或者更确切地说,您正在运行操作“here_i_put_a_name”并获取此操作的第一个输出(之后我们有“:0”) . feed dict将值[[1],[2],[3]]赋给张量“input_placeholder:0”(同样,“:0”表示这是张量,而不是操作) .
随着所有这些以及(希望)回答的问题,我有一些意见:
1)根据我的经验,使用库
tf.saved_model
来保存和恢复模块会很不错 . 但这是我个人的建议 .2)我限制自己回答有关命名和调用操作的问题,所以我忽略了训练和预测例程 . 但是当我把变量X和BATCH_SIZE作为大小时,我认为你在解决这个问题 .
3)注意“blabla”和“blabla:0”之间的区别 . 第一个是操作,最后一个是张量 .