首页 文章

基于列的sklearn分层抽样

提问于
浏览
11

我有一个包含亚马逊评论数据的相当大的CSV文件,我将其读入大熊猫数据框 . 我想将数据分成80-20(训练测试),但在这样做时我想确保分割数据按比例代表一列(类别)的值,即所有不同类别的评论都存在于列车中并按比例测试数据 .

数据如下所示:

**ReviewerID**       **ReviewText**        **Categories**       **ProductId**

1212                   good product         Mobile               14444425
1233                   will buy again       drugs                324532
5432                   not recomended       dvd                  789654123

我使用以下代码来执行此操作:

import pandas as pd
Meta = pd.read_csv('C:\\Users\\xyz\\Desktop\\WM Project\\Joined.csv')
import numpy as np
from sklearn.cross_validation import train_test_split

train, test = train_test_split(Meta.categories, test_size = 0.2, stratify=y)

它给出了以下错误

NameError: name 'y' is not defined

因为我对python相对较新,所以我无法弄清楚我做错了什么,或者这个代码是否会根据列类别进行分层 . 当我从train-test split中删除了stratify选项以及categories列时,它似乎工作正常 .

任何帮助将不胜感激 .

2 回答

  • 9
    >>> import pandas as pd
        >>> Meta = pd.read_csv('C:\\Users\\*****\\Downloads\\so\\Book1.csv')
        >>> import numpy as np
        >>> from sklearn.model_selection import train_test_split
        >>> y = Meta.pop('Categories')
        >>> Meta
            ReviewerID      ReviewText  ProductId
            0        1212    good product   14444425
            1        1233  will buy again     324532
            2        5432  not recomended  789654123
        >>> y
            0    Mobile
            1     drugs
            2       dvd
            Name: Categories, dtype: object
        >>> X = Meta
        >>> X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.33, random_state=42, stratify=y)
        >>> X_test
            ReviewerID    ReviewText  ProductId
            0        1212  good product   14444425
    
  • 15

    sklearn.model_selection.train_test_split stratify:array-like或None(默认为None)如果不是None,则数据以分层方式拆分,使用此作为类标签 .

    沿着API文档,我认为你必须尝试像 X_train, X_test, y_train, y_test = train_test_split(Meta_X, Meta_Y, test_size = 0.2, stratify=Meta_Y) .

    Meta_XMeta_Y 应该由您正确分配(我认为 Meta_Y 应该是 Meta.categories ,基于您的代码) .

相关问题