62 lines
2.2 KiB
Python
Executable File
62 lines
2.2 KiB
Python
Executable File
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import gzip
|
||
|
||
|
||
def read_idx3(filename):
|
||
"""
|
||
读取gz格式的数据集图像部分,并返回
|
||
|
||
:param filename: extension name of the file is '.gz'
|
||
:return: images data, shape -> num, rows, cols
|
||
"""
|
||
with gzip.open(filename) as fo:
|
||
print('Reading images...')
|
||
buf = fo.read()
|
||
|
||
offset = 0 # 偏移量
|
||
# 首先获取的是这个数据集的头部数据,通常是元数据。
|
||
# '>i' 表示顺序读取,并且数据类型为整数
|
||
# 4 读4个单位
|
||
# offset 偏移量
|
||
# 返回的是一个数组,赋值给header
|
||
header = np.frombuffer(buf, dtype='>i', count=4, offset=offset)
|
||
print(header)
|
||
magic_number, num_images, num_rows, num_cols = header
|
||
# magic number 即幻数,意义不明,只是读取时需要占位所以声明了
|
||
print("\tmagic number: {}, number of images: {}, number of rows: {}, number of columns: {}" \
|
||
.format(magic_number, num_images, num_rows, num_cols))
|
||
# 计算偏移量,以读取后续的内容
|
||
# size = 数组长度
|
||
# itemsize = 每个元素的大小
|
||
# 因此乘起来就是跳过header的内容,读后续的内容
|
||
offset += header.size * header.itemsize
|
||
# 读取真正的数据。>B 表示是二进制数据
|
||
data = np.frombuffer(buf, '>B', num_images * num_rows * num_cols, offset).reshape(
|
||
(num_images, num_rows, num_cols))
|
||
# .reshape 表示按传入的参数重新构造这个数组
|
||
|
||
return data, num_images
|
||
|
||
|
||
def read_idx1(filename):
|
||
"""
|
||
读取gz格式的数据集标签部分,并返回
|
||
|
||
:param filename: extension name of the file is '.gz'
|
||
:return: labels
|
||
"""
|
||
with gzip.open(filename) as fo:
|
||
print('Reading labels...')
|
||
buf = fo.read()
|
||
|
||
offset = 0
|
||
header = np.frombuffer(buf, '>i', 2, offset)
|
||
magic_number, num_labels = header
|
||
print("\tmagic number: {}, number of labels: {}" \
|
||
.format(magic_number, num_labels))
|
||
|
||
offset += header.size * header.itemsize
|
||
|
||
data = np.frombuffer(buf, '>B', num_labels, offset)
|
||
return data |