首页 文章

keras:返回model.summary()vs scikit学习包装器

提问于
浏览
0

在使用keras时,我了解到使用包装器会对keras产生负面影响,并且scikit会学习api请求 . 我对两者都有解决方案感兴趣 .

Variant 1: scikit Wrapper

from keras.wrappers.scikit_learn import KerasClassifier

    def model():
        model = Sequential()
        model.add(Dense(10, input_dim=4, activation='relu'))
        model.add(Dense(3, activation='softmax'))
        model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
        return model

estimator = KerasClassifier(build_fn=model, epochs=100, batch_size=5)
model.fit(X, y)

-> 这让我可以打印scikit命令,例如accuracy_score()或classification_report() . 但是,model.summary()不起作用:

AttributeError:'KerasClassifier'对象没有属性'summary'

Variant 2: No Wrapper

model = Sequential()
model.add(Dense(10, input_dim=4, activation='relu'))
model.add(Dense(3, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X, y, epochs=100, batch_size=5)

-> 这让我打印model.summary()但不打印scikit命令 .

ValueError:混合类型的y不允许,得到类型{'multiclass','multilabel-indicator'}

有没有办法可以同时使用两者?

2 回答

  • 2

    KerasClassifier 只是 keras 中实际 Model 的包装器,因此keras api的实际方法可以路由到scikit中使用的方法,因此它可以与scikit实用程序一起使用 . 但在内部它只使用可以使用 estimator.model 访问的模型 .

    说明以上内容的示例:

    from keras.models import Sequential
    from keras.layers import Dense
    from keras.wrappers.scikit_learn import KerasClassifier
    from sklearn.datasets import make_classification
    def model():
        model = Sequential()
        model.add(Dense(10, input_dim=20, activation='relu'))
        model.add(Dense(2, activation='softmax'))
        model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
        return model
    
    estimator = KerasClassifier(build_fn=model, epochs=100, batch_size=5)
    X, y = make_classification()
    estimator.fit(X, y)
    
    # This is what you need
    estimator.model.summary()
    

    这个输出是:

    Layer (type)                 Output Shape              Param #   
    =================================================================
    dense_9 (Dense)              (None, 10)                210       
    _________________________________________________________________
    dense_10 (Dense)             (None, 2)                 22        
    =================================================================
    Total params: 232
    Trainable params: 232
    Non-trainable params: 0
    _________________________________________________________________
    
  • 0

    摘要的功能在此库中: from keras. models import Model 您可以看到:
    for example

    Figure

相关问题