欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 资讯 > PyTorch张量基础操作与数据表示完全指南

PyTorch张量基础操作与数据表示完全指南

2025/2/23 22:31:39 来源:https://blog.csdn.net/abc666_666/article/details/145646889  浏览:    关键词:PyTorch张量基础操作与数据表示完全指南

一、张量基础认知

1.1 什么是张量?

张量(Tensor)是PyTorch中最核心的数据结构,可以理解为多维数组的扩展:

  • 0维张量:标量(Scalar)

  • 1维张量:向量(Vector)

  • 2维张量:矩阵(Matrix)

  • 3维张量:时间序列数据

  • 4维张量:图像数据(批量大小×通道×高×宽)

  • 5维张量:视频数据

1.2 张量的关键属性

import torchtensor = torch.randn(3, 4)
print(f"Shape: {tensor.shape}")      # 张量维度
print(f"Data type: {tensor.dtype}")  # 数据类型
print(f"Device: {tensor.device}")    # 存储设备
print(f"Layout: {tensor.layout}")    # 内存布局
print(f"Requires grad: {tensor.requires_grad}")  # 梯度追踪

二、张量创建全解析

2.1 基础创建方法

# 从Python列表创建
data_tensor = torch.tensor([[1., 2], [3, 4]])# 使用NumPy转换
import numpy as np
numpy_array = np.array([[1, 2], [3, 4]])
tensor_from_np = torch.from_numpy(numpy_array)# 特殊初始化方法
zeros_tensor = torch.zeros(2, 3)         # 全零张量
ones_tensor = torch.ones(2, 3)           # 全一张量
eye_tensor = torch.eye(3)                # 单位矩阵
rand_tensor = torch.rand(2, 3)           # [0,1)均匀分布
randn_tensor = torch.randn(2, 3)         # 标准正态分布
linspace_tensor = torch.linspace(0, 10, 5)  # 线性空间

2.2 高级初始化技巧

# 内存共享初始化
shared_tensor = torch.as_tensor(numpy_array)  # 共享内存# 指定设备初始化
gpu_tensor = torch.rand(3, 3, device='cuda')# 类型转换初始化
int_tensor = torch.tensor([1, 2], dtype=torch.int32)# 克隆已有张量
clone_tensor = tensor.clone().detach()

三、张量操作大全

3.1 形状操作

tensor = torch.rand(2, 3, 4)# 改变形状
reshaped = tensor.view(2, 12)    # 保持内存连续性
flattened = tensor.flatten()     # 展平为1D
squeezed = tensor.squeeze()     # 去除大小为1的维度
unsqueezed = tensor.unsqueeze(0) # 增加维度# 维度交换
permuted = tensor.permute(2, 0, 1)
transposed = tensor.transpose(1, 2)

3.2 索引与切片

tensor = torch.rand(5, 4, 3)# 基础索引
print(tensor[0])               # 第一维度索引
print(tensor[:, 1:3, :])       # 多维切片
print(tensor[..., 0])          # 省略号语法# 高级索引
mask = tensor > 0.5
print(tensor[mask])            # 布尔掩码索引
indices = torch.tensor([0, 2])
print(tensor[:, indices, :])   # 整数数组索引

四、数学运算详解

4.1 逐元素运算

a = torch.tensor([1., 2, 3])
b = torch.tensor([4., 5, 6])# 基础运算
add = a + b            # 加法
sub = a - b            # 减法
mul = a * b            # 乘法
div = a / b            # 除法
pow = a ** 2           # 幂运算# 比较运算
eq = a == b            # 等于
gt = a > b             # 大于# 数学函数
sin = torch.sin(a)
log = torch.log(a)
exp = torch.exp(a)

4.2 矩阵运算

matrix1 = torch.rand(3, 4)
matrix2 = torch.rand(4, 5)# 矩阵乘法
matmul = torch.mm(matrix1, matrix2)        # 2D矩阵乘
bmm = torch.bmm(matrix1.unsqueeze(0), matrix2.unsqueeze(0))  # 批量矩阵乘# 降维操作
sum_all = matrix1.sum()          # 全局求和
sum_dim0 = matrix1.sum(dim=0)    # 沿维度0求和
mean_dim1 = matrix1.mean(dim=1)  # 沿维度1求平均

五、广播机制解析

5.1 广播规则

  • 从尾部维度开始对齐

  • 维度大小相等或其中一个为1

  • 缺失维度视为1

# 合法广播示例
a = torch.rand(3, 1, 4)
b = torch.rand(   2, 4)
result = a + b  # 最终形状:(3, 2, 4)# 非法广播示例
c = torch.rand(3, 4)
d = torch.rand(2, 4)
try:c + d  # 触发错误
except RuntimeError as e:print(e)

5.2 显式广播控制

# 手动扩展维度
a = torch.rand(3, 1)
b = torch.rand(1, 4)
a_expanded = a.expand(3, 4)
b_expanded = b.expand(3, 4)# 使用广播函数
a = torch.rand(3, 1)
b = torch.rand(4)
result = a + b.reshape(1, 4)

六、数据转换技巧

6.1 类型转换

tensor = torch.tensor([1, 2, 3])# 安全转换
float_tensor = tensor.float()        # 转为float32
double_tensor = tensor.double()     # 转为float64
int_tensor = tensor.short()         # 转为int16# 类型检查
if tensor.dtype != torch.float32:tensor = tensor.to(torch.float32)

6.2 与NumPy互操作

# Tensor转NumPy
tensor = torch.rand(3, 3)
numpy_array = tensor.numpy()        # CPU张量共享内存
safe_numpy = tensor.detach().cpu().numpy()  # 安全转换# NumPy转Tensor
numpy_data = np.random.rand(3, 3)
tensor_shared = torch.as_tensor(numpy_data)  # 共享内存
tensor_copy = torch.tensor(numpy_data)      # 数据拷贝

七、高级操作实战

7.1 条件选择

# torch.where三元选择
condition = torch.rand(3, 3) > 0.5
x = torch.ones(3, 3)
y = torch.zeros(3, 3)
result = torch.where(condition, x, y)# masked_select过滤
mask = torch.bernoulli(torch.ones(3,3)*0.5).bool()
selected = torch.masked_select(tensor, mask)

7.2 张量收集

# 创建示例数据
tensor = torch.arange(1, 13).view(3, 4)
indices = torch.tensor([[0, 1], [2, 3]])# 使用gather
dim1_result = torch.gather(tensor, 1, indices)  # 沿列收集# 使用index_select
selected_rows = torch.index_select(tensor, 0, torch.tensor([0, 2]))

八、GPU加速指南

8.1 设备管理

# 设备检测
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 张量迁移
cpu_tensor = torch.rand(3,3)
gpu_tensor = cpu_tensor.to(device)          # 转到GPU
back_cpu = gpu_tensor.cpu()                 # 转回CPU# 设备一致性检查
try:tensor1 = torch.rand(3,3).cuda()tensor2 = torch.rand(3,3).cpu()result = tensor1 + tensor2  # 触发错误
except RuntimeError as e:print("设备不一致错误:", e)

8.2 性能优化技巧

# 使用非阻塞传输
with torch.cuda.stream(torch.cuda.Stream()):gpu_tensor = cpu_tensor.to(device, non_blocking=True)# 固定内存加速传输
pinned_tensor = torch.rand(1024, 1024).pin_memory()# 混合精度训练
scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():outputs = model(inputs)loss = criterion(outputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

九、最佳实践总结

9.1 常见错误与修复

# 维度不匹配
try:a = torch.rand(3,4)b = torch.rand(3,5)c = a + b
except RuntimeError as e:print("维度错误:", e)# 原地操作问题
a = torch.rand(3,3)
b = a.add_(1)  # 原地操作会影响梯度计算

9.2 性能优化建议

  1. 优先使用内置函数而非Python循环

  2. 合理使用广播减少内存占用

  3. 避免频繁的CPU-GPU数据传输

  4. 使用原地操作节省内存

  5. 选择合适的批处理大小

  6. 定期使用memory_profiler分析内存使用

# 内存分析示例
from torch import memory_profiler@memory_profiler.profile
def process_data():data = torch.randn(1000, 1000).cuda()result = data @ data.Treturn result.cpu()process_data()

十、常见问题解答

Q1:view和reshape有什么区别?

  • view要求内存连续,reshape会自动处理连续性

  • 优先使用reshape保证代码安全

Q2:如何克隆张量?

  • 使用tensor.clone().detach()同时复制数据和断开计算图

Q3:为什么需要detach()?

  • 阻断梯度传播,常用于冻结网络层

Q4:如何高效处理大型数据集?

  • 使用Dataset和DataLoader

  • 采用内存映射文件

  • 使用多进程加载

# Dataset示例
from torch.utils.data import Dataset, DataLoaderclass CustomDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):return torch.tensor(self.data[idx])dataset = CustomDataset(np.random.rand(1000, 10))
loader = DataLoader(dataset, batch_size=32, shuffle=True)

通过系统学习张量操作,您已经掌握了PyTorch的核心数据操作技能。接下来可以继续学习自动微分、神经网络模块等进阶内容。如果本文对您有帮助,欢迎点赞收藏!

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词