首页 文章

scikit learn:train_test_split,我可以确保在不同的数据集上进行相同的拆分

提问于
浏览
6

据我所知,train_test_split方法将数据集拆分为随机序列和测试子集 . 并且使用random_state = int可以确保每次调用方法时我们在此数据集上都有相同的拆分 .

我的问题略有不同 .

我有两个数据集,A和B,它们包含相同的示例集,每个数据集中出现的这些示例的顺序也相同 . 但它们的关键区别在于每个数据集中的exmaples使用不同的功能集 .

我想测试一下A中使用的功能是否比B中使用的功能更好 . 所以我想确保当我在A和B上调用train_test_split时,我可以在两个数据集上获得相同的分割,以便比较是有意义的 .

这可能吗?我是否只需要确保两个数据集的两个方法调用中的random_state都相同?

谢谢

2 回答

  • 6

    是的,随机状态就足够了 .

    >>> X, y = np.arange(10).reshape((5, 2)), range(5)
    >>> X2 = np.hstack((X,X))
    >>> X_train, X_test, _, _ = train_test_split(X,y, test_size=0.33, random_state=42)
    >>> X_train2, X_test2, _, _ = train_test_split(X2,y, test_size=0.33, random_state=42)
    >>> X_train
    array([[4, 5],
           [0, 1],
           [6, 7]])
    >>> X_train2
    array([[4, 5, 4, 5],
           [0, 1, 0, 1],
           [6, 7, 6, 7]])
    >>> X_test
    array([[2, 3],
           [8, 9]])
    >>> X_test2
    array([[2, 3, 2, 3],
           [8, 9, 8, 9]])
    
  • 2

    查看 train_test_split 函数的代码,它会在每次调用时在函数内设置随机种子 . 因此每次都会产生相同的分割 . 我们可以检查这很简单

    X1 = np.random.random((200, 5))
    X2 = np.random.random((200, 5))
    y = np.arange(200)
    
    X1_train, X1_test, y1_train, y1_test = model_selection.train_test_split(X1, y,
                                                                            test_size=0.1,
                                                                            random_state=42)
    X2_train, X2_test, y2_train, y2_test = model_selection.train_test_split(X1, y,
                                                                            test_size=0.1,
                                                                            random_state=42)
    
    print np.all(y1_train == y2_train)
    print np.all(y1_test == y2_test)
    

    哪个输出:

    True
    True
    

    哪个好!解决此问题的另一种方法是在所有功能上创建一个训练和测试分割,然后在训练之前分割您的功能 . 但是,如果您想要训练集中的测试功能,那么您可以使用 StratifiedShuffleSplit 函数返回属于每个集合的数据的索引 . 例如:

    n_splits = 1 
    sss = model_selection.StratifiedShuffleSplit(n_splits=n_splits, 
                                                 test_size=0.1,
                                                 random_state=42)
    train_idx, test_idx = list(sss.split(X, y))[0]
    

相关问题