初级
第03章 k近邻法 - 习题3.3 - 实现Node类
未完成
初级参考
完整示例代码供参考,建议自己理解后重新输入
# 构建kd树,搜索待预测点所属区域
from collections import namedtuple
import numpy as np
# 建立节点类
class Node(namedtuple("Node", "location left_child right_child")):
def __repr__(self):
return str(tuple(self))
# kd tree类
class KdTree():
def __init__(self, k=1):
self.k = k
self.kdtree = None
# 构建kd tree
def _fit(self, X, depth=0):
try:
k = self.k
except IndexError as e:
return None
# 这里可以展开,通过方差选择axis
axis = depth % k
X = X[X[:, axis].argsort()]
median = X.shape[0] // 2
try:
X[median]
except IndexError:
return None
return Node(location=X[median],
left_child=self._fit(X[:median], depth + 1),
right_child=self._fit(X[median + 1:], depth + 1))
def _search(self, point, tree=None, depth=0, best=None):
if tree is None:
return best
k = self.k
# 更新 branch
if point[0][depth % k] < tree.location[depth % k]:
next_branch = tree.left_child
else:
next_branch = tree.right_child
if not next_branch is None:
best = next_branch.location
return self._search(point,
tree=next_branch,
depth=depth + 1,
best=best)
def fit(self, X):
self.kdtree = self._fit(X)
return self.kdtree
def predict(self, X):
res = self._search(X, self.kdtree)
return res
👑
升级 VIP
解锁全部题目,畅通无阻地学习
- ✓ 解锁全部训练包所有题目
- ✓ 查看完整参考代码和提示
- ✓ 浏览器内直接运行 Python 代码
- ✓ 自动批改 + 进度追踪
30天
¥18
1年
¥99
2年
¥158
3年
¥199