目录
1. 基础写法
1.1导包
2.2加载读取数据
2.3原始数据可视化(画图显示)
2.4线性回归的(基础)分解写法
2.5定义训练过程
2.PyTorch实现 线性回归的封装写法(实际项目中的常用写法)
2.1创建线性回归模型
2.2定义损失函数
2.3定义优化器
2.4定义训练过程
1. 基础写法
1.1导包
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
2.2加载读取数据¶
data = pd.read_csv('./dataset/Income1.csv')
data
#读取数据类型为dataframe类型
输出结果截图所示(部分数据)
data.head() #查看dataframe数据的前五条数据
data.tail() #后五条数据
data.Education.head() #查看数据的Education列的前五条数据 #是一个Series
0 10.000000 1 10.401338 2 10.842809 3 11.244147 4 11.645485 Name: Education, dtype: float64
data.Education[:5] #查看数据的Education列的前五条数据
0 10.000000 1 10.401338 2 10.842809 3 11.244147 4 11.645485 Name: Education, dtype: float64
2.3原始数据可视化(画图显示)
#画散点图,观察数据Education 与 Income 是否具有线性关系
plt.scatter(data.Education, data.Income)
plt.xlabel('Education')
plt.ylabel('Income')