文章目录
- 张量拼接操作
- 1. torch.cat 函数的使用
- 1.1. torch.cat 定义
- 1.2. 语法
- 1.3. 关键规则
- 1.4. 示例代码
- 1.4.1. 沿行拼接(dim=0)
- 1.4.2. 沿列拼接(dim=1)
- 1.4.3. 高维拼接(dim=2)
- 1.5. 错误场景分析
- 1.5.1. 维度数不一致
- 1.5.2. 非拼接维度大小不匹配
- 1.5.3. 设备或数据类型不一致
- 1.6. 与 torch.stack 的区别
- 1.7. 高级用法
- 1.7.1. 批量拼接(Batch-wise Concatenation)
- 1.7.2. 自动广播支持
- 1.8. 总结
- 2. torch.stack 函数的使用
- 2.1. 函数定义
- 2.2. 核心规则
- 2.3. 使用示例
- 2.4. 与 torch.cat 的对比
- 2.4. 常见错误与调试
- 2.5. 工程实践技巧
- 2.7. 性能优化建议
- 2.8. 总结
张量拼接操作
1. torch.cat 函数的使用
在 PyTorch 中,torch.cat 是用于沿指定维度拼接多个张量的核心函数
1.1. torch.cat 定义
功能: 将多个张量沿指定维度(dim)拼接,生成新张量。
输入要求:
所有输入张量的 维度数必须相同。
非拼接维度的大小必须一致。
张量必须位于 同一设备 且 数据类型相同。
1.2. 语法
torch.cat(tensors, dim=0, *, out=None) → Tensor
参数:
tensors (sequence of Tensors):需拼接的张量序列(列表或元组)。
dim (int, optional):拼接的维度索引,默认为 0。
out (Tensor, optional):可选输出张量。
1.3. 关键规则
规则 | 示例 |
---|---|
输入张量维度数必须相同 | 不允许将 2D 张量与 3D 张量拼接 |
非拼接维度大小必须一致 | 若 dim=1,所有张量的 dim=0、dim=2 等大小必须相同 |
拼接维度大小可以不同 | 沿 dim=0 拼接形状为 (2, 3) 和 (3, 3) 的张量,结果为 (5, 3) |
输出维度数与输入相同 | 输入均为 3D 张量,输出仍为 3D 张量 |
1.4. 示例代码
1.4.1. 沿行拼接(dim=0)
import torchA = torch.tensor([[1, 2], [3, 4]]) # shape: (2, 2)
B = torch.tensor([[5, 6], [7, 8]]) # shape: (2, 2)
C = torch.cat([A, B], dim=0) # shape: (4, 2)
print(C)
# 输出:
# tensor([[1, 2],
# [3, 4],
# [5, 6],
# [7, 8]])
1.4.2. 沿列拼接(dim=1)
D = torch.tensor([[9], [10]]) # shape: (2, 1)
E = torch.cat([A, D], dim=1) # shape: (2, 3)
print(E)
# 输出:
# tensor([[ 1, 2, 9],
# [ 3, 4, 10]])
1.4.3. 高维拼接(dim=2)
F = torch.randn(2, 3, 4) # shape: (2, 3, 4)
G = torch.randn(2, 3, 5) # shape: (2, 3, 5)
H = torch.cat([F, G], dim=2) # shape: (2, 3, 9)
1.5. 错误场景分析
1.5.1. 维度数不一致
A_2D = torch.randn(2, 3)
B_3D = torch.randn(2, 3, 4)
try:torch.cat([A_2D, B_3D], dim=0) # 报错:维度数不同
except RuntimeError as e:print("错误:", e)
1.5.2. 非拼接维度大小不匹配
A = torch.randn(2, 3)
B = torch.randn(3, 3) # dim=0 大小不同
try:torch.cat([A, B], dim=1) # 报错:非拼接维度大小不一致
except RuntimeError as e:print("错误:", e)
1.5.3. 设备或数据类型不一致
if torch.cuda.is_available():A_cpu = torch.randn(2, 3)B_gpu = torch.randn(2, 3).cuda()try:torch.cat([A_cpu, B_gpu], dim=0) # 报错:设备不一致except RuntimeError as e:print("错误:", e)
1.6. 与 torch.stack 的区别
函数 | 输入维度 | 输出维度 | 核心用途 |
---|---|---|---|
torch.cat | 所有张量维度相同 | 维度数与输入相同 | 沿现有维度扩展张量 |
torch.stack | 所有张量形状严格相同 | 新增一个维度 | 创建新维度合并张量 |
示例对比:
A = torch.tensor([1, 2]) # shape: (2)
B = torch.tensor([3, 4]) # shape: (2)# cat 沿 dim=0
C_cat = torch.cat([A, B]) # shape: (4)# stack 沿 dim=0
C_stack = torch.stack([A, B]) # shape: (2, 2)
1.7. 高级用法
1.7.1. 批量拼接(Batch-wise Concatenation)
# 批量数据拼接(batch_size=2)
batch_A = torch.randn(2, 3, 4) # shape: (2, 3, 4)
batch_B = torch.randn(2, 3, 5) # shape: (2, 3, 5)
batch_C = torch.cat([batch_A, batch_B], dim=2) # shape: (2, 3, 9)
1.7.2. 自动广播支持
torch.cat 不支持广播,必须显式匹配形状:
A = torch.randn(3, 1) # shape: (3, 1)
B = torch.randn(1, 3) # shape: (1, 3)
try:torch.cat([A, B], dim=1) # 报错:非拼接维度大小不一致
except RuntimeError as e:print("错误:", e)
1.8. 总结
适用场景:合并同维度的特征、批量数据拼接等。
核心规则:
1、输入张量维度数相同。2、非拼接维度大小严格一致。3、设备与数据类型一致。
优先使用 torch.cat:当需要在现有维度扩展时;需新增维度时选择 torch.stack。
2. torch.stack 函数的使用
2.1. 函数定义
torch.stack(tensors, dim=0, *, out=None) → Tensor
功能:将多个张量沿新维度堆叠(非拼接),要求所有输入张量形状严格相同。
- 输入:
- tensors (sequence of Tensors):形状相同的张量序列(列表/元组)。
- dim (int):新维度的插入位置(支持负数索引)。
- 输出:
- 比输入张量多一维的新张量。
2.2. 核心规则
规则 | 示例 |
---|---|
输入张量形状必须完全相同 | (3, 4) 只能与 (3, 4) 堆叠,不能与 (3, 5) 堆叠 |
输出维度 = 输入维度 + 1 | 输入(3, 4) → 输出 (n, 3, 4)(n为堆叠数量) |
新维度大小 = 张量数量 | 堆叠3个张量 → 新维度大小为3 |
设备/数据类型必须一致 | 所有张量需在同一设备(CPU/GPU)且 dtype 相同 |
2.3. 使用示例
(1) 基础用法
import torch
# 定义两个相同形状的张量
A = torch.tensor([1, 2, 3]) # shape: (3,)
B = torch.tensor([4, 5, 6]) # shape: (3,)# 沿新维度0堆叠
C = torch.stack([A, B]) # shape: (2, 3)
print(C)
# tensor([[1, 2, 3],
# [4, 5, 6]])# 沿新维度1堆叠
D = torch.stack([A, B], dim=1) # shape: (3, 2)
print(D)
# tensor([[1, 4],
# [2, 5],
# [3, 6]])
(2) 高维张量堆叠
# 形状为 (2, 3) 的张量
X = torch.randn(2, 3)
Y = torch.randn(2, 3)# 沿dim=0堆叠(新增最外层维度)
Z0 = torch.stack([X, Y]) # shape: (2, 2, 3)# 沿dim=1堆叠(插入到第二维)
Z1 = torch.stack([X, Y], dim=1) # shape: (2, 2, 3)# 沿dim=-1堆叠(插入到最后一维)
Z2 = torch.stack([X, Y], dim=-1) # shape: (2, 3, 2)
(3) 批量数据构建
# 模拟批量图像数据(单张图像shape: (3, 32, 32))
image1 = torch.randn(3, 32, 32)
image2 = torch.randn(3, 32, 32)
image3 = torch.randn(3, 32, 32)# 构建batch维度(batch_size=3)
batch = torch.stack([image1, image2, image3]) # shape: (3, 3, 32, 32)
2.4. 与 torch.cat 的对比
特性 torch.stack torch.cat
输入要求 所有张量形状严格相同 仅需非拼接维度相同
输出维度 比输入多1维 与输入维度相同
内存开销 更高(新增维度) 更低(复用现有维度)
典型场景 构建batch、新增序列维度 合并特征、扩展现有维度
示例对比:
A = torch.tensor([1, 2])
B = torch.tensor([3, 4])# stack -> 新增维度
stacked = torch.stack([A, B]) # shape: (2, 2)# cat -> 沿现有维度扩展
concatenated = torch.cat([A, B]) # shape: (4)
2.4. 常见错误与调试
(1) 形状不匹配
A = torch.randn(2, 3)
B = torch.randn(2, 4) # 第二维不同
try:torch.stack([A, B])
except RuntimeError as e:print("Error:", e) # Sizes of tensors must match
(2) 设备不一致
A_cpu = torch.randn(3, 4)
B_gpu = torch.randn(3, 4).cuda()
try:torch.stack([A_cpu, B_gpu])
except RuntimeError as e:print("Error:", e) # Expected all tensors to be on the same device
(3) 空张量处理
empty_tensors = [torch.tensor([]) for _ in range(3)]
try:torch.stack(empty_tensors) # 可能引发未定义行为
except RuntimeError as e:print("Error:", e)
2.5. 工程实践技巧
(1) 批量数据预处理
# 从数据加载器中逐批读取数据并堆叠
batch_images = []
for image in dataloader:batch_images.append(image)if len(batch_images) == batch_size:batch = torch.stack(batch_images) # shape: (batch_size, C, H, W)process_batch(batch)batch_images = []
(2) 序列建模中的时间步堆叠
# RNN输入序列构建(T个时间步,每个步长特征dim=D)
time_steps = [torch.randn(1, D) for _ in range(T)]
input_seq = torch.stack(time_steps, dim=1) # shape: (1, T, D)
(3) 多任务输出合并
# 多任务学习中的输出堆叠
task1_out = torch.randn(batch_size, 10)
task2_out = torch.randn(batch_size, 5)
multi_out = torch.stack([task1_out, task2_out], dim=1) # shape: (batch_size, 2, ...)
2.7. 性能优化建议
避免循环中频繁堆叠:优先在内存中收集所有张量后一次性堆叠。
# 低效做法
result = None
for x in data_stream:if result is None:result = x.unsqueeze(0)else:result = torch.stack([result, x.unsqueeze(0)])# 高效做法
tensor_list = [x for x in data_stream]
result = torch.stack(tensor_list)
显存不足时考虑分块处理:
chunk_size = 1000
for i in range(0, len(big_list), chunk_size):chunk = torch.stack(big_list[i:i+chunk_size])process(chunk)
2.8. 总结
核心用途:构建batch、新增维度、多任务输出整合。
关键检查点:
- 输入张量形状完全一致。
- 设备与数据类型统一。
- 合理选择 dim 参数控制维度扩展位置。
优先选择场景:当需要显式创建新维度时使用;若仅需扩展现有维度,用 torch.cat 更高效。