首页 文章

具有分类输入的回归树或随机森林回归量

提问于
浏览
11

我一直试图在回归树(或随机森林回归器)中使用分类的inpust,但sklearn不断返回错误并要求输入数字 .

import sklearn as sk
MODEL = sk.ensemble.RandomForestRegressor(n_estimators=100)
MODEL.fit([('a',1,2),('b',2,3),('a',3,2),('b',1,3)], [1,2.5,3,4]) # does not work
MODEL.fit([(1,1,2),(2,2,3),(1,3,2),(2,1,3)], [1,2.5,3,4]) #works

MODEL = sk.tree.DecisionTreeRegressor()
MODEL.fit([('a',1,2),('b',2,3),('a',3,2),('b',1,3)], [1,2.5,3,4]) # does not work
MODEL.fit([(1,1,2),(2,2,3),(1,3,2),(2,1,3)], [1,2.5,3,4]) #works

据我了解,这些方法中的分类输入应该是可能的,没有任何转换(例如WOE替代) .

有没有其他人有这个困难?

谢谢!

2 回答

  • 16

    scikit-learn 没有分类变量的专用表示(在R中是a.k.a因子),一种可能的解决方案是使用 LabelEncoder 将字符串编码为 int

    import numpy as np
    from sklearn.preprocessing import LabelEncoder  
    from sklearn.ensemble import RandomForestRegressor
    
    X = np.asarray([('a',1,2),('b',2,3),('a',3,2),('c',1,3)]) 
    y = np.asarray([1,2.5,3,4])
    
    # transform 1st column to numbers
    X[:, 0] = LabelEncoder().fit_transform(X[:,0]) 
    
    regressor = RandomForestRegressor(n_estimators=150, min_samples_split=2)
    regressor.fit(X, y)
    print(X)
    print(regressor.predict(X))
    

    输出:

    [[ 0.  1.  2.]
     [ 1.  2.  3.]
     [ 0.  3.  2.]
     [ 2.  1.  3.]]
    [ 1.61333333  2.13666667  2.53333333  2.95333333]
    

    但请记住,如果 ab 是独立的类别,这只是一个轻微的黑客,它只适用于基于树的估算器 . 为什么?因为 b 并不比 a 大 . 正确的方法是在 LabelEncoderpd.get_dummies 之后使用 OneHotEncoder ,为 X[:, 0] 产生两个单独的,一个热的编码列 .

    import numpy as np
    from sklearn.preprocessing import LabelEncoder, OneHotEncoder
    from sklearn.ensemble import RandomForestRegressor
    
    X = np.asarray([('a',1,2),('b',2,3),('a',3,2),('c',1,3)]) 
    y = np.asarray([1,2.5,3,4])
    
    # transform 1st column to numbers
    import pandas as pd
    X_0 = pd.get_dummies(X[:, 0]).values
    X = np.column_stack([X_0, X[:, 1:]])
    
    regressor = RandomForestRegressor(n_estimators=150, min_samples_split=2)
    regressor.fit(X, y)
    print(X)
    print(regressor.predict(X))
    
  • 1

    你必须在python中手动编写代码 . 我建议使用pandas.get_dummies()进行一次热编码 . 对于Boosted树,我使用factorize()成功实现了Ordinal编码 .

    对于这种事情,还有一整套包装here .

    有关更详细的说明,请参阅this Data Science Stack Exchange帖子 .

相关问题