首页 文章

keras tensorflow load_weights失败

提问于
浏览
1

我正在使用keras 1.2和tensorflow 1.0.0后端 .

我有一个函数从json加载预先校准的模型,然后从hdf5文件加载其权重 .

def load():
    model = model_from_json(open(model_path).read())
    model.load_weights(model_weights_path)

此函数,更确切地说是对 load_weights 的调用导致以下异常:

RuntimeError: The Session graph is empty.  Add operations to the graph before calling run()

我想知道这是否是由于我在模块的开头设置了tensorflow种子以获得再现性的这些行:

tf.set_random_seed(123) # To set Tensorflow seed
sess = tf.Session()
keras.backend.set_session(sess)

似乎keras会话不会自动将加载的模型设置为与会话关联的图形,因此无法初始化权重 .

避免异常的任何解释和解决方法?

2 回答

  • 0

    我几乎使用与你相同的代码,它对我有用 .

    from keras.models import Sequential
            from keras.models import Model
            from keras.layers import Dense, Dropout, Activation, Flatten, Input, GlobalAveragePooling2D
            from keras.optimizers import RMSprop
            from keras.utils import np_utils
            from keras.models import model_from_json
            from keras.layers import Convolution2D, MaxPooling2D
            from keras.layers.pooling import AveragePooling2D
            from keras.layers.normalization import BatchNormalization
            from keras.layers.convolutional import ZeroPadding2D
            from keras.engine.topology import Merge
            from keras.layers import merge
            from keras.optimizers import Adam
            from keras import backend as K
            from keras.layers.pooling import MaxPooling2D
            from keras.layers.convolutional import ZeroPadding2D
    
            import PIL
            import inception
            import tensorflow as tf
            import keras
            import glob
            import pandas as pd
            import pickle
            import numpy as np
            import matplotlib.pyplot as plt
            from PIL import Image
    
         # load json and create model
            json_file = open('model.json', 'r')
            loaded_model_json = json_file.read()
            json_file.close()
            model = model_from_json(loaded_model_json)
            # load weights into new model
            model.load_weights("model.h5")
            print("Loaded model from disk")
    
    model.summary()
    model.compile(Adam(lr=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
    
    score = model.predict(transfer_values_test)
    
  • 0

    事实上,在加载模型时,Keras似乎不尊重set_session设置的会话 .

    尝试强制Keras使用Tensorflow的上下文管理器的特定会话:

    def load():
        with sess.as_default():    
            model = model_from_json(open(model_path).read())
            model.load_weights(model_weights_path)''
    

    如果Keras仍然抱怨,请预定义图形( graph=tf.Graph() )并强制model.load_weights通过引入额外的 with 语句来使用它:

    def load():
        with graph.as_default():
            with sess.as_default():    
                model = model_from_json(open(model_path).read())
                model.load_weights(model_weights_path)''
    

相关问题