首页 文章

使用GridSearchCV的随机森林 - param_grid出错

提问于
浏览
10

我试图用GridSearchCV创建一个随机森林模型,但我得到一个与param_grid有关的错误:"ValueError: Invalid parameter max_features for estimator Pipeline. Check the list of available parameters with `estimator.get_params().keys()" . 我正在对文档进行分类,所以我也将tf-idf矢量化器推送到管道 . 这是代码:

from sklearn import metrics
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, f1_score, accuracy_score, precision_score, confusion_matrix
from sklearn.pipeline import Pipeline

 #Classifier Pipeline
pipeline = Pipeline([
    ('tfidf', TfidfVectorizer()),
    ('classifier', RandomForestClassifier())
])
# Params for classifier
params = {"max_depth": [3, None],
              "max_features": [1, 3, 10],
              "min_samples_split": [1, 3, 10],
              "min_samples_leaf": [1, 3, 10],
              # "bootstrap": [True, False],
              "criterion": ["gini", "entropy"]}

# Grid Search Execute
rf_grid = GridSearchCV(estimator=pipeline , param_grid=params) #cv=10
rf_detector = rf_grid.fit(X_train, Y_train)
print(rf_grid.grid_scores_)

我无法弄清楚错误显示的原因 . 当我使用GridSearchCV运行决策树时,正在发生相同的顺便说一句 . (Scikit-learn 0.17)

2 回答

  • 5

    您必须将参数分配给管道中的命名步骤 . 在你的情况下 classifier . 尝试将 classifier__ 添加到参数名称 . Sample pipeline

    params = {"classifier__max_depth": [3, None],
                  "classifier__max_features": [1, 3, 10],
                  "classifier__min_samples_split": [1, 3, 10],
                  "classifier__min_samples_leaf": [1, 3, 10],
                  # "bootstrap": [True, False],
                  "classifier__criterion": ["gini", "entropy"]}
    
  • 14

    尝试在 final 管道对象上运行get_params(),而不仅仅是估算器 . 这样它就会为网格参数生成 all 可用的管道项 unique 键 .

    sorted(pipeline.get_params().keys())
    

    ['classifier','classifier__bootstrap','classifier__class_weight','classifier__criterion','classifier__max_depth', 'classifier__max_features', 'classifier__max_leaf_nodes','classifier__min_impurity_split','classifier__min_samples_leaf','classifier__min_samples_split','classifier__min_weight_fraction_leaf','classifier__n_estimators','classifier__n_jobs','classifier__oob_score','classifier__random_state','classifier__verbose','classifier__warm_start','steps','tfidf','tfidf__analyzer','tfidf__binary','tfidf__decode_error','tfidf__dtype','tfidf__encoding','tfidf__input', 'tfidf__lowercase','tfidf__max_df','tfidf__max_features','tfidf__min_df','tfidf__ngram_range','tfidf__norm','tfidf__preprocessor','tfidf__smooth_idf','tfidf__stop_words','tfidf__strip_accents','tfidf__sublinear_tf','tfidf__token_pattern','tfidf__tokenizer','tfidf__use_idf','tfidf__vocabulary']

    当您在Piplines中使用简短的make_pipeline()语法时,这非常有用,您不需要为管道项目添加标签:

    pipeline = make_pipeline(TfidfVectorizer(), RandomForestClassifier())
    sorted(pipeline.get_params().keys())
    

    ['randomforestclassifier','randomforestclassifier__bootstrap','randomforestclassifier__class_weight','randomforestclassifier__criterion','randomforestclassifier__max_depth', 'randomforestclassifier__max_features', 'randomforestclassifier__max_leaf_nodes','randomforestclassifier__min_impurity_split','randomforestclassifier__min_samples_leaf','randomforestclassifier__min_samples_split','randomforestclassifier__min_weight_fraction_leaf','randomforestclassifier__n_estimators','randomforestclassifier__n_jobs','randomforestclassifier__oob_score','randomforestclassifier__random_state','randomforestclassifier__verbose','randomforestclassifier__warm_start','steps','tfidfvectorizer','tfidfvectorizer__analyzer','tfidfvectorizer__binary','tfidfvectorizer__decode_error','tfidfvectorizer__dtype','tfidfvectorizer__encoding','tfidfvectorizer__input', 'tfidfvectorizer__lowercase','tfidfvectorizer__max_df','tfidfvectorizer__max_features','tfidfvectorizer__min_df','tfidfvectorizer__ngram_range','tfidfvectorizer__norm','tfidfvectorizer__preprocessor','tfidfvectorizer__smooth_idf','tfidfvectorizer__stop_words','tfidfvectorizer__strip_accents','tfidfvectorizer__sublinear_tf','tfidfvectorizer__token_pattern','tfidfvectorizer__tokenizer','tfidfvectorizer__use_idf','tfidfvectorizer__vocabulary']

相关问题