首页 文章

来自插入符号中训练数据的ROC曲线

提问于
浏览
19

使用R包插入符号,如何根据train()函数的交叉验证结果生成ROC曲线?

说,我做以下事情:

data(Sonar)
ctrl <- trainControl(method="cv", 
  summaryFunction=twoClassSummary, 
  classProbs=T)
rfFit <- train(Class ~ ., data=Sonar, 
  method="rf", preProc=c("center", "scale"), 
  trControl=ctrl)

训练函数超过一系列mtry参数并计算ROC AUC . 我想看看相关的ROC曲线 - 我该怎么做?

注意:如果用于采样的方法是LOOCV,那么 rfFit 将在 rfFit$pred 槽中包含一个非空数据帧,这似乎正是我所需要的 . 但是,我需要"cv"方法(k倍验证)而不是LOO .

另外:不,之前版本的插入符号中包含的 roc 函数不是答案 - 这是一个低级函数,您可以为每个交叉验证的样本提供预测概率 .

2 回答

  • 27

    ctrl 中只缺少 savePredictions = TRUE 参数(这也适用于其他重采样方法):

    library(caret)
    library(mlbench)
    data(Sonar)
    ctrl <- trainControl(method="cv", 
                         summaryFunction=twoClassSummary, 
                         classProbs=T,
                         savePredictions = T)
    rfFit <- train(Class ~ ., data=Sonar, 
                   method="rf", preProc=c("center", "scale"), 
                   trControl=ctrl)
    library(pROC)
    # Select a parameter setting
    selectedIndices <- rfFit$pred$mtry == 2
    # Plot:
    plot.roc(rfFit$pred$obs[selectedIndices],
             rfFit$pred$M[selectedIndices])
    

    ROC

    也许我错过了一些东西,但是一个小问题是 train 总是估计AUC值略微不同于 plot.rocpROC::auc (绝对差<0.005),尽管 twoClassSummary 使用 pROC::auc 来估算AUC . Edit: 我认为这是因为 train 中的ROC是使用单独的CV集的AUC的平均值,这里我们同时计算所有重采样的AUC以获得总AUC .

    Update 由于这引起了一些关注,这是一个使用 plotROC::geom_roc() for ggplot2 的解决方案:

    library(ggplot2)
    library(plotROC)
    ggplot(rfFit$pred[selectedIndices, ], 
           aes(m = M, d = factor(obs, levels = c("R", "M")))) + 
        geom_roc(hjust = -0.4, vjust = 1.5) + coord_equal()
    

    ggplot_roc

  • 10

    在这里,我正在修改@ thei1e的情节,其他人可能会觉得有帮助 .

    训练模型并做出预测

    library(caret)
    library(ggplot2)
    library(mlbench)
    library(plotROC)
    
    data(Sonar)
    
    ctrl <- trainControl(method="cv", summaryFunction=twoClassSummary, classProbs=T,
                         savePredictions = T)
    
    rfFit <- train(Class ~ ., data=Sonar, method="rf", preProc=c("center", "scale"), 
                   trControl=ctrl)
    
    # Select a parameter setting
    selectedIndices <- rfFit$pred$mtry == 2
    

    更新了ROC曲线图

    g <- ggplot(rfFit$pred[selectedIndices, ], aes(m=M, d=factor(obs, levels = c("R", "M")))) + 
      geom_roc(n.cuts=0) + 
      coord_equal() +
      style_roc()
    
    g + annotate("text", x=0.75, y=0.25, label=paste("AUC =", round((calc_auc(g))$AUC, 4)))
    

    enter image description here

相关问题