首页 文章

具有tensorflow的keras运行正常,直到我添加回调

提问于
浏览
3

我正在使用Keras和TensorFlow后端运行模型 . 一切都很完美:

model = Sequential()
model.add(Dense(dim, input_dim=dim, activation='relu'))
model.add(Dense(200, activation='relu'))
model.add(Dense(1, activation='linear'))

model.compile(loss='mse', optimizer='Adam', metrics=['mae'])

history = model.fit(X, Y, epochs=12, 
                    batch_size=100, 
                    validation_split=0.2, 
                    shuffle=True, 
                    verbose=2)

但是一旦我包含 Logger 和回调,所以我可以登录到tensorboard,我得到了

InvalidArgumentError(请参见上面的回溯):您必须为占位符张量'input_layer_input_2'提供一个值,其中dtype为float和shape [?,1329] ...

这是我的代码:(实际上,它工作了一次,第一次,然后ecer从那以后得到了这个错误)

model = Sequential()
model.add(Dense(dim, input_dim=dim, activation='relu'))
model.add(Dense(200, activation='relu'))
model.add(Dense(1, activation='linear'))

model.compile(loss='mse', optimizer='Adam', metrics=['mae'])

logger = keras.callbacks.TensorBoard(log_dir='/tf_logs',
                                     write_graph=True,
                                     histogram_freq=1)

history = model.fit(X, Y, 
                    epochs=12,
                    batch_size=100,
                    validation_split=0.2,
                    shuffle=True,
                    verbose=2,
                    callbacks=[logger])

1 回答

  • 2

    tensorboard 回调使用 tf.summary.merge_all 函数来收集直方图计算的所有张量 . 因此 - 您的摘要是从以前的模型中收集张量,而不是从以前的模型运行中清除 . 为了清除这些以前的型号,请尝试:

    from keras import backend as K
    
    K.clear_session()
    
    model = Sequential()
    model.add(Dense(dim, input_dim=dim, activation='relu'))
    model.add(Dense(200, activation='relu'))
    model.add(Dense(1, activation='linear'))
    
    model.compile(loss='mse', optimizer='Adam', metrics=['mae'])
    
    logger = keras.callbacks.TensorBoard(log_dir='/tf_logs',
                                     write_graph=True,
                                     histogram_freq=1)
    
    history = model.fit(X, Y, 
                    epochs=12,
                    batch_size=100,
                    validation_split=0.2,
                    shuffle=True,
                    verbose=2,
                    callbacks=[logger])
    

相关问题