PyTorch 核心详解
目录
- PyTorch 核心详解
- 目录
- 1. PyTorch 简介
- 2. 张量(Tensor)操作
- 基本操作
- 常用函数
- 3. 自动微分(Autograd)
- 基本用法
- 禁用梯度跟踪
- 4. 神经网络模块(nn.Module)
- 定义模型
- 常用层
- 5. 数据加载与预处理
- 自定义数据集
- 数据增强
- 6. 模型训练与验证
- 训练流程
- 验证流程
- 7. 模型保存与加载
- 保存模型参数(推荐)
- 加载模型
- 8. GPU 加速
- 设备设置
1. PyTorch 简介
PyTorch 是一个基于 Python 的开源深度学习框架,由 Facebook 的 AI 研究团队开发。其核心特点包括:
- 动态计算图:实时构建计算图,适合调试和复杂模型。
- GPU 加速:无缝支持 CUDA,可高效利用 GPU。
- 模块化设计:通过
torch.nn
等模块快速搭建模型。
2. 张量(Tensor)操作
张量是 PyTorch 的核心数据结构,类似于 NumPy 的多维数组,但支持 GPU 加速。
基本操作
import torch# 创建张量
x = torch.tensor([[1, 2], [3, 4]]) # 从数据创建
y = torch.zeros(2, 3) # 全零张量
z = torch.rand(3, 3