← 返回题库
初级

第01章 统计学习方法概论 - 使用最小二乘法拟和曲线 - 可视化

未完成
初级参考 完整示例代码供参考,建议自己理解后重新输入
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import leastsq

def real_func(x):
    return np.sin(2*np.pi*x)

def fit_func(p, x):
    f = np.poly1d(p)
    return f(x)

def residuals_func(p, x, y):
    return fit_func(p, x) - y

def residuals_func_regularization(p, x, y):
    regularization = 0.0001
    ret = fit_func(p, x) - y
    ret = np.append(ret, np.sqrt(0.5 * regularization * np.square(p)))
    return ret

x = np.linspace(0, 1, 10)
x_points = np.linspace(0, 1, 1000)
y_ = real_func(x)
y = [np.random.normal(0, 0.1) + y1 for y1 in y_]

p_init_9 = np.random.rand(9 + 1)
p_lsq_9 = leastsq(residuals_func, p_init_9, args=(x, y))

p_init_reg = np.random.rand(9 + 1)
p_lsq_regularization = leastsq(residuals_func_regularization, p_init_reg, args=(x, y))

plt.plot(x_points, real_func(x_points), label='real')
plt.plot(x_points, fit_func(p_lsq_9[0], x_points), label='fitted curve')
plt.plot(x_points, fit_func(p_lsq_regularization[0], x_points), label='regularization')
plt.plot(x, y, 'bo', label='noise')
plt.legend()
Python 代码 🔒 登录后使用
🔒

登录后即可练习

注册免费账号,在浏览器中直接运行 Python 代码