← 返回题库
初级

第03章 k近邻法 - 习题3.3 - 代码实现

未完成
初级参考 完整示例代码供参考,建议自己理解后重新输入
import numpy as np
from collections import namedtuple

class Node(namedtuple("Node", "location left_child right_child")):
    pass

class KdTree:
    def __init__(self, k=1):
        self.k = k
        self.root = None
    
    def _fit(self, X, depth=0):
        if len(X) == 0:
            return None
        axis = depth % self.k
        X = X[X[:, axis].argsort()]
        median = len(X) // 2
        return Node(location=X[median],
                    left_child=self._fit(X[:median], depth + 1),
                    right_child=self._fit(X[median + 1:], depth + 1))
    
    def fit(self, X):
        self.k = X.shape[1]
        self.root = self._fit(X)
        return self

train_data = np.array([(2, 3), (5, 4), (9, 6), (4, 7), (8, 1), (7, 2)])
tree = KdTree()
tree.fit(train_data)
Python 代码 🔒 登录后使用
🔒

登录后即可练习

注册免费账号,在浏览器中直接运行 Python 代码