首页 文章

如何在r中的rpart()中关闭k折叠交叉验证

提问于
浏览
0

我有比特币时间序列,我使用11个技术指标作为功能,我想为数据拟合回归树 . 据我所知,r中有2个函数可以创建回归树,即rpart()和tree(),但这两个函数似乎都不合适 . rpart()使用k-fold交叉验证来验证最优成本复杂度参数cp,并且在tree()中,不可能指定cp的值 .

我知道cv.tree()通过交叉验证查找cp的最佳值,但cv.tee()再次使用k-fold交叉验证 . 由于我有时间序列,因此有时间依赖性,我不想使用k折交叉验证,因为k折交叉验证会将数据随机分成k倍,在k-1折叠上拟合模型并计算左边的MSE在第k折,然后我的时间序列的序列明显被破坏了 .

我找到了一个rpart()函数的参数,即xval,它应该让我指定交叉验证的数量,但是当我在xval = 0时查看rpart()函数调用的输出时,它不会似乎交叉验证已关闭 . 下面你可以看到我的函数调用和输出:

tree.model= rpart(Close_5~ M+ DSMA+ DWMA+ DEMA+ CCI+ RSI+ DKD+ R+ FI+ DVI+ 
OBV, data= train.subset, method= "anova", control= 
rpart.control(cp=0.01,xval= 0, minbucket = 5))

> summary(tree.model)
Call:
rpart(formula = Close_5 ~ M + DSMA + DWMA + DEMA + CCI + RSI + 
DKD + R + FI + DVI + OBV, data = train.subset, method = "anova", 
control = rpart.control(cp = 0.01, xval = 0, minbucket = 5))
n= 590 

           CP nsplit rel error
1  0.35433076      0 1.0000000
2  0.10981049      1 0.6456692
3  0.06070669      2 0.5358587
4  0.04154720      3 0.4751521
5  0.02415633      5 0.3920576
6  0.02265346      6 0.3679013
7  0.02139752      8 0.3225944
8  0.02096500      9 0.3011969
9  0.02086543     10 0.2802319
10 0.01675277     11 0.2593665
11 0.01551861     13 0.2258609
12 0.01388126     14 0.2103423
13 0.01161287     15 0.1964610
14 0.01127722     16 0.1848482
15 0.01000000     18 0.1622937

似乎rpart()交叉验证了15个不同的cp值 . 如果用k折交叉验证测试这些值,那么我的时间序列的序列将被破坏,我基本上不能使用这些结果 . 有谁知道如何有效地关闭rpart()中的交叉验证,或者如何在tree()中改变cp的值?

更新:我按照我们的一位同事的建议设置了xval = 1,但这似乎没有解决问题 . 当xval = 1 here时,您可以看到完整的函数输出 . 顺便说一句,参数[j]是参数向量的第j个元素 . 当我调用此函数时,参数[j] = 0.0009765625

提前谢谢了

1 回答

  • 1

    为了证明 rpart() 正在通过迭代 cp 的下降值与重新采样来创建树节点,我们将使用 mlbench 包中的 Ozone 数据来比较 rpart()caret::train() 的结果,如对OP的注释中所讨论的那样 . 我们将设置臭氧数据,如Support Vector Machines的CRAN文档中所示,它支持非线性回归并且与 rpart() 相当 .

    library(rpart)
    library(caret)
    data(Ozone, package = "mlbench")
    # split into test and training
    index <- 1:nrow(Ozone)
    set.seed(01381708)
    testIndex <- sample(index, trunc(length(index) / 3))
    testset <- na.omit(Ozone[testIndex,-3])
    trainset <- na.omit(Ozone[-testIndex,-3])
    
    
    # rpart version
    set.seed(95014) #reset seed to ensure sample is same as caret version
    rpart.model <- rpart(V4 ~ .,data = trainset,xval=0)
    # summary(rpart.model)
    # calculate RMSE
    rpart.pred <- predict(rpart.model, testset[,-3])
    crossprod(rpart.pred - testset[,3]) / length(testIndex)
    

    ...以及RMSE计算的输出:

    > crossprod(rpart.pred - testset[,3]) / length(testIndex)
             [,1]
    [1,] 18.25507
    

    接下来,我们将根据对OP的评论中提出的 caret::train() 运行相同的分析 .

    # caret version
    set.seed(95014)
    rpart.model <- caret::train(x = trainset[,-3],
                                y = trainset[,3],method = "rpart", trControl = trainControl(method = "none"), 
                                metric = "RMSE", tuneGrid = data.frame(cp=0.01), 
                                preProcess = c("center", "scale"), xval = 0, minbucket = 5)
    # summary(rpart.model)
    # demonstrate caret version did not do resampling
    rpart.model
    # calculate RMSE, which matches RMSE from rpart() 
    rpart.pred <- predict(rpart.model, testset[,-3])
    crossprod(rpart.pred - testset[,3]) / length(testIndex)
    

    当我们从 caret::train() 打印模型输出时,它清楚地指出没有重新采样 .

    > rpart.model
    CART 
    
    135 samples
     11 predictor
    
    Pre-processing: centered (9), scaled (9), ignore (2) 
    Resampling: None
    

    caret::train() 版本的RMSE与 rpart() 的RMSE匹配 .

    > # calculate RMSE, which matches RMSE from rpart() 
    > rpart.pred <- predict(rpart.model, testset[,-3])
    > crossprod(rpart.pred - testset[,3]) / length(testIndex)
             [,1]
    [1,] 18.25507
    >
    

    结论

    首先,如上所述, caret::train()rpart() 都不进行重采样 . 但是,如果打印出模型输出,则会看到多个 cp 值用于通过两种技术生成47个节点的最终树 .

    _002_来自插入符摘要的输出(rpart.model)

    CP nsplit rel error
    1 0.58951537      0 1.0000000
    2 0.08544094      1 0.4104846
    3 0.05237152      2 0.3250437
    4 0.04686890      3 0.2726722
    5 0.03603843      4 0.2258033
    6 0.02651451      5 0.1897648
    7 0.02194866      6 0.1632503
    8 0.01000000      7 0.1413017
    

    rpart摘要输出(rpart.model)

    CP nsplit rel error
    1 0.58951537      0 1.0000000
    2 0.08544094      1 0.4104846
    3 0.05237152      2 0.3250437
    4 0.04686890      3 0.2726722
    5 0.03603843      4 0.2258033
    6 0.02651451      5 0.1897648
    7 0.02194866      6 0.1632503
    8 0.01000000      7 0.1413017
    

    其次,两个模型通过包含 monthday 变量作为自变量来解释时间值 . 在 Ozone 数据集中, V1 是月份变量, V2 是日期变量 . 所有数据都是在1976年期间收集的,因此数据集中没有包含年份变量,并且在原始分析中,在分析之前,星期几被丢弃了.638908_小插图 .

    第三,考虑到使用 rpart()svm() 等算法的其他基于时间的效果,当日期属性未用作模型中的要素时,必须将滞后效应作为模型中的要素包含,因为这些算法不直接考虑时间组件 . 如何使用一系列滞后值对回归树集合进行此操作的一个示例是Ensemble Regression Trees for Time Series Predictions .

相关问题