首页 文章

使用sklearn LogisticRegression和RandomForest模型的Predict()总是预测少数类(1)

提问于
浏览
1

我正在 Build 一个Logistic回归模型来预测一个事务是否有效(1)或不是(0),只有150个观察数据集 . 我的数据在两个类之间分配如下:

  • 106观测值为0(无效)

  • 44个观测值为1(有效)

我正在使用两个预测器(均为数值) . 尽管数据大部分为0,但我的分类器仅为我的测试集中的每个事务预测1,即使它们中的大多数应为0.分类器从不为任何观察输出0 .

这是我的整个代码:

# Logistic Regression
import numpy as np
import pandas as pd
from pandas import Series, DataFrame

import scipy
from scipy.stats import spearmanr
from pylab import rcParams
import seaborn as sb
import matplotlib.pyplot as plt
import sklearn
from sklearn.preprocessing import scale
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn import metrics
from sklearn import preprocessing

address = "dummy_csv-150.csv"
trades = pd.read_csv(address)
trades.columns=['location','app','el','rp','rule1','rule2','rule3','validity','transactions']
trades.head()

trade_data = trades.ix[:,(1,8)].values
trade_data_names = ['app','transactions']

# set dependent/response variable
y = trades.ix[:,7].values

# center around the data mean
X= scale(trade_data)

LogReg = LogisticRegression()

LogReg.fit(X,y)
print(LogReg.score(X,y))

y_pred = LogReg.predict(X)

from sklearn.metrics import classification_report

print(classification_report(y,y_pred)) 

log_prediction = LogReg.predict_log_proba(
    [
       [2, 14],[3,1], [1, 503],[1, 122],[1, 101],[1, 610],[1, 2120],[3, 85],[3, 91],[2, 167],[2, 553],[2, 144]
    ])
prediction = LogReg.predict([[2, 14],[3,1], [1, 503],[1, 122],[1, 101],[1, 610],[1, 2120],[3, 85],[3, 91],[2, 167],[2, 553],[2, 144]])

我的模型定义为:

LogReg = LogisticRegression()  
LogReg.fit(X,y)

其中X看起来像这样:

X = array([[1, 345],
       [1, 222],
       [1, 500],
       [2, 120]]....)

每次观察,Y只有0或1 .

传递给模型的规范化X是这样的:

[[-1.67177659  0.14396503]
 [-1.67177659 -0.14538932]
 [-1.67177659  0.50859856]
 [-1.67177659 -0.3853417 ]
 [-1.67177659 -0.43239119]
 [-1.67177659  0.743846  ]
 [-1.67177659  4.32195953]
 [ 0.95657805 -0.46062089]
 [ 0.95657805 -0.45591594]
 [ 0.95657805 -0.37828428]
 [ 0.95657805 -0.52884264]
 [ 0.95657805 -0.20420118]
 [ 0.95657805 -0.63705646]
 [ 0.95657805 -0.65587626]
 [ 0.95657805 -0.66763863]
 [-0.35759927 -0.25125067]
 [-0.35759927  0.60975496]
 [-0.35759927 -0.33358727]
 [-0.35759927 -0.20420118]
 [-0.35759927  1.37195666]
 [-0.35759927  0.27805607]
 [-0.35759927  0.09456307]
 [-0.35759927  0.03810368]
 [-0.35759927 -0.41121892]
 [-0.35759927 -0.64411389]
 [-0.35759927 -0.69586832]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.53825254]
 [ 0.95657805 -0.53354759]
 [ 0.95657805 -0.52413769]
 [ 0.95657805 -0.57589213]
 [ 0.95657805  0.03810368]
 [ 0.95657805 -0.66293368]
 [ 0.95657805  2.86107294]
 [-1.67177659  0.14396503]
 [-1.67177659 -0.14538932]
 [-1.67177659  0.50859856]
 [-1.67177659 -0.3853417 ]
 [-1.67177659 -0.43239119]
 [-1.67177659  0.743846  ]
 [-1.67177659  4.32195953]
 [ 0.95657805 -0.46062089]
 [ 0.95657805 -0.45591594]
 [ 0.95657805 -0.37828428]
 [ 0.95657805 -0.52884264]
 [ 0.95657805 -0.20420118]
 [ 0.95657805 -0.63705646]
 [ 0.95657805 -0.65587626]
 [ 0.95657805 -0.66763863]
 [-0.35759927 -0.25125067]
 [-0.35759927  0.60975496]
 [-0.35759927 -0.33358727]
 [-0.35759927 -0.20420118]
 [-0.35759927  1.37195666]
 [-0.35759927  0.27805607]
 [-0.35759927  0.09456307]
 [-0.35759927  0.03810368]
 [-0.35759927 -0.41121892]
 [-0.35759927 -0.64411389]
 [-0.35759927 -0.69586832]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.53825254]
 [ 0.95657805 -0.53354759]
 [ 0.95657805 -0.52413769]
 [ 0.95657805 -0.57589213]
 [ 0.95657805  0.03810368]
 [ 0.95657805 -0.66293368]
 [ 0.95657805  2.86107294]
 [-1.67177659  0.14396503]
 [-1.67177659 -0.14538932]
 [-1.67177659  0.50859856]
 [-1.67177659 -0.3853417 ]
 [-1.67177659 -0.43239119]
 [-1.67177659  0.743846  ]
 [-1.67177659  4.32195953]
 [ 0.95657805 -0.46062089]
 [ 0.95657805 -0.45591594]
 [ 0.95657805 -0.37828428]
 [ 0.95657805 -0.52884264]
 [ 0.95657805 -0.20420118]
 [ 0.95657805 -0.63705646]
 [ 0.95657805 -0.65587626]
 [ 0.95657805 -0.66763863]
 [-0.35759927 -0.25125067]
 [-0.35759927  0.60975496]
 [-0.35759927 -0.33358727]
 [-0.35759927 -0.20420118]
 [-0.35759927  1.37195666]
 [-0.35759927  0.27805607]
 [-0.35759927  0.09456307]
 [-0.35759927  0.03810368]
 [-0.35759927 -0.41121892]
 [-0.35759927 -0.64411389]
 [-0.35759927 -0.69586832]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.53825254]
 [ 0.95657805 -0.53354759]
 [ 0.95657805 -0.52413769]
 [ 0.95657805 -0.57589213]
 [ 0.95657805  0.03810368]
 [ 0.95657805 -0.66293368]
 [ 0.95657805  2.86107294]
 [-1.67177659  0.14396503]
 [-1.67177659 -0.14538932]
 [-1.67177659  0.50859856]
 [-1.67177659 -0.3853417 ]
 [-1.67177659 -0.43239119]
 [-1.67177659  0.743846  ]
 [-1.67177659  4.32195953]
 [ 0.95657805 -0.46062089]
 [ 0.95657805 -0.45591594]
 [ 0.95657805 -0.37828428]
 [ 0.95657805 -0.52884264]
 [ 0.95657805 -0.20420118]
 [ 0.95657805 -0.63705646]
 [ 0.95657805 -0.65587626]
 [ 0.95657805 -0.66763863]
 [-0.35759927 -0.25125067]
 [-0.35759927  0.60975496]
 [-0.35759927 -0.33358727]
 [-0.35759927 -0.20420118]
 [-0.35759927  1.37195666]
 [-0.35759927  0.27805607]
 [-0.35759927  0.09456307]
 [-0.35759927  0.03810368]
 [-0.35759927 -0.41121892]
 [-0.35759927 -0.64411389]
 [-0.35759927 -0.69586832]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.53825254]
 [ 0.95657805 -0.53354759]
 [ 0.95657805 -0.52413769]
 [ 0.95657805 -0.57589213]
 [ 0.95657805  0.03810368]
 [ 0.95657805 -0.66293368]
 [ 0.95657805  2.86107294]
 [-0.35759927  0.60975496]
 [-0.35759927 -0.33358727]
 [-0.35759927 -0.20420118]
 [-0.35759927  1.37195666]
 [-0.35759927  0.27805607]
 [-0.35759927  0.09456307]
 [-0.35759927  0.03810368]]

和Y是:

[0 0 0 0 0 0 1 1 0 0 0 1 1 1 1 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0 0 0 1 1 1 0 0
 0 0 0 0 1 1 0 0 0 1 1 1 1 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0 0 0 1 1 1 0 0 0 0
 0 0 1 1 0 0 0 1 1 1 1 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0
 1 1 0 0 0 1 1 1 1 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0 0 0 1 1 1 0 0 0 1 0 0 0]

模型指标是:

precision    recall  f1-score   support

          0       0.78      1.00      0.88        98
          1       1.00      0.43      0.60        49

avg / total       0.85      0.81      0.78       147

分数为 0.80

当我运行model.predict_log_proba(test_data)时,我得到的概率区间如下所示:

array([[ -1.10164032e+01,  -1.64301095e-05],
       [ -2.06326947e+00,  -1.35863187e-01],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00]])

我的测试集是,除了2之外的所有都应该是0,但它们都被归类为1.这对于每个测试集都会发生,即使是那些具有模型训练值的测试集 .

[2, 14],[3,1], [1, 503],[1, 122],[1, 101],[1, 610],[1, 2120],[3, 85],[3, 91],[2, 167],[2, 553],[2, 144]

我在这里找到了一个类似的问题:https://stats.stackexchange.com/questions/168929/logistic-regression-is-predicting-all-1-and-no-0但是在这个问题中,问题似乎是数据大部分是1 's so it made sense the model would ouput 1s. My case is the opposite because the train data is mostly 0' s但是由于某种原因我的模型总是输出1 's for everything even though 1' s相对较少 . 我还尝试了一个随机森林分类器来查看模型是否错误,但同样的事情发生了 . 也许这是我的数据,但我不会因为它符合所有假设而错误 .

可能有什么不对?数据满足逻辑模型的所有假设(两个预测器都是独立的,输出是二进制的,没有丢失的数据点) . 任何建议表示赞赏 .

1 回答

  • 1

    您没有缩放 test 数据 . 执行此操作时,缩放列车数据是正确的:

    X= scale(trade_data)
    

    训练模型后,您不会对测试数据执行相同的操作:

    log_prediction = LogReg.predict_log_proba(
    [
       [2, 14],[3,1], [1, 503],[1, 122],[1, 101],[1, 610],[1, 2120],[3, 85],[3, 91],[2, 167],[2, 553],[2, 144]
    ])
    

    Build 模型的系数,期望归一化输入 . 您的测试数据未规范化 . 模型的任何正系数都将乘以一个巨大的数字,因为您的数据未按比例缩放可能会使您的预测值全部为1 .

    一般规则是,您在训练集上进行的任何转换同样应在您的测试集上进行 . 您还应该使用测试集在训练集上应用相同的转换 . 代替:

    X = scale(trade_data)
    

    您应该从训练数据中创建一个缩放器,如下所示:

    scaler = StandardScaler().fit(trade_date)
    X = scaler.transform(trade_data)
    

    然后将该缩放器应用于 test 数据:

    scaled_test = scaler.transform(test_x)
    

相关问题