欢迎访问我的网站,希望内容对您有用,感兴趣的可以加入我们的社群。

K-近邻算法

算法 迷途小书童 4年前 (2021-03-11) 3946次浏览 0个评论

简介

K -近邻即 K Nearest Neighbor,简写为 KNN,它是一种分类和回归算法,是最简单的机器学习算法之一。它的思路是:在特征空间中,如果一个样本附近的 k 个最近(即特征空间中最邻近)样本的大多数属于某一个类别,则该样本也属于这个类别。

经典案例

knn

看上面的图,已知有2个类别,红色的三角形和蓝色的正方形,现在我们要判断中间的那个绿色的圆是属于哪一类?使用 KNN,就从它的邻居下手,但需要看多少个邻居呢?

  • K=3,绿色圆点的最近的3个邻居是2个红色三角形和1个蓝色正方形,少数从属于多数,基于统计的方法,就认为绿色的圆属于红色三角形这一类
  • K=5,绿色圆点的最近的5个邻居是2个红色三角形和3个蓝色正方形,还是少数从属于多数,基于统计的方法,就认为绿色的圆属于蓝色正方形这一类

可以看到,K 值的选择,对我们最后的结果影响很大。K 值越小,很容易受到单个个体的影响,K 值太大,很容易受到较远的特殊距离的影响。这里讲到的距离,常见的计算方法有

K 值应该如何设定呢?

很不幸,这里没有一个明确的结论。K 的取值受到问题本身和数据集大小的影响,很多时候,需要自己进行多次的尝试,然后选择最佳的值。

KNN的缺点

KNN 需要计算与所有样本之间的距离,这样的话,计算量就很大,效率很低,很难应用到较大的数据集当中。

代码示例

sklearn 这个库,提供了完整的 KNN 实现,使用起来也非常简单,通过 pip install scikit-learn 安装

from sklearn.neighbors import KNeighborsClassifier

...

# n_neighbors就是K值
knn_classifier = KNeighborsClassifier(n_neighbors=5)
knn_classifier.fit(x_train, y_train)

# X_test是待分类的数据
pred = knn_classifier.predict(X_test)

参考资料

喜欢 (0)

您必须 登录 才能发表评论!