← 返回题库
初级

Wage数据集阶梯函数

未完成
初级参考 完整示例代码供参考,建议自己理解后重新输入
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
import statsmodels.api as sm
wage = pd.read_csv('https://liangdaima.com/static/data/statistics/Wage.csv')
X = wage['age']
y = wage['wage']
for cuts in [2, 4, 6, 8]:
    X_cut = pd.cut(X, cuts)
    X_dummies = pd.get_dummies(X_cut)
    X_train, X_test, y_train, y_test = train_test_split(X_dummies, y, test_size=0.5, random_state=0)
    model = sm.OLS(y_train, X_train).fit()
    pred = model.predict(X_test)
    mse = mean_squared_error(y_test, pred)
    print(f'{cuts}个分割点的MSE: {mse:.2f}')

示例

输入
solve()
期望输出
2个分割点的MSE: 1778.57
4个分割点的MSE: 1674.19
6个分割点的MSE: 1673.36
8个分割点的MSE: 1649.97
Python 代码 🔒 登录后使用
🔒

登录后即可练习

注册免费账号,在浏览器中直接运行 Python 代码