首页 文章

在scikit-learn中混淆嵌套交叉验证的例子

提问于
浏览
3

我正在从scikit-learn文档中查看此示例:http://scikit-learn.org/0.18/auto_examples/model_selection/plot_nested_cross_validation_iris.html

在我看来,交叉验证不是以无偏见的方式在这里执行的 . GridSearchCV (据说是内部CV循环)和 cross_val_score (假设是外部CV循环)都使用相同的数据和相同的折叠 . 因此,分类器训练和评估的数据之间存在重叠 . 我错了什么?

3 回答

  • 0

    @Gael - 由于我无法添加评论,我将在答案部分发布此内容 . 我不确定盖尔的意思是"the first split is done inside cross_val_score, and the second split is done inside GridSearchCV (that the whole point of the GridSearchCV object)" . 您是否试图暗示cross_val_score函数将(k-1) - 折叠数据(用于外循环训练)传递给clf对象?这似乎并非如此,因为我可以注释掉cross_val_score函数并将nested_score [i]设置为虚拟变量,并且仍然获得完全相同的clf.best_score_ . 这意味着GridSearchCV是单独评估的,并且确实使用 all 可用数据,而不是训练数据的子集 .

    在嵌套的CV中,据我所知,内部循环将在较小的训练数据子集上进行超参数搜索,然后外部循环将使用这些参数进行交叉验证 . 在内循环中使用较小训练数据的原因之一是避免信息泄漏 . 这似乎不是这里发生的事情 . 内循环首先使用所有数据来搜索超参数,然后将其用于外循环中的交叉验证 . 因此,内环已经看到所有数据,并且在外环中进行的任何测试都将遭受信息泄漏 . 如果我弄错了,你能不能指出你在答案中提到的代码部分?

  • 0

    完全同意,嵌套-cv过程是错误的,cross_val_score采用GridSearchCV计算的最佳超参数,并使用这样的超参数计算cv分数 . 在nested-cv中,您需要外部循环来评估模型性能和内部循环以进行模型选择,这样,内部循环中用于模型选择的数据部分不能与用于评估模型性能的部分相同 . 一个例子是用于评估性能的LOOCV外环(或者,它将是5cv,10cv或任何你喜欢的),以及在内环中使用网格搜索进行模型选择的10cv倍 . 这意味着,如果您有N个观测值,那么您将在内部循环中执行模型选择(例如,使用网格搜索和10-CV)对N-1观测值进行模型选择,您将在LOO观测值上评估模型性能(或者如果您选择其他方法,则在保留数据样本中) . (请注意,您正在内部估计超参数中的N个最佳模型) . 访问cross_val_score和GridSearchCV代码的链接会很有帮助 . 嵌套CV的一些参考是:

    • Christophe Ambroise和Georey J McLachlan . 基于微阵列基因表达数据的基因提取中的选择偏倚 . 国家科学院院刊99,10(2002),6562-6566 .

    • Gavin C Cawley和Nicola LC Talbot . 关于模型选择中的过度拟合和性能评估中的后续选择偏差 . 机器学习研究杂志11,Jul(2010),2079 {2107 .

    注意:我在cross_val_score的文档中没有找到任何内容,表明在内部使用参数搜索,网格搜索交叉验证,例如k-1数据折叠,以及在保留时使用这些优化参数来优化超参数数据样本(我所说的与http://scikit-learn.org/dev/auto_examples/model_selection/plot_nested_cross_validation_iris.html中的代码不同)

  • 2

    他们没有使用相同的数据 . 当然,示例的代码并没有明显,因为拆分是不可见的:第一个拆分在cross_val_score内完成,第二个拆分在GridSearchCV内完成(GridSearchCV对象的整个点) . 使用函数和对象而不是手写的循环可能会使事情变得不那么透明,但它:

    • 启用重用

    • 添加许多会使for循环繁琐的"little things",例如并行计算,支持不同的评分函数等 .

    • 在避免数据泄漏方面实际上更安全,因为我们的拆分代码已被多次审核 .

    如果您不相信,请查看cross_val_score和GridSearchCV的代码 .

    该最近改进的示例在评论中指定了这个:http://scikit-learn.org/dev/auto_examples/model_selection/plot_nested_cross_validation_iris.html

    (在https://github.com/scikit-learn/scikit-learn/pull/7949上拉请求)

相关问题