欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 文旅 > 旅游 > 机器学习-基于KNN算法手动实现kd树

机器学习-基于KNN算法手动实现kd树

2025/3/25 17:58:31 来源:https://blog.csdn.net/m0_73426548/article/details/146397449  浏览:    关键词:机器学习-基于KNN算法手动实现kd树

目录

一、概括

二、KD树的构建流程

1.循环选轴

2.选择分裂点

三、kd树的查询

1.输入我们要搜索的点

2.递归向下遍历:

3.记录最近点

4.回溯父节点:

四、KD树的优化与变种:

五、KD树代码:


上一章我们将了机器学习-手搓KNN算法,这一章我们加上kd树对它进行优化,下面先来讲讲kd树。

KD 树(K-Dimensional Tree)是一种高效的K 维空间数据索引结构,主要用于最近邻搜索和范围搜索。以下从原理、构建、查询、优化等方面详细讲解:

一、概括

KD树通过递归划分k维空间,将数据点组织成二叉树结构:每一个节点代表一个k维超矩形空,比如在二维空间中,就是一个矩形包围一个点,三维就是一个体来包围一个点。然后使用二叉树将这些点连接起来,父节点选择一个维度作为分裂轴,用该维度的中位数将区域划分维左子树(小于分裂轴的点)和右子树(大于等于分裂轴的点)

二、KD树的构建流程

以X = [(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)]为例

1.循环选轴

先计算每个维度的方差,X的x轴数据有(2,5,9,4,8,7)方差为5.8055。x轴数据有(3,4,6,7,1,2)方差为4.4722。因为x轴的方差大于y轴,那么先选择x轴。等到x轴分裂后下一次就是y轴,如果还有别的维度那么继续循环,循环结束后又回到y轴开始下一轮的循环

2.选择分裂点

第一次在上面选择完成后的轴x上,选择该轴的中位数,数据为((2,5,9,4,8,7)那么中位数为(5,4)那么在该点上分裂,分裂后的左子树为[(2,3)],右子树为[(9,6), (4,7), (8,1), (7,2)]

第二次选择y轴:在上面的右子树中的中位数为(7,2),那么根据中位数分裂后左子树为[(4,7), (8,1)],右子树为:[(9,6)]。继续循环,循环结束后树结构为:

      (5,4) (x轴分裂)/        \
(2,3)      (7,2) (y轴分裂)/      \(4,7)    (9,6)\(8,1)

三、kd树的查询

既然设计到树,那么肯定有增删改查。

1.输入我们要搜索的点

最近邻搜索的目的是找到我们要查询的点的最近的K个点,那么目标就变成了在我们的KD树中寻找到距离搜索点的最小距离的K个点。

2.递归向下遍历:

从根节点开始,根据当前分裂轴比较我们要搜索的点,如果比我们要搜索的点大就去右子树,小就去左子树。

3.记录最近点

等到第二步递归到叶子节点时,那么这个叶子节点就是距离我们要搜索的点最近的点,将这个点记录下来

4.回溯父节点:

计算我们搜索到的点到我们要搜索的点的距离,因为还要遍历另外一边的最近点,比如刚刚遍历的是左子树,那么现在要遍历右子树了,每次回溯到父节点后都要将搜索到的点与上一次搜索的最近点比较距离大小,将小的留下

示例:

以上面的例子为例:比如查找(6,3)的最近点

1.从根节点(5,4)出发,x 轴分裂,6>5,进入右子树(7,2)。

2.(7,2)是 y 轴分裂,3>2,进入右子树(9,6),记录最近点为(9,6)(距离√[(6-9)²+(3-6)²]=√18)。

3.回溯到(7,2),计算 y 轴分裂超平面距离为 | 3-2|=1 < √18,检查左子树(4,7)和(8,1)。
在左子树中,(8,1)距离为√[(6-8)²+(3-1)²]=√8,更近,更新最近点。

4.回溯到根节点(5,4),计算 x 轴分裂超平面距离为 | 6-5|=1 < √8,检查左子树(2,3),距离√[(6-2)²+(3-3)²]=4 > √8,不更新。最终最近点为 (8,1)。

2.范围搜索

目标:找到所有在k维超矩形区域内的点。

这个方法是先设置一个距离,然后递归遍历树,若当前节点的分裂轴到我们查询点的距离超过了我们设置的距离,则直接剪枝就是不去遍历这个节点以后的点了,如果这个节点在查询区域内则加入结果集,继续搜索子树

四、KD树的优化与变种:

1.BBF算法:使用有线队列优化最近邻搜索,减少回溯次数

2.Ball树:用超球体代替超矩形,更高效处理高维数据(普通KD树在维度>20时性能明显下降)

3.k-d-B树:结合KD树和B树,支持动态插入和删除

五、KD树代码:

import numpy as np
from collections import dequefrom sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScalerclass KDNode:def __init__(self,point,left=None,right=None,axis=None):self.point=point  # 数据点[]self.left=leftself.right=rightself.axis=axisclass KDTree:def __init__(self,data,labels):self.data=np.c_[data,labels]self.root=self.build_tree(self.data)def build_tree(self,points,depth=0):if len(points)==0:return Nonek=points.shape[1]-1axis=depth%ksorted_points=points[points[:,axis].argsort()]median_idx=len(sorted_points)//2median_point=sorted_points[median_idx]left=self.build_tree(sorted_points[:median_idx],depth+1)right=self.build_tree(sorted_points[median_idx+1:],depth+1)return KDNode(median_point,left,right,axis)def query_knn(self, target, k):best_candidates = []  # 保存最近的k个邻居(按距离倒序存储)candidates = deque()  # 使用双端队列实现非递归遍历candidates.append((self.root, False))  # (当前节点, 是否已访问)while candidates:node, visited = candidates.pop()if node is None:continueif not visited:# 计算当前节点到目标的欧氏距离(排除标签列)distance = np.sqrt(np.sum((node.point[:-1] - target)  ** 2))# 维护长度为k的优先队列(使用负距离实现最大堆)if len(best_candidates) < k:best_candidates.append((-distance, node.point))best_candidates.sort(reverse=True)  # 按距离从大到小排序else:if distance < -best_candidates[0][0]:best_candidates.pop()  # 移除最远候选best_candidates.append((-distance, node.point))best_candidates.sort(reverse=True)# 根据切分维度决定搜索路径(类似二叉搜索树)axis = node.axisif target[axis] < node.point[axis]:candidates.append((node, True))  # 标记当前节点已访问candidates.append((node.left, False))  # 先搜索左子树else:candidates.append((node, True))candidates.append((node.right, False))  # 先搜索右子树else:# 回溯检查另一侧子树是否需要搜索(剪枝优化)axis = node.axisworst_dist = -best_candidates[0][0] if best_candidates else np.inf# 判断目标点到分割超平面的距离是否小于当前最远邻居距离if (len(best_candidates) < k) or \(abs(target[axis] - node.point[axis]) < worst_dist):if target[axis] < node.point[axis]:candidates.append((node.right, False))  # 搜索右子树else:candidates.append((node.left, False))  # 搜索左子树# 返回前k个邻居的标签(按距离从近到远排序)return [point[-1] for (dist, point) in sorted(best_candidates, reverse=True)]class KNNWithKDTree:def __init__(self, k=5):self.k = k  # 最近邻数量Kself.kdtree = None  # 存储构建好的KD树def fit(self, X, y):# 构建KD树(将训练数据和标签传入)self.kdtree = KDTree(X, y)def predict(self, X_test):predictions = []for x in X_test:# 获取当前测试样本的K个最近邻标签neighbors = self.kdtree.query_knn(x, self.k)# 多数投票(取出现次数最多的类别)most_common = max(set(neighbors), key=neighbors.count)predictions.append(most_common)return np.array(predictions)if __name__ == '__main__':# 加载鸢尾花数据集iris = load_iris()X, y = iris.data, iris.target# 数据标准化(消除量纲影响)scaler = StandardScaler()X = scaler.fit_transform(X)# 划分训练集和测试集(70%训练,30%测试)X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 初始化KNN分类器(K=5)knn = KNNWithKDTree(k=5)knn.fit(X_train, y_train)  # 训练模型(构建KD树)# 预测测试集结果y_pred = knn.predict(X_test)# 计算准确率accuracy = np.sum(y_pred == y_test) / len(y_test)print(f"准确率: {accuracy:.4f}")  # 输出如:准确率: 0.9778

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词