首页 文章

Tensorflow ValueError:没有要保存的变量(但有很多)

提问于
浏览
0

我试图从我从磁盘加载的图表中保存模型 . 我可以加载图表并检查它没有问题,并运行训练操作,但我不能创建一个没有得到ValueError的保护程序:没有要保存的变量 .

图形定义:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from pathlib import Path
import os
import tensorflow as tf

outdir = os.path.dirname(__file__)
outfile = Path(__file__).stem + ".pb"

print(os.path.join(outdir, outfile))

# The input is the state of a Tic Tac Toe game.
# This is represented as two length-9 Vec<u8>.
# The first plane holds the location of the first player's stones,
# The second plane, the second player's.
# A 19th byte holds 0 for first player, 1 for second player.
x = tf.placeholder(tf.uint8, shape=[None, 9 * 2 + 1], name ='x')

# Training makes makes the net more likely to pick the picked move.
# The picked move will be 1.0, the other 8 spaces will be 0.0.
y_true = tf.placeholder(tf.float32, shape=[None, 9], name='y_true')

dense = tf.layers.dense(tf.cast(x, tf.float32), units=64, activation=tf.nn.relu)
logits = tf.layers.dense(dense, units=9, activation=tf.nn.relu)
softmax = tf.nn.softmax(logits, name='softmax')

sess = tf.Session()
init = tf.variables_initializer(tf.global_variables(), name='init')
sess.run(init)

loss = tf.losses.mean_squared_error(labels=y_true, predictions=softmax)
optimizer = tf.train.GradientDescentOptimizer(.01)
train = optimizer.minimize(loss, name='train')

definition = tf.Session().graph_def
tf.train.write_graph(definition, outdir, outfile, as_text=False)

加载图表:

import tensorflow as tf
import glob

num_epochs = 100
minibatch_size = 128
dataset_dir = "src/tictactoe/gamedata"
model_dir = "src/tictactoe/simple_model/checkpoint"
graph_filename = "src/tictactoe/simple_net.pb"

def make_dataset(num_epochs, minibatch_size, dataset_dir):
    files = glob.glob("{}/*.tfrecord".format(dataset_dir))
    print("loading", files)
    dataset = tf.data.TFRecordDataset(files)
    dataset = dataset.map(parse)
    dataset = dataset.shuffle(buffer_size=100000)
    dataset = dataset.batch(minibatch_size)
    print("loaded data")
    return dataset

def parse(bytes):
  features = {"game": tf.FixedLenFeature((), tf.string),
              "choice": tf.FixedLenSequenceFeature((), tf.float32, allow_missing=True)}
  parsed_features = tf.parse_single_example(bytes, features)
  game = tf.decode_raw(parsed_features["game"], tf.uint8)
  choice =  parsed_features["choice"]
  return tf.reshape(game, [19]), tf.reshape(choice, [9])


with tf.gfile.FastGFile(graph_filename,'rb') as f:
    sess = tf.InteractiveSession()

    dataset = make_dataset(num_epochs, minibatch_size, dataset_dir)
    print("loading graph at '{}'".format(graph_filename))

    iterator = dataset.make_initializable_iterator()
    example, label = iterator.get_next()
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='',input_map={'x': example, 'y_true':label})

    init = tf.group(
        tf.global_variables_initializer(), 
        tf.local_variables_initializer(), 
        iterator.initializer, 
        sess.graph.get_operation_by_name('init'))

    train = sess.graph.get_operation_by_name('train')
    for name in [n.name for n in tf.get_default_graph().as_graph_def().node]:
        print(name)

    saver = tf.train.Saver()

    sess.run(init)

    for i in range(num_epochs):
        sess.run(iterator.initializer)

        while True:
            try:
                sess.run(train)
            except tf.errors.OutOfRangeError:
                break
        save_path = saver.save(sess, model_dir)
        print("Model saved in path: %s" % save_path)

Tensorflow抛出 saver = tf.train.Saver()

我试图确认图表已正确恢复,并且它包含的变量通过打印出Saver线上方默认图形中的所有变量而加载到当前默认图形中 . 那里有数百个,包括我在图表创建文件(x,y_true,train等)中手工命名的那些 .

相关问题似乎不是我的问题 . 例如,我找到的最接近的相关问题是:No variable to save error in Tensorflow

OP的问题是他的变量在错误的图表中 . 对于我的,只有一个图,它肯定包含变量 .

1 回答

  • 1

    如果你想让tensorflow识别变量,你需要导入元图; graphdef本身没有足够的信息来重建一切 . 查看tf.train.import_meta_graph的文档 .

相关问题