我刚刚将Tensorflow的本地安装更新为0.11rc2,我收到一条消息,说我应该为我的保护程序添加一个参数,使其保存在版本2.我更新了这个,现在我无法加载以这种格式保存的模型 . 当我运行我的模型时,它会在每个时代之后保存 . 保存时,它用于保存名为 translate.ckpt-3916
和 translate.ckpt-3916.meta
的文件 . 现在我得到三个文件而不是两个,名为 translate.ckpt-3916.index
, translate.ckpt-3916.meta
和 translate.ckpt-3916.data-000000-of-000001
.
要加载数据,我使用以下代码:
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
model.saver.restore(session, ckpt.model_checkpoint_path)
else:
print("Created model with fresh parameters.")
session.run(tf.initialize_all_variables())
return model
其中 model
是已使用我的程序的标准超参数初始化的模型对象 . 这与saver v1没有问题 . ckpt.model_checkpoint_path
计算到 translate.ckpt-3916
的路径,无论版本如何,因此如果检查点是使用v2保存的,则不会找到任何文件 .
该目录中 checkpoint
文件的内容(使用任一版本保存时)为:
model_checkpoint_path: "translate.ckpt-3916"
all_model_checkpoint_paths: "translate.ckpt-3916"
是否有一种新方法可以使用saver v2加载数据?否则,我如何加载检查点?
编辑:将 if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
更改为 if ckpt and ckpt.model_checkpoint_path:
之类的行显示在this question似乎进一步工作但随后抛出以下错误:
InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [84] rhs shape= [98]
[[Node: save/Assign_54 = Assign[T=DT_FLOAT, _class=["loc:@NLC/Logistic/Linear/Bias"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](NLC/Logistic/Linear/Bias, save/RestoreV2_54)]]
1 回答
我在编辑中发布的方法实际上是使其正常工作的正确方法 . 我得到的错误是因为数据在我创建检查点和我尝试加载它之间发生了变化 .
只是为了使它可见,从上面的代码中的V2检查点加载是通过将行
if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
更改为if ckpt and ckpt.model_checkpoint_path:
来完成的