欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 房产 > 建筑 > digit_eye开发记录(2): Python读取MNIST数据集

digit_eye开发记录(2): Python读取MNIST数据集

2024/11/30 11:18:20 来源:https://blog.csdn.net/baiyu33/article/details/144097058  浏览:    关键词:digit_eye开发记录(2): Python读取MNIST数据集

在上一篇博客 digit_eye开发记录(1): C++读取MNIST数据集 中解读了 IDX 文件格式,并使用 C++ 语言完成了 MNIST 数据集的解析,第6小节给出的完整代码有146行之多。使用 Python 读取则可以省略70%的代码,只用不到50行代码完成相同功能。

读取 buffer

np.frombuffer(buf, dtype, count, offset)

说明:

  • buf: buffer,从文件读出来的
  • dtype: 从buf读取时,按什么类型读取数据,或者说,读取的基本单位是什么
  • count: 从buf读取时,读取多少个基本单位
  • offset: 从buf读取时,指针首先偏移多少个字节

读取 magic number

magic number 是 mnist 文件的前4个字节。 以二进制形式打开后,读取4字节即可:

import numpy as npwith open(filename, 'rb') as fin:buf = bytearray(fin.read())
magic = np.frombuffer(buf, np.uint8, count=4)
print(magic)

读取维度信息

回忆一下 magic numbers 的构成: 前两个字节是0,第三个字节是类型,第四个字节是维度数量 num_dims。
根据 num_dims 的取值,读取对应数量的字节,得到对应的维度信息。每个维度都是一个 int32 大小。

注意 MSB 到 LSB 的转换,通过 dtype=np.dtype('>u4') 指定, >u4 意思是:以MSB序,读取4个byte.

对于图像数据:

num_dims = magic[3]
dims = np.frombuffer(buf, dtype=np.dtype('>u4'), count=num_dims, offset=4)

对于label数据:

dims = np.frombuffer(buf, dtype=np.dtype('>u4'), count=num_dims, offset=4)
num_labels = dims[0]

读取图像像素

很容易想到使用 OOP 方式,定义 DataSet 类,在成员 self.images 中保存图像;于是乎,很“毛躁”的写出如下糟糕代码:

class DataSet:def __init__(self):self.images = []self.labels = []def load_images(self, filename):...for i in range(num_images):self.images.append(...)

存在的问题:

  • self.images 的类型一定是 list 吗?其实可以是 numpy 数组
  • self.images 的每个元素,和其他元素,一定是独立的吗? 可以是同一个内存上连续的分布
  • self.images 的每个元素,内存可以和读取文件得到的 buffer 复用吗?可以!
class DataSet:def __init__(self):self.images = Noneself.labels = Noneself.buf = Nonedef load_images(self, filename):with open(filename, 'rb') as fin:self.buf = bytearray(fin.read())magic = np.frombuffer(self.buf, np.uint8, count=4)num_dims = magic[3]dims = np.frombuffer(self.buf, dtype=np.dtype('>u4'), count=num_dims, offset=4)num_images, rows, cols = dimsself.images = np.frombuffer(self.buf, dtype=np.uint8, offset=4+4*num_dims).reshape(num_images, rows, cols)...train_set = DataSet()
train_set.load_images('data/train-images.idx3-ubyte')
print("Images and buffer share memory:", np.shares_memory(train_set.images, train_set.buf))

解释:self.buf 的类型,如果直接用 fin.read() 则得到 bytes 类型,是不可变的;转为 bytearray 类型后,是可变的,就可以保持和 self.images( ) 共享。

遗憾的是, self.buf = bytearray(fin.read()) 这句本身就发生了内存拷贝。

改进 - 避免内存拷贝

with open(filename, 'rb') as fin:  self.buf = bytearray(fin.read())  # 当前实现,存在两次内存分配  

改为

with open(filename, 'rb') as fin:  self.buf = fin.read()  # 读取为 bytes  self.buf = memoryview(self.buf)  # 直接使用 memoryview  

就可以避免 bytes 对象的中间拷贝过程。

完整代码

import numpy as np
import cv2class DataSet:def __init__(self):self.images = Noneself.labels = Noneself.buf = Nonedef load_images(self, filename):with open(filename, 'rb') as fin:#self.buf = bytearray(fin.read())self.buf = fin.read()self.buf = memoryview(self.buf)magic = np.frombuffer(self.buf, np.uint8, count=4)num_dims = magic[3]dims = np.frombuffer(self.buf, dtype=np.dtype('>u4'), count=num_dims, offset=4)num_images, rows, cols = dimsself.images = np.frombuffer(self.buf, dtype=np.uint8, offset=4+4*num_dims).reshape(num_images, rows, cols)def load_labels(self, filename):with open(filename, 'rb') as fin:buf = fin.read()magic = np.frombuffer(buf, np.uint8, count=4)num_dims = magic[3]dims = np.frombuffer(buf, dtype=np.dtype('>u4'), count=num_dims, offset=4)num_labels = dims[0]assert num_labels == len(self.images)self.labels = np.frombuffer(buf, dtype=np.uint8, offset=4+4*num_dims)def show_image(self, index):cv2.imshow('image', self.images[index])print('label:', self.labels[index])cv2.waitKey(0)cv2.destroyAllWindows()def main():train_set = DataSet()train_set.load_images('data/train-images.idx3-ubyte')train_set.load_labels('data/train-labels.idx1-ubyte')# train_set.show_image(0)# train_set.show_image(2)# train_set.show_image(5)print("Images and buffer share memory:", np.shares_memory(train_set.images, train_set.buf))if __name__ == '__main__':main()

总结

在前一篇,我们解析了MNIST数据集的IDX格式并用C++做了文件读取的实现,在本篇则切换到 Python 语言,在降低70%代码量的情况下实现了相同功能,并且避免了不必要的内存拷贝。这份工程之美,建立在对 IDX 格式有所了解的前提之下,对于 Python 的熟悉也是必不可少的,对于C++的经验也促使了复用内存这一条件的达成。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com