首页 文章

如何从party ::: ctree模型中删除训练数据?

提问于
浏览
7

我创建了几个ctree模型(大约40到80),我想要经常评估 .

一个问题是模型对象非常大(40个模型需要超过2.8G的内存),在我看来,他们存储了训练数据,可能是modelname @ data和modelname @ response,而不仅仅是相关的信息预测新数据 .

大多数其他R学习包具有可配置选项,是否将数据包含在模型对象中,但我在文档中找不到任何提示 . 我还尝试通过分配空的ModelEnv对象

modelname@data <- new("ModelEnv")

但是对各个RData文件的大小没有影响 .

任何人都知道ctree是否真的存储了训练数据以及如何从ctree模型中删除与新预测无关的所有数据,以便我可以将其中的许多数据放入内存中?

非常感谢,

斯特凡


感谢您的反馈,这已经非常有帮助 .

我使用了 dputstr 来深入了解对象,发现模型中没有包含任何训练数据,但是有一个 responses 插槽,它似乎有训练标签和rownames . 无论如何,我注意到每个节点都有一个每个训练样本的权重向量 . 经过一段时间检查代码后,我最后搜索了一下,并在 party NEWS日志中找到以下注释:

CHANGES IN party VERSION 0.9-13 (2007-07-23)

o   update `mvt.f'

o   improve the memory footprint of RandomForest objects
    substancially (by removing the weights slots from each node).

事实证明,派对包中有一个C函数可以删除这些称为 R_remove_weights 的权重,其定义如下:

SEXP R_remove_weights(SEXP subtree, SEXP removestats) {
    C_remove_weights(subtree, LOGICAL(removestats)[0]);
    return(R_NilValue);
}

它也工作正常:

# cc is my model object

sum(unlist(lapply(slotNames(cc), function (x)  object.size(slot(cc, x)))))
# returns: [1] 2521256
save(cc, file="cc_before.RData")

.Call("R_remove_weights", cc@tree, TRUE, PACKAGE="party")
# returns NULL and removes weights and node statistics

sum(unlist(lapply(slotNames(cc), function (x)  object.size(slot(cc, x)))))
# returns: [1] 1521392
save(cc, file="cc_after.RData")

正如您所看到的,它大大减小了对象大小,从大约2.5MB到1.5MB .

然而,奇怪的是,相应的RData文件非常庞大,并且对它们没有影响:

$ ls -lh cc*
-rw-r--r-- 1 user user 9.6M Aug 24 15:44 cc_after.RData
-rw-r--r-- 1 user user 9.6M Aug 24 15:43 cc_before.RData

解压缩文件显示2.5MB对象占用近100MB的空间:

$ cp cc_before.RData cc_before.gz
$ gunzip cc_before.gz 
$ ls -lh cc_before*
-rw-r--r-- 1 user user  98M Aug 24 15:45 cc_before

任何想法,是什么原因造成的?

2 回答

  • 1

    我找到了解决问题的方法,所以如果有人遇到同样的问题,我会写下这个答案 . 我会描述我的过程,所以它可能有点漫无边际,所以请耐心等待 .

    没有任何线索,我想到了插槽和删除权重以使对象尽可能小,并至少节省一些内存,以防万一找不到修复 . 所以我删除了 @data@responses 作为开始,如果没有它们,预测仍然很好,但对.RData文件大小没有影响 .

    我反过来创建并清空了ctree模型,只需将树插入其中:

    > library(party)
    
    ## create reference predictions for the dataset
    > predictions.org <- treeresponse(c1, d)
    
    ## save tree object for reference
    save(c1, "testSize_c1.RData")
    

    检查原始对象的大小:

    $ ls -lh testSize_c1.RData 
    -rw-r--r-- 1 user user 9.6M 2011-08-25 14:35 testSize_c1.RData
    

    现在,让我们创建一个空的CTree并仅复制树:

    ## extract the tree only 
    > c1Tree <- c1@tree
    
    ## create empty tree and plug in the extracted one 
    > newCTree <- new("BinaryTree")
    > newCTree@tree <- c1Tree
    
    ## save tree for reference 
    save(newCTree, file="testSize_newCTree.RData")
    

    这个新的树对象现在要小得多:

    $ ls -lh testSize_newCTree.RData 
    -rw-r--r-- 1 user user 108K 2011-08-25 14:35 testSize_newCTree.RData
    

    但是,它不能用于预测:

    ## predict with the new tree
    > predictions.new <- treeresponse(newCTree, d)
    Error in object@cond_distr_response(newdata = newdata, ...) : 
      unused argument(s) (newdata = newdata)
    

    我们没有设置 @cond_distr_response ,这可能会导致错误,所以复制原始错误并尝试再次预测:

    ## extract cond_distr_response from original tree
    > cdr <- c1@cond_distr_response
    > newCTree@cond_distr_response <- cdr
    
    ## save tree for reference 
    save(newCTree, file="testSize_newCTree_with_cdr.RData")
    
    ## predict with the new tree
    > predictions.new <- treeresponse(newCTree, d)
    
    ## check correctness
    > identical(predictions.org, predictions.new)
    [1] TRUE
    

    这很好用,但现在RData文件的大小恢复为原始值:

    $ ls -lh testSize_newCTree_with_cdr.RData 
    -rw-r--r-- 1 user user 9.6M 2011-08-25 14:37 testSize_newCTree_with_cdr.RData
    

    只需打印插槽,就可以将其显示为绑定到环境的功能:

    > c1@cond_distr_response
    function (newdata = NULL, mincriterion = 0, ...) 
    {
        wh <- RET@get_where(newdata = newdata, mincriterion = mincriterion)
        response <- object@responses
        if (any(response@is_censored)) {
            swh <- sort(unique(wh))
            RET <- vector(mode = "list", length = length(wh))
            resp <- response@variables[[1]]
            for (i in 1:length(swh)) {
                w <- weights * (where == swh[i])
                RET[wh == swh[i]] <- list(mysurvfit(resp, weights = w))
            }
            return(RET)
        }
        RET <- .Call("R_getpredictions", tree, wh, PACKAGE = "party")
        return(RET)
    }
    <environment: 0x44e8090>
    

    因此,初始问题的答案似乎是对象的方法将环境绑定到它,然后将其与对象一起保存在相应的RData文件中 . 这也可以解释为什么在读取RData文件时会加载几个包 .

    因此,为了摆脱环境,我们也可以在没有它们的情况下进行预测 . 相反"dirty"解决方案是模拟原始方法的功能并直接调用底层C代码 . 经过一些挖掘源代码,这确实是可能的 . 正如上面复制的代码所示,我们需要调用 get_where ,它确定输入到达的树的终端节点 . 然后,我们需要调用 R_getpredictions 来确定每个输入样本的终端节点的响应 . 棘手的部分是我们需要以正确的输入格式获取数据,因此必须调用ctree中包含的数据预处理:

    ## create a character string of the formula which was used to fit the free
    ## (there might be a more neat way to do this)
    > library(stringr)
    > org.formula <- str_c(
                       do.call(str_c, as.list(deparse(c1@data@formula$response[[2]]))),
                       "~", 
                       do.call(str_c, as.list(deparse(c1@data@formula$input[[2]]))))
    
    ## call the internal ctree preprocessing 
    > data.dpp <- party:::ctreedpp(as.formula(org.formula), d)
    
    ## create the data object necessary for the ctree C code
    > data.ivf <- party:::initVariableFrame.df(data.dpp@menv@get("input"), 
                                               trafo = ptrafo)
    
    ## now call the tree traversal routine, note that it only requires the tree
    ## extracted from the @tree slot, not the whole object
    > nodeID <- .Call("R_get_nodeID", c1Tree, data.ivf, 0, PACKAGE = "party")
    
    ## now determine the respective responses
    > predictions.syn <- .Call("R_getpredictions", c1Tree, nodeID, PACKAGE = "party")
    
    ## check correctness
    > identical(predictions.org, predictions.syn)
    [1] TRUE
    

    我们现在只需要保存提取的树和公式字符串,以便能够预测新数据:

    > save(c1Tree, org.formula, file="testSize_extractedObjects.RData")
    

    我们可以进一步删除上面更新的问题中描述的不必要的权重:

    > .Call("R_remove_weights", c1Tree, TRUE, PACKAGE="party")
    > save(c1Tree, org.formula, file="testSize_extractedObjects__removedWeights.RData")
    

    现在让我们再看一下文件大小:

    $ ls -lh testSize_extractedObjects*
    -rw-r--r-- 1 user user 109K 2011-08-25 15:31 testSize_extractedObjects.RData
    -rw-r--r-- 1 user user  43K 2011-08-25 15:31 testSize_extractedObjects__removedWeights.RData
    

    最后,代替(压缩)9.6M,只需要43K即可使用该模型 . 我现在应该可以在我的3G堆空间中尽可能多地适应 . 万岁!

  • 5

    你要找的是删除插槽 . 一个字谨慎:考虑到 party 函数如何处理对象,这可能相当危险 .

    尽管如此,请看一下 slotNames(yourModel) . 您也可以尝试 object.size(slot(yourModel), slotNameOfInterest) 来检查不同插槽的大小 . 您可以轻松创建排序表,以确保每个插槽中对象的大小 .

    在任何情况下, data 的插槽都是 ModelEnvFormula (我称之为"MEF")对象 . 您可以创建一个虚拟MEF: dummyMEF <- ModelEnvFormula(1 ~ 1) 然后将其分配给 dataslot(yourModel, "data") <- dummyMEF .

    那会破坏那个特定的插槽 . 您应该看看是否有其他插槽导致存储方面的麻烦 - object.size() 功能将有所帮助 . 我同意能够从模型对象中省略训练数据是很好的 .

相关问题