TensorDataset
是 PyTorch 提供的一个工具类,用于将多个张量(Tensors)打包成一个数据集(Dataset),便于配合 DataLoader
进行批量加载和数据管理。
一、概念
作用:将多个张量(如特征张量、标签张量)按样本对齐,合并为一个数据集对象。
适用场景:监督学习任务、多输入、多输出模型、简化数据加载流程,兼容DataLoader的批处理、打乱数据等操作
二、使用步骤
1.导入库
import torch
from torch.utils.data import TensorDataset, DataLoader
2.准备数据
创建一些特征张量和标签,数据类型可以是numpy或list列表
特征张量:(样本数,特征维度)
标签张量:(样本数,)
# 示例数据(100个样本,每个样本5个特征)
features = torch.randn(100, 5) # 随机生成特征
labels = torch.randint(0, 2, (100,)) # 二分类标签(0或1)
3。创建TensorDataser,将特征张量和标签合并成一个大张量
dataset = TensorDataset(features, labels)
4.使用DataLoader加载数据,将合并好的张量按照指定大小来切分批次
dataloader = DataLoader(dataset,batch_size=32,shuffle=True, # 训练时打乱数据num_workers=2 # 多进程加载数据(可选)
)# 遍历批次
for batch_features, batch_labels in dataloader:# 在此处执行模型的前向传播、损失计算等操作print("Batch features shape:", batch_features.shape)print("Batch labels shape:", batch_labels.shape)
三、注意事项
- 张量对齐:所有传入
TensorDataset
的张量的第一个维度(样本数)必须相同。 - 数据类型:确保张量的数据类型(
dtype
)与模型输入要求一致(如float32
或int64
)。 - 设备位置:若使用 GPU,需将张量放在 GPU 上(通过
tensor.to(device)
),或让DataLoader
自动处理。 - 内存限制:数据量过大时,优先使用
Dataset
的子类(如IterableDataset
)动态加载数据,避免内存溢出。
四、完整代码
import torch
from torch.utils.data import TensorDataset, DataLoader# 1. 生成模拟数据
num_samples = 1000
features = torch.randn(num_samples, 10) # 10维特征
labels = torch.randn(num_samples) # 连续值标签# 2. 创建 TensorDataset
dataset = TensorDataset(features, labels)# 3. 定义 DataLoader
dataloader = DataLoader(dataset,batch_size=64,shuffle=True,num_workers=4
)# 4. 遍历数据
for batch_idx, (batch_features, batch_labels) in enumerate(dataloader):print(f"Batch {batch_idx}:")print(" Features shape:", batch_features.shape) # [64, 10]print(" Labels shape: ", batch_labels.shape) # [64]