欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 社会 > 自定义数据集 使用paddlepaddle框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

自定义数据集 使用paddlepaddle框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

2025/2/2 3:09:52 来源:https://blog.csdn.net/qq_63603839/article/details/145407192  浏览:    关键词:自定义数据集 使用paddlepaddle框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

代码:

# 导入必要的库
import numpy as np
import paddle
import paddle.nn as nn# 设置随机种子,确保实验可重复
seed = 1
paddle.seed(seed)# 数据集:一组二维数据,包含x和y的对应关系
data = [[-0.5, 7.7], [1.8, 98.5], [0.9, 57.8], [0.4, 39.2],[-1.4, -15.7], [-1.4, -37.3], [-1.8, -49.1], [1.5, 75.6],[0.4, 34.0], [0.8, 62.3]]# 将数据转为NumPy数组格式
data = np.array(data)# 分离特征(x)和标签(y),x是输入,y是输出
x_data = data[:, 0]
y_data = data[:, 1]# 将数据转为Paddle张量类型,paddle.to_tensor用于转换
x_train = paddle.to_tensor(x_data, dtype=paddle.float32)
y_train = paddle.to_tensor(y_data, dtype=paddle.float32)# 定义线性回归模型,继承自paddle.nn.Layer
class LinearModel(nn.Layer):def __init__(self):# 初始化时,定义一个线性层(1个输入特征和1个输出特征)super(LinearModel, self).__init__()self.linear = nn.Linear(1, 1)def forward(self, x):# 前向传播,输入x通过线性层计算输出x = self.linear(x)return x# 实例化模型对象
model = LinearModel()# 定义损失函数,这里使用均方误差(MSE)
criterion = paddle.nn.MSELoss()# 定义优化器,这里使用SGD(随机梯度下降),学习率设置为0.01
optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())# 训练的迭代次数
epochs = 500
final_checkpoint = {}  # 保存最终训练结果# 训练循环,迭代epochs次
for epoch in range(1, epochs + 1):# 通过模型计算预测值,模型输入x_train需要增加一个维度(因为nn.Linear需要二维输入)y_prd = model(x_train.unsqueeze(1))# 计算损失,y_prd是预测值,y_train是实际值loss = criterion(y_prd.squeeze(1), y_train)# 清空梯度optimizer.clear_grad()# 反向传播计算梯度loss.backward()# 更新模型参数optimizer.step()# 每10个epoch输出一次损失if epoch % 10 == 0 or epoch == 1:print(f"epoch:{epoch},loss:{float(loss)}")# 在最后一个epoch保存模型的状态if epoch == epochs:final_checkpoint['epoch'] = epochfinal_checkpoint['loss'] = loss# 保存模型参数到文件,方便之后加载
paddle.save(model.state_dict(), './model.params')# 加载保存的模型参数
model.load_dict(paddle.load('./model.params'))
model.eval()  # 设置为评估模式(例如,关闭Dropout等)# 使用训练后的模型进行预测
x_test = paddle.to_tensor([[1.8]], dtype=paddle.float32)
y_test = model(x_test)# 打印预测结果
print(f'y_test:{y_test}')

结果:

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com