0%

K最近邻算法

K 最近邻算法(KNN):原理、应用与实现

K 最近邻算法(K-Nearest Neighbours,KNN)是一种简单直观的监督学习算法,核心思想是 “物以类聚”—— 通过样本周围最近的 K 个邻居的信息来预测其类别或数值。它无需训练过程,属于 “惰性学习”(Lazy Learning),适用于分类和回归任务。

KNN 的核心原理

基本思想

对于未知样本,KNN 通过以下步骤进行预测:

  1. 计算距离:计算未知样本与训练集中所有已知样本的距离(如欧氏距离、曼哈顿距离)。
  2. 找邻居:选取距离最近的K 个样本(K 为超参数,通常为奇数,如 3、5、7)。
  3. 投票 / 平均:
    • 分类任务:K 个邻居中出现次数最多的类别即为未知样本的预测类别(多数投票)。
    • 回归任务:K 个邻居的数值的平均值即为未知样本的预测值。

关键概念

  • K 值选择
    • K 过小:易受噪声影响,模型过拟合(决策边界复杂)。
    • K 过大:邻居中可能包含其他类别的样本,模型欠拟合(决策边界模糊)。
    • 通常通过交叉验证选择最优 K 值(如 3、5)。
  • 距离度量
    • 欧氏距离(最常用):适用于连续特征,公式为 (d(x,y) = \sqrt{\sum_{i=1}^{n}(x_i - y_i)^2})。
    • 曼哈顿距离:适用于高维数据,公式为 (d(x,y) = \sum_{i=1}^{n}|x_i - y_i|)。

KNN 的分类与回归应用

分类任务(离散结果)

示例:预测鸢尾花类别(Setosa、Versicolor、Virginica)。

  • 已知样本:不同鸢尾花的花瓣长度、宽度等特征及对应类别。
  • 未知样本:计算其与所有已知样本的距离,选最近的 5 个邻居,若其中 3 个为 Versicolor,则预测为 Versicolor。

回归任务(连续结果)

示例:预测房价。

  • 已知样本:房屋面积、房间数等特征及对应房价。
  • 未知样本:选最近的 5 个邻居,计算其房价的平均值作为预测值。

KNN 的实现步骤(以分类为例)

1. 数据准备

  • 训练集:包含特征(如花瓣长度、宽度)和标签(如鸢尾花类别)。
  • 未知样本:仅含特征,需预测标签。

2. 距离计算

以欧氏距离为例,计算未知样本与每个训练样本的距离:

1
2
3
4
5
6
7
8
// 计算两个样本的欧氏距离(features1和features2为特征数组)
public static double euclideanDistance(double[] features1, double[] features2) {
double sum = 0.0;
for (int i = 0; i < features1.length; i++) {
sum += Math.pow(features1[i] - features2[i], 2);
}
return Math.sqrt(sum);
}

3. 寻找 K 个最近邻居

  • 对所有距离排序,选取前 K 个最小距离对应的样本。

4. 多数投票预测

  • 统计 K 个邻居中出现次数最多的类别,作为未知样本的预测结果。

完整代码示例(简化版)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import java.util.*;

public class KNNClassifier {
private List<double[]> trainFeatures; // 训练集特征
private List<String> trainLabels; // 训练集标签
private int k; // K值

public KNNClassifier(List<double[]> features, List<String> labels, int k) {
this.trainFeatures = features;
this.trainLabels = labels;
this.k = k;
}

// 预测单个样本的类别
public String predict(double[] sample) {
// 1. 计算与所有训练样本的距离
List<DistanceLabel> distances = new ArrayList<>();
for (int i = 0; i < trainFeatures.size(); i++) {
double dist = euclideanDistance(sample, trainFeatures.get(i));
distances.add(new DistanceLabel(dist, trainLabels.get(i)));
}

// 2. 按距离排序,取前K个邻居
Collections.sort(distances);
List<String> topKLabels = new ArrayList<>();
for (int i = 0; i < k; i++) {
topKLabels.add(distances.get(i).label);
}

// 3. 多数投票
return majorityVote(topKLabels);
}

// 欧氏距离计算
private double euclideanDistance(double[] a, double[] b) {
double sum = 0.0;
for (int i = 0; i < a.length; i++) {
sum += Math.pow(a[i] - b[i], 2);
}
return Math.sqrt(sum);
}

// 多数投票
private String majorityVote(List<String> labels) {
Map<String, Integer> count = new HashMap<>();
for (String label : labels) {
count.put(label, count.getOrDefault(label, 0) + 1);
}
// 找到出现次数最多的标签
return count.entrySet().stream()
.max(Map.Entry.comparingByValue())
.get()
.getKey();
}

// 辅助类:存储距离和对应的标签
private static class DistanceLabel implements Comparable<DistanceLabel> {
double distance;
String label;

DistanceLabel(double distance, String label) {
this.distance = distance;
this.label = label;
}

@Override
public int compareTo(DistanceLabel other) {
return Double.compare(this.distance, other.distance);
}
}

// 测试
public static void main(String[] args) {
// 训练集:假设为鸢尾花数据(特征:花瓣长度、宽度)
List<double[]> features = Arrays.asList(
new double[]{1.4, 0.2}, // Setosa
new double[]{1.3, 0.2}, // Setosa
new double[]{4.5, 1.5}, // Versicolor
new double[]{4.9, 1.5}, // Versicolor
new double[]{6.0, 2.5} // Virginica
);
List<String> labels = Arrays.asList("Setosa", "Setosa", "Versicolor", "Versicolor", "Virginica");

// 初始化KNN分类器(K=3)
KNNClassifier knn = new KNNClassifier(features, labels, 3);

// 预测未知样本(花瓣长度2.0,宽度0.3)
double[] sample = {2.0, 0.3};
System.out.println("预测类别:" + knn.predict(sample)); // 输出:Setosa
}
}

KNN 的优缺点与优化

优点

  • 简单直观:无需训练过程,易于理解和实现。
  • 适应性强:可处理多分类问题,且对异常值不敏感(当 K 较大时)。
  • 无需假设数据分布:适用于非线性数据。

缺点

  • 计算成本高:预测时需与所有训练样本计算距离,时间复杂度为 O (n)(n 为训练样本数),不适用于大规模数据。
  • 对 K 值敏感:K 值选择不当会严重影响结果。
  • 受特征尺度影响:如 “身高(米)” 和 “体重(千克)” 的尺度差异会导致距离计算偏差,需先标准化特征。

优化方向

  • 降维:减少特征数量(如 PCA),降低距离计算成本。
  • KD 树 / 球树:高效的近邻搜索数据结构,加速 KNN 查询。
  • 特征标准化:将所有特征缩放到同一尺度(如 Z-score 标准化)。

适用场景

  • 小规模数据集(如医学诊断、客户分类)。
  • 对模型可解释性要求高的场景(邻居可直观展示决策依据)。
  • 非线性问题(如手写数字识别、推荐系统)。

总结

KNN 是一种 “懒惰却有效” 的算法,通过 “近朱者赤” 的逻辑实现预测,无需复杂的训练过程。其核心在于 K 值选择和距离度量,适用于小规模、非线性的数据任务。尽管计算效率较低,但凭借简单性和适应性,KNN 在实际应用中仍被广泛使用(如推荐系统、模式识别)

欢迎关注我的其它发布渠道

表情 | 预览
快来做第一个评论的人吧~
Powered By Valine
v1.3.10