博主个人微信:xituxiaoshutong100

K-近邻算法

算法 迷途小书童 0评论

简介

K-近邻即K Nearest Neighbor,简写为KNN,它是一种分类和回归算法,是最简单的机器学习算法之一。它的思路是:在特征空间中,如果一个样本附近的k个最近(即特征空间中最邻近)样本的大多数属于某一个类别,则该样本也属于这个类别。

经典案例

knn

看上面的图,已知有2个类别,红色的三角形和蓝色的正方形,现在我们要判断中间的那个绿色的圆是属于哪一类?使用KNN,就从它的邻居下手,但需要看多少个邻居呢?

  • K=3,绿色圆点的最近的3个邻居是2个红色三角形和1个蓝色正方形,少数从属于多数,基于统计的方法,就认为绿色的圆属于红色三角形这一类
  • K=5,绿色圆点的最近的5个邻居是2个红色三角形和3个蓝色正方形,还是少数从属于多数,基于统计的方法,就认为绿色的圆属于蓝色正方形这一类

可以看到,K值的选择,对我们最后的结果影响很大。K值越小,很容易受到单个个体的影响,K值太大,很容易受到较远的特殊距离的影响。这里讲到的距离,常见的计算方法有

K值应该如何设定呢?

很不幸,这里没有一个明确的结论。K的取值受到问题本身和数据集大小的影响,很多时候,需要自己进行多次的尝试,然后选择最佳的值。

KNN的缺点

KNN需要计算与所有样本之间的距离,这样的话,计算量就很大,效率很低,很难应用到较大的数据集当中。

代码示例

sklearn这个库,提供了完整的KNN实现,使用起来也非常简单,通过pip install scikit-learn安装

from sklearn.neighbors import KNeighborsClassifier

...

# n_neighbors就是K值
knn_classifier = KNeighborsClassifier(n_neighbors=5)
knn_classifier.fit(x_train, y_train)

# X_test是待分类的数据
pred = knn_classifier.predict(X_test)

参考资料

喜欢 (0)
发表我的评论
取消评论

表情