首页 文章

访问插入符号中自定义度量函数的每个CV折叠的索引

提问于
浏览
1

我想在 caret 中定义我的自定义度量函数,但在此函数中我想使用不用于训练的其他信息 . 因此,我需要在此折叠中使用的数据的索引(行号)进行验证 .

这是一个愚蠢的例子:

生成数据:

library(caret)
set.seed(1234)

x <- matrix(rnorm(10),nrow=5,ncol=2 )
y <- factor(c("y","n","y","y","n"))

priors <- c(1,3,2,7,9)

这是我的示例度量函数,它应该使用来自 priors 向量的信息

my.metric <- function (data,
                   lev = NULL,
                   model = NULL) {
          out <- priors[-->INDICES.OF.DATA<--] + data$pred/data$obs   
          names(out) <- "MYMEASURE"
          out
}

myControl <- trainControl(summaryFunction = my.metricm, method="repeatedcv", number=10, repeats=2)

fit <- train(y=y,x=x, metric = "MYMEASURE",method="gbm", trControl = mControl)

为了使这个可能更加清晰,我可以在生存环境中使用它,其中 priors 是天,并在 Surv 对象中使用它来测量度量函数中的生存AUC .

我怎么能在插入符号中这样做?

1 回答

  • 1

    您可以使用 data$rowIndex 访问行号 . 请注意,摘要函数应返回单个数字作为其度量(例如ROC,Accuracy,RMSE ...) . 上述函数似乎返回一个长度等于保持的CV数据中的观察数的向量 .

    如果您有兴趣查看重新采样及其预测,可以将 print(data) 添加到 my.metric 函数中 .

    下面是使用您的数据(放大一点)和 Metrics::auc 作为预测类概率乘以前的概率之后的性能指标的示例:

    library(caret)
    library(Metrics)
    
    set.seed(1234)
    x <- matrix(rnorm(100), nrow=100, ncol=2 )
    set.seed(1234)
    y <- factor(sample(x = c("y", "n"), size = 100, replace = T))
    
    priors <- runif(n = length(y), min = 0.1, max = 0.9)
    
    my.metric <- function(data, lev = NULL, model = NULL) 
    {
        # The performance metric should be a single number
        # data$y are the predicted probabilities of  
        # the observations in the fold belonging to class "y"
        out <- Metrics::auc(actual = as.numeric(data$obs == "y"),
                            predicted = priors[data$rowIndex] * data$y)
        names(out) <- "MYMEASURE"
        out
    }
    
    fitControl <- trainControl(method = "repeatedcv",
                               number = 10,
                               classProbs = T,
                               repeats = 2,
                               summaryFunction = my.metric)
    
    set.seed(1234)
    fit <- train(y = y, 
                 x = x,
                 metric = "MYMEASURE",
                 method="gbm", 
                 verbose = FALSE,
                 trControl = fitControl)
    fit
    
    # Stochastic Gradient Boosting 
    # 
    # 100 samples
    # 2 predictor
    # 2 classes: 'n', 'y' 
    # 
    # No pre-processing
    # Resampling: Cross-Validated (10 fold, repeated 2 times) 
    # 
    # Summary of sample sizes: 90, 90, 90, 90, 90, 89, ... 
    # 
    # Resampling results across tuning parameters:
    #     
    # interaction.depth  n.trees  MYMEASURE  MYMEASURE SD
    # 1                   50      0.5551667  0.2348496   
    # 1                  100      0.5682500  0.2297383   
    # 1                  150      0.5797500  0.2274042   
    # 2                   50      0.5789167  0.2246845   
    # 2                  100      0.5941667  0.2053826   
    # 2                  150      0.5900833  0.2186712   
    # 3                   50      0.5750833  0.2291999   
    # 3                  100      0.5488333  0.2312470   
    # 3                  150      0.5577500  0.2202638   
    # 
    # Tuning parameter 'shrinkage' was held constant at a value of 0.1
    # Tuning parameter 'n.minobsinnode' was held constant at a value of 10
    # MYMEASURE was used to select the optimal model using  the largest value.
    

    我不太了解生存分析,但我希望这会有所帮助 .

相关问题