文章目录
- 1. 准备数据
- 2. 创建自定义 `Dataset` 类
- 3. 实例化数据集对象
- 4. 使用 `DataLoader` 加载数据
- 5. 迭代数据集
- 6. 预处理和数据增强(可选)
- 7. 多线程加载(可选)
自定义数据集的加载在
PyTorch
中主要为以下几个步骤:
1. 准备数据
- 从文件中读取数据
- 对数据进行预处理
- 给数据打上标签label
- 合并数据(根据实际而定)
- 划分训练集和验证集(可使用
sklearn.model_selection
模块的train_test_split
函数)
2. 创建自定义 Dataset
类
创建一个继承自 torch.utils.data.Dataset
的自定义类,主要作用:
- 封装数据:
ECGDataset
类封装了数据和标签,使得它们可以作为一个整体被处理。 - 提供数据访问接口:通过实现
__getitem__
方法,ECGDataset
类提供了一个标准化的方式来访问数据集中的单个样本。 - 与
DataLoader
协同工作:Dataset
类与 PyTorch 的DataLoader
类紧密集成,DataLoader
可以利用Dataset
类提供的方法来实现批量加载、打乱数据、多线程加载等功能。
这个类需要实现两个方法:__len__
和 __getitem__
。
__len__
方法返回数据集中样本的数量。__getitem__
方法根据索引返回数据集中的一个样本。
如果有一个心电(ECG)数据集,自定义 Dataset
类可以如下:
from torch.utils.data import Datasetclass ECGDataset(Dataset):def __init__(self, ecg_data, labels):self.ecg_data = ecg_dataself.labels = labels# 计算样本的数量def __len__(self):return len(self.labels)def __getitem__(self, idx):ecg_sample = self.ecg_data[idx]label = self.labels[idx]#最后返回读取到的数据,记住返回一定要是tensor的形式return ecg_sample, label
或者是以这种方式(加入transform参数):
from torch.utils.data import Datasetclass ECGDataset(Dataset):def __init__(self, ecg_data, labels,transform=None):self.ecg_data = ecg_dataself.labels = labelsself.transform = transform# 计算样本的数量def __len__(self):return len(self.labels)def __getitem__(self, idx):ecg_sample = self.ecg_data[idx]label = self.labels[idx]if self.transform:ecg_sample = self.transform(ecg_sample)#最后返回读取到的数据,记住返回一定要是tensor的形式return ecg_sample, label
ps
: transforms.Compose
通常用于图像数据的预处理,如调整大小、裁剪、翻转和归一化等操作。然而,心电信号(ECG)是一维时间序列数据,不是二维图像数据,因此不能直接应用上述为图像设计的 transforms
。
对于心电信号,我们通常会采用不同的预处理方法,一些适用于心电信号的常见预处理 transforms
:
为心电信号定义一个简单的预处理流程:
import numpy as np
from scipy.signal import butter, filtfilt# 定义心电信号预处理的 transforms
class ECGTransform:def __init__(self, sample_rate, lowcut, highcut, filter_order, segment_length):self.sample_rate = sample_rateself.lowcut = lowcutself.highcut = highcutself.filter_order = filter_orderself.segment_length = segment_lengthdef bandpass_filter(self, ecg_signal):# 定义带通滤波器参数nyq = 0.5 * self.sample_ratelow = self.lowcut / nyqhigh = self.highcut / nyqb, a = butter(self.filter_order, [low, high], btype='band')# 应用滤波器filtered_signal = filtfilt(b, a, ecg_signal)return filtered_signaldef standardize(self, ecg_signal):# 标准化信号return (ecg_signal - np.mean(ecg_signal)) / np.std(ecg_signal)def segment_signal(self, ecg_signal):# 将信号分割成固定长度的片段segments = []for start in range(0, len(ecg_signal) - self.segment_length, self.segment_length):segment = ecg_signal[start:start + self.segment_length]segments.append(segment)return np.array(segments)# 使用预处理 transforms
transform = ECGTransform(sample_rate=250, lowcut=0.5, highcut=15.0, filter_order=5, segment_length=5000)
ecg_signal = ... # 加载心电信号数据
filtered_ecg = transform.bandpass_filter(ecg_signal)
standardized_ecg = transform.standardize(filtered_ecg)
segmented_ecg = transform.segment_signal(standardized_ecg)
在这个示例中,我们创建了一个 ECGTransform
类,它包含带通滤波、标准化和信号分割的方法。
3. 实例化数据集对象
使用你的数据(特征和标签)来创建 Dataset
类的实例,来创建数据集对象。
# 假设 x 是特征数据,y 是标签数据
ecg_dataset = ECGDataset(x, y)
4. 使用 DataLoader
加载数据
使用 torch.utils.data.DataLoader
来包装你的数据集对象,创建数据加载器。DataLoader
可以提供额外的功能,如自动打乱数据、批量加载、多线程加载等。
from torch.utils.data import DataLoader# 创建 DataLoader 实例
data_loader = DataLoader(ecg_dataset, batch_size=32, shuffle=True)
在这个例子中,batch_size=32
表示每次迭代返回 32 个样本的批次,shuffle=True
表示在每个 epoch 开始时打乱数据。
5. 迭代数据集
在你的训练或验证循环中,你可以迭代 DataLoader
实例来获取数据。
for epoch in range(num_epochs):for batch_idx, (ecg_samples, labels) in enumerate(data_loader):pass
在这个循环中,ecg_samples
和 labels
是从数据集中加载的批次数据和标签。
6. 预处理和数据增强(可选)
在自定义 Dataset
类中,你可以添加任何特定的预处理或数据增强步骤。这些步骤将在 __getitem__
方法中执行,确保每个样本在返回之前都经过了适当的处理。
7. 多线程加载(可选)
DataLoader
还支持多线程加载数据,可以通过设置 num_workers
参数来实现。
data_loader = DataLoader(ecg_dataset, batch_size=32, shuffle=True, num_workers=4)
num_workers=4
表示使用 4 个进程来加载数据,这可以显著提高数据加载的效率。