首页 文章

我是否只是在sklearn KNN分类器中发现了一个错误,或者一切都按预期工作?

提问于
浏览
1

我一直在玩python sklearn k最近邻分类器,我相信它不能正常工作 - k大于1的结果是错误的 . 我试图想象出不同的k-nn方法与我的示例代码有何不同 .

代码有点长但不是很复杂 . 继续自己运行以获取图片 . 我以大约10个点的列的形式生成样本2D数据 . 大多数代码都是以动画方式在图表上很好地绘制它 . 所有分类都是在for循环中调用“ main ”中的构造库对象KNeighborsClassifier之后发生的 .

我尝试了不同的算法方法,怀疑它是kd-tree问题,但我得到了相同的结果(swap算法=“kdtree”代表“粗暴”或球树)

以下是我得到的结果的说明:

result of classifier with k=3 and uniform weights, kdtrees

Picture comment :正如您在x = 2周围的第3列中所见,红点周围的所有区域都应为红色,并且在x = -4附近区域应为蓝色,因为下一个最近的红点位于相邻列中 . 我相信's not how the classifier should behave and I'我不确定我是否'm not doing something right or it'是库方法错误 . 我是代码,但决定同时问这个问题 . 我也写了'm not familiar with C-Python it' .

Sources and version :我使用scikit-learn documentation和mathplotlib示例制作了代码 . 我运行python 3.6.1和sklearn的0.18.1版本 .

Bonus question :k邻居是否使用近似或确定的kd树回答?根据我的理解,对于k = 1,它可以很容易地完美地工作但是我不确定答案是否总是正确的,k大于1 .

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn import neighbors
import random


random.seed(905) # 905
# interesting seed 2293
def generate_points(sizex, sizey):
    # sizex = 4
    # sizey = 10
    apart = 5
    # generating at which X coordinate my data column will be
    columns_x = [random.normalvariate(0, 5) for i in range(sizex)]
    columns_y = list()
    # randomising for each column the Y coordinate at which it starts
    for i in range(sizex):
        y_column = [random.normalvariate(-50, 100) for j in range(sizey)]
        y_column.sort()
        columns_y.append(y_column)

    # preparing lists of datapoints with classification
    datapoints = np.ndarray((sizex * sizey, 2))
    dataclass = list()

    # genenerating random split for each column
    for i in range(sizex):
        division = random.randint(0, sizey)
        for j in range(sizey):
            datapoints[i * sizey + j][0] = columns_x[i]
            datapoints[i * sizey + j][1] = -j * apart
            dataclass.append(j < division)

    return datapoints, dataclass


if __name__ == "__main__":
    datapoints, dataclass = generate_points(4, 10)

    #### VISUALISATION PART ####
    x_min, y_min = np.argmin(datapoints, axis=0)
    x_min, y_min = datapoints[x_min][0], datapoints[y_min][1]
    x_max, y_max = np.argmax(datapoints, axis=0)
    x_max, y_max = datapoints[x_max][0], datapoints[y_max][1]
    x_range = x_max - x_min
    y_range = y_max - y_min
    x_min -= 0.15*x_range
    x_max += 0.15*x_range
    y_min -= 0.15*y_range
    y_max += 0.15*y_range

    mesh_step_size = .1

    # Create color maps
    cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF']) # for meshgrid
    cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF']) # for points

    plt.ion() # plot interactive mode
    for weights in ['uniform', 'distance']: # two types of algorithm
        for k in range(1, 13, 2): # few k choices
            # we create an instance of Neighbours Classifier and fit the data.
            clf = neighbors.KNeighborsClassifier(k, weights=weights, algorithm="kd_tree")
            clf.fit(datapoints, dataclass)

            # Plot the decision boundary. For that, we will assign a color to each
            # point in the mesh [x_min, x_max]x[y_min, y_max].
            xx, yy = np.meshgrid(np.arange(x_min, x_max, mesh_step_size),
                                 np.arange(y_min, y_max, mesh_step_size))
            Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])

            # Put the result into a color plot
            Z = Z.reshape(xx.shape)

            plt.figure(1)
            plt.pcolormesh(xx, yy, Z, cmap=cmap_light)

            # Plot also the training points
            plt.scatter(datapoints[:, 0], datapoints[:, 1], c=dataclass, cmap=cmap_bold, marker='.')
            plt.xlim(xx.min(), xx.max())
            plt.ylim(yy.min(), yy.max())
            plt.title("K-NN classifier (k = %i, weights = '%s')"
                      % (k, weights))

            plt.draw()
            input("Press Enter to continue...")
            plt.clf()

我也决定在发布之前设置种子,所以我们都得到相同的结果,随意设置随机种子 .

1 回答

  • 0

    你的输出似乎很好 .

    从图中可能不明显的是,点之间的水平距离实际上是 shorter 而不是垂直距离 . 即使两个相邻列之间的最远水平间隔是4.something,而任何两个相邻行之间的垂直间隔是5 .

    对于归类为红色的点,他们在训练集中的3个最近邻居中的大多数为 really are red . 如果接下来的两个邻居是红色的话,它不会超级接近蓝点 . 对于在红点附近归类为蓝色的点也是如此 .

相关问题