首页 文章

使用scikit-learn并行生成随机森林

提问于
浏览
9

主要问题:如何在python和scikit-learn中组合不同的randomForests?

我目前正在使用R中的randomForest包来使用弹性贴图reduce生成randomforest对象 . 这是为了解决分类问题 .

由于我的输入数据太大而无法放入一台机器的内存中,因此我将数据采样为较小的数据集并生成包含较小树集的随机林对象 . 然后,我使用修改的组合函数将不同的树组合在一起,以创建新的随机森林对象 . 此随机林对象包含特征重要性和最终树集 . 这不包括oob错误或树的投票 .

虽然这在R中运行良好,但我想在Python中使用scikit-learn做同样的事情 . 我可以创建不同的随机森林对象,但我没有办法将它们组合在一起形成一个新对象 . 任何人都可以指出一个可以结合森林的功能吗?这可能是使用scikit-learn吗?

以下是关于如何在R中进行此过程的问题的链接:Combining random forests built with different training sets in R .

编辑:生成的随机森林对象应包含可用于预测的树以及特征重要性 .

任何帮助,将不胜感激 .

2 回答

  • 8

    当然,只是聚合所有的树,例如从pyrallel看这个片段:

    def combine(all_ensembles):
        """Combine the sub-estimators of a group of ensembles
    
            >>> from sklearn.datasets import load_iris
            >>> from sklearn.ensemble import ExtraTreesClassifier
            >>> iris = load_iris()
            >>> X, y = iris.data, iris.target
    
            >>> all_ensembles = [ExtraTreesClassifier(n_estimators=4).fit(X, y)
            ...                  for i in range(3)]
            >>> big = combine(all_ensembles)
            >>> len(big.estimators_)
            12
            >>> big.n_estimators
            12
            >>> big.score(X, y)
            1.0
    
        """
        final_ensemble = copy(all_ensembles[0])
        final_ensemble.estimators_ = []
    
        for ensemble in all_ensembles:
            final_ensemble.estimators_ += ensemble.estimators_
    
        # Required in old versions of sklearn
        final_ensemble.n_estimators = len(final_ensemble.estimators_)
    
        return final_ensemble
    
  • 2

    根据您的编辑,听起来您只询问如何提取特征重要性并查看随机森林中使用的各个树 . 如果是这样,这两个都是您的随机森林模型的属性,分别名为“feature_importances_”和“estimators_” . 可以在下面找到说明这一点的示例:

    >>> from sklearn.ensemble import RandomForestClassifier
    >>> from sklearn.datasets import make_blobs
    >>> X, y = make_blobs(n_samples=10000, n_features=10, centers=100,random_state=0)
    >>> clf = RandomForestClassifier(n_estimators=5, max_depth=None, min_samples_split=1, random_state=0)
    >>> clf.fit(X,y)
    RandomForestClassifier(bootstrap=True, compute_importances=None,
                criterion='gini', max_depth=None, max_features='auto',
                min_density=None, min_samples_leaf=1, min_samples_split=1,
                n_estimators=5, n_jobs=1, oob_score=False, random_state=0,
                verbose=0)
    >>> clf.feature_importances_
    array([ 0.09396245,  0.07052027,  0.09951226,  0.09095071,  0.08926362,
            0.112209  ,  0.09137607,  0.11771107,  0.11297425,  0.1215203 ])
    >>> clf.estimators_
    [DecisionTreeClassifier(compute_importances=None, criterion='gini',
                max_depth=None, max_features='auto', min_density=None,
                min_samples_leaf=1, min_samples_split=1,
                random_state=<mtrand.RandomState object at 0x2b6f62d9b408>,
                splitter='best'), DecisionTreeClassifier(compute_importances=None, criterion='gini',
                max_depth=None, max_features='auto', min_density=None,
                min_samples_leaf=1, min_samples_split=1,
                random_state=<mtrand.RandomState object at 0x2b6f62d9b3f0>,
                splitter='best'), DecisionTreeClassifier(compute_importances=None, criterion='gini',
                max_depth=None, max_features='auto', min_density=None,
                min_samples_leaf=1, min_samples_split=1,
                random_state=<mtrand.RandomState object at 0x2b6f62d9b420>,
                splitter='best'), DecisionTreeClassifier(compute_importances=None, criterion='gini',
                max_depth=None, max_features='auto', min_density=None,
                min_samples_leaf=1, min_samples_split=1,
                random_state=<mtrand.RandomState object at 0x2b6f62d9b438>,
                splitter='best'), DecisionTreeClassifier(compute_importances=None, criterion='gini',
                max_depth=None, max_features='auto', min_density=None,
                min_samples_leaf=1, min_samples_split=1,
                random_state=<mtrand.RandomState object at 0x2b6f62d9b450>,
                splitter='best')]
    

相关问题