欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 社会 > Pytorch:Dataset的加载

Pytorch:Dataset的加载

2025/3/17 13:53:14 来源:https://blog.csdn.net/nisidjndn/article/details/146290865  浏览:    关键词:Pytorch:Dataset的加载

文章目录

      • 1. 准备数据
      • 2. 创建自定义 `Dataset` 类
      • 3. 实例化数据集对象
      • 4. 使用 `DataLoader` 加载数据
      • 5. 迭代数据集
      • 6. 预处理和数据增强(可选)
      • 7. 多线程加载(可选)

自定义数据集的加载在 PyTorch 中主要为以下几个步骤:

1. 准备数据

  1. 从文件中读取数据
  2. 对数据进行预处理
  3. 给数据打上标签label
  4. 合并数据(根据实际而定)
  5. 划分训练集和验证集(可使用sklearn.model_selection模块的train_test_split函数)

2. 创建自定义 Dataset

创建一个继承自 torch.utils.data.Dataset 的自定义类,主要作用:

  1. 封装数据ECGDataset 类封装了数据和标签,使得它们可以作为一个整体被处理。
  2. 提供数据访问接口:通过实现 __getitem__ 方法,ECGDataset 类提供了一个标准化的方式来访问数据集中的单个样本。
  3. DataLoader 协同工作Dataset 类与 PyTorch 的 DataLoader 类紧密集成,DataLoader 可以利用 Dataset 类提供的方法来实现批量加载、打乱数据、多线程加载等功能。

这个类需要实现两个方法:__len____getitem__

  • __len__ 方法返回数据集中样本的数量。
  • __getitem__ 方法根据索引返回数据集中的一个样本。

如果有一个心电(ECG)数据集,自定义 Dataset 类可以如下:

from torch.utils.data import Datasetclass ECGDataset(Dataset):def __init__(self, ecg_data, labels):self.ecg_data = ecg_dataself.labels = labels# 计算样本的数量def __len__(self):return len(self.labels)def __getitem__(self, idx):ecg_sample = self.ecg_data[idx]label = self.labels[idx]#最后返回读取到的数据,记住返回一定要是tensor的形式return ecg_sample, label

或者是以这种方式(加入transform参数):

from torch.utils.data import Datasetclass ECGDataset(Dataset):def __init__(self, ecg_data, labels,transform=None):self.ecg_data = ecg_dataself.labels = labelsself.transform = transform# 计算样本的数量def __len__(self):return len(self.labels)def __getitem__(self, idx):ecg_sample = self.ecg_data[idx]label = self.labels[idx]if self.transform:ecg_sample = self.transform(ecg_sample)#最后返回读取到的数据,记住返回一定要是tensor的形式return ecg_sample, label

pstransforms.Compose 通常用于图像数据的预处理,如调整大小、裁剪、翻转和归一化等操作。然而,心电信号(ECG)是一维时间序列数据,不是二维图像数据,因此不能直接应用上述为图像设计的 transforms

对于心电信号,我们通常会采用不同的预处理方法,一些适用于心电信号的常见预处理 transforms

为心电信号定义一个简单的预处理流程:

import numpy as np
from scipy.signal import butter, filtfilt# 定义心电信号预处理的 transforms
class ECGTransform:def __init__(self, sample_rate, lowcut, highcut, filter_order, segment_length):self.sample_rate = sample_rateself.lowcut = lowcutself.highcut = highcutself.filter_order = filter_orderself.segment_length = segment_lengthdef bandpass_filter(self, ecg_signal):# 定义带通滤波器参数nyq = 0.5 * self.sample_ratelow = self.lowcut / nyqhigh = self.highcut / nyqb, a = butter(self.filter_order, [low, high], btype='band')# 应用滤波器filtered_signal = filtfilt(b, a, ecg_signal)return filtered_signaldef standardize(self, ecg_signal):# 标准化信号return (ecg_signal - np.mean(ecg_signal)) / np.std(ecg_signal)def segment_signal(self, ecg_signal):# 将信号分割成固定长度的片段segments = []for start in range(0, len(ecg_signal) - self.segment_length, self.segment_length):segment = ecg_signal[start:start + self.segment_length]segments.append(segment)return np.array(segments)# 使用预处理 transforms
transform = ECGTransform(sample_rate=250, lowcut=0.5, highcut=15.0, filter_order=5, segment_length=5000)
ecg_signal = ...  # 加载心电信号数据
filtered_ecg = transform.bandpass_filter(ecg_signal)
standardized_ecg = transform.standardize(filtered_ecg)
segmented_ecg = transform.segment_signal(standardized_ecg)

在这个示例中,我们创建了一个 ECGTransform 类,它包含带通滤波、标准化和信号分割的方法。

3. 实例化数据集对象

使用你的数据(特征和标签)来创建 Dataset 类的实例,来创建数据集对象。

# 假设 x 是特征数据,y 是标签数据
ecg_dataset = ECGDataset(x, y)

4. 使用 DataLoader 加载数据

使用 torch.utils.data.DataLoader 来包装你的数据集对象,创建数据加载器。DataLoader 可以提供额外的功能,如自动打乱数据、批量加载、多线程加载等。

from torch.utils.data import DataLoader# 创建 DataLoader 实例
data_loader = DataLoader(ecg_dataset, batch_size=32, shuffle=True)

在这个例子中,batch_size=32 表示每次迭代返回 32 个样本的批次,shuffle=True 表示在每个 epoch 开始时打乱数据。

5. 迭代数据集

在你的训练或验证循环中,你可以迭代 DataLoader 实例来获取数据。

for epoch in range(num_epochs):for batch_idx, (ecg_samples, labels) in enumerate(data_loader):pass

在这个循环中,ecg_sampleslabels 是从数据集中加载的批次数据和标签。

6. 预处理和数据增强(可选)

在自定义 Dataset 类中,你可以添加任何特定的预处理或数据增强步骤。这些步骤将在 __getitem__ 方法中执行,确保每个样本在返回之前都经过了适当的处理。

7. 多线程加载(可选)

DataLoader 还支持多线程加载数据,可以通过设置 num_workers 参数来实现。

data_loader = DataLoader(ecg_dataset, batch_size=32, shuffle=True, num_workers=4)

num_workers=4 表示使用 4 个进程来加载数据,这可以显著提高数据加载的效率。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词