首页 文章

Keras GridSearch scikit学习冻结

提问于
浏览
0

我很难使用scikit learn在Keras中实现网格搜索 . 基于this tutorial,我编写了以下代码:

from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import GridSearchCV

    def create_model():
            model = Sequential()
            model.add(Dense(100, input_shape=(max_len, len(alphabet)), kernel_regularizer=regularizers.l2(0.001)))
            model.add(Dropout(0.85))
            model.add(LSTM(100, input_shape=(100,))) 
            model.add(Dropout(0.85))
            model.add(Dense(num_output_classes, activation='softmax'))

            adam = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, decay=1e-6)

            model.compile(loss='categorical_crossentropy',
                      optimizer=adam,
                      metrics=['accuracy']) 

            return model

    seed = 7
    np.random.seed(seed)

    model = KerasClassifier(build_fn=create_model, epochs=10, verbose=0)

    batch_size = [10,20]
    param_grid = dict(batch_size=batch_size)
    grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1)
    grid_result = grid.fit(train_data_reduced, train_labels_reduced)

    print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
    means = grid_result.cv_results_['mean_test_score']
    stds = grid_result.cv_results_['std_test_score']
    params = grid_result.cv_results_['params']
    for mean, stdev, param in zip(means, stds, params):
        print("%f (%f) with: %r" % (mean, stdev, param))

它没有给我任何错误消息,但它只是永远运行而不打印任何东西 . 我故意用很少的时代,非常少的训练样例和很少的超参数进行搜索 . 没有网格搜索,一个时代非常快,所以我不认为我只需要给它更多的时间 . 它根本不做任何事情 .

任何人都可以指出我错过了什么?

非常感谢!

1 回答

  • 0

    我有同样的问题 .

    从参数列表中删除 n_jobs=-1 可能会有所帮助!也尝试不做热编码 .

相关问题