MNIST数据集获取

MNIST数据集是入门机器学习/模式识别的最经典数据集之一。最早于1998年Yan Lecun在论文:

中提出。经典的LeNet-5 CNN网络也是在该论文中提出的。
数据集包含了0-9共10类手写数字图片,每张图片都做了尺寸归一化,都是28x28大小的灰度图。每张图片中像素值大小在0-255之间,其中0是黑色背景,255是白色前景。如下图所示:

MNIST共包含70000张手写数字图片,其中有60000张用作训练集,10000张用作测试集。原始数据集可在MNIST官网下载。

下载之后得到4个压缩文件:

train-images-idx3-ubyte.gz #60000张训练集图片
train-labels-idx1-ubyte.gz #60000张训练集图片对应的标签
t10k-images-idx3-ubyte.gz #10000张测试集图片
t10k-labels-idx1-ubyte.gz #10000张测试集图片对应的标签

将其解压,得到:

train-images-idx3-ubyte
train-labels-idx1-ubyte
t10k-images-idx3-ubyte
t10k-labels-idx1-ubyte

MNIST二进制文件的存储格式

解压得到的四个文件都是二进制格式,我们如何获取其中的信息呢?这得首先了解MNIST二进制文件的存储格式(官网底部有介绍),以训练集图像文件train-images-idx3-ubyte为例:

图像文件的

  • 第1-4个byte(字节,1byte=8bit),即前32bit存的是文件的magic number,对应的十进制大小是2051;
  • 第5-8个byte存的是number of images,即图像数量60000;
  • 第9-12个byte存的是每张图片行数/高度,即28;
  • 第13-16个byte存的是每张图片的列数/宽度,即28。
  • 从第17个byte开始,每个byte存储一张图片中的一个像素点的值。

因为train-images-idx3-ubyte文件总共包含了60000张图片数据,按照以上的存储方式,我们算一下该文件的大小:

  • 一张图片包含28x28=784个像素点,需要784bytes的存储空间;
  • 60000张图片则需要784x60000=47040000 bytes的存储空间;
  • 此外,文件开始处使用了16个bytes用于存储magic number、图像数量、图像高度和图像宽度,因此,训练集图像文件的大小应该是47040000+16=47040016 bytes。

我们查看解压后的train-images-idx3-ubyte文件的属性:

文件实际大小和我们计算的结果一致。

类似地,我们查看训练集标签文件train-labels-idx1-ubyte的存储格式:

和图像文件类似:

  • 第1-4个byte存的是文件的magic number,对应的十进制大小是2049;
  • 第5-8个byte存的是number of items,即label数量60000;
  • 从第9个byte开始,每个byte存一个图片的label信息,即数字0-9中的一个。

计算一下训练集标签文件train-labels-idx1-ubyte的文件大小:

  • 1x60000+8=60008 bytes。

与该文件实际的大小一致:

另外两个文件,即测试集图像文件、测试集标签文件的存储方式和训练图像文件、训练标签文件相似,只是图像数量由60000变为10000。

使用python访问MNIST数据集文件内容

知道了MNIST二进制文件的存储方式,下面介绍如何使用python访问文件内容。同样以训练集图像文件train-images-idx3-ubyte为例:

import numpy as np
import matplotlib.pyplot as plt

'''试验transpose()
def back (a,b):
    return a,b

if __name__ == '__main__':
    a = np.array([[1,2,3],[11,12,13],[21,22,23]])
    print(a)
    b = np.array([[31,32,33],[41,42,43],[51,52,53]])
    print(b)
    a, b = transpose(back(a,b))
    #a, b = back(a, b)
    print(a)
    print(b)
'''

# 数据加载器基类
class Loader(object):
    def __init__(self, path, count):
        '''
        初始化加载器
        path: 数据文件路径
        count: 文件中的样本个数
        '''
        self.path = path
        self.count = count

    def get_file_content(self):
        '''
        读取文件内容
        '''
        f = open(self.path, 'rb')
        content = f.read()
        f.close()
        return content

    def to_int(self, byte):
        '''
        将unsigned byte字符转换为整数
        '''
        #print(byte)
        #return struct.unpack('B', byte)[0]
        return byte

# 图像数据加载器
class ImageLoader(Loader):
    def get_picture(self, content, index):
        '''
        内部函数,从文件中获取图像
        '''
        start = index * 28 * 28 + 16
        picture = []
        for i in range(28):
            picture.append([])
            for j in range(28):
                picture[i].append(
                    self.to_int(content[start + i * 28 + j]))
        return picture

    def get_one_sample(self, picture):
        '''
        内部函数,将图像转化为样本的输入向量
        '''
        sample = []
        for i in range(28):
            for j in range(28):
                sample.append(picture[i][j])
        return sample

    def load(self):
        '''
        加载数据文件,获得全部样本的输入向量
        '''
        content = self.get_file_content()
        data_set = []
        for index in range(self.count):
            data_set.append(
                self.get_one_sample(
                    self.get_picture(content, index)))
        return data_set

# 标签数据加载器
class LabelLoader(Loader):
    def load(self):
        '''
        加载数据文件,获得全部样本的标签向量
        '''
        content = self.get_file_content()
        labels = []
        for index in range(self.count):
            labels.append(self.norm(content[index + 8]))
        return labels

    def norm(self, label):
        '''
        内部函数,将一个值转换为10维标签向量
        '''
        label_vec = []
        label_value = self.to_int(label)
        for i in range(10):
            if i == label_value:
                label_vec.append(0.9)
            else:
                label_vec.append(0.1)
        return label_vec

def get_training_data_set():
    '''
    获得训练数据集
    '''
    image_loader = ImageLoader('train-images.idx3-ubyte', 60000)
    label_loader = LabelLoader('train-labels.idx1-ubyte', 60000)
    return image_loader.load(), label_loader.load()

def get_test_data_set():
    '''
    获得测试数据集
    '''
    image_loader = ImageLoader('t10k-images.idx3-ubyte', 10000)
    label_loader = LabelLoader('t10k-labels.idx1-ubyte', 10000)
    return image_loader.load(), label_loader.load()

if __name__ == '__main__':
    train_data_set, train_labels = get_training_data_set()
    line = np.array(train_data_set[0])
    img = line.reshape((28,28))
    plt.imshow(img)
    plt.show()

输出图片如下:

参考:

https://www.jianshu.com/p/e7c286530ab9

Python读取MNIST数据集的更多相关文章

  1. python读取mnist

    python读取mnist 其实就是python怎么读取binnary file mnist的结构如下,选取train-images TRAINING SET IMAGE FILE (train-im ...

  2. mnist的格式说明,以及在python3.x和python 2.x读取mnist数据集的不同

    有一个关于mnist的一个事例可以参考,我觉得写的很好:http://www.cnblogs.com/x1957/archive/2012/06/02/2531503.html #!/usr/bin/ ...

  3. python 将Mnist数据集转为jpg,并按比例/标签拆分为多个子数据集

    现有条件:Mnist数据集,下载地址:跳转 下载后的四个.gz文件解压后放到同一个文件夹下,如:/raw Step 1:将Mnist数据集转为jpg图片(代码来自这篇博客) 1 import os 2 ...

  4. C++读取MNIST数据集

    MNIST是一个标准的手写字符测试集. Mnist数据集对应四个文件: train-images-idx3-ubyte: training set images  train-labels-idx1- ...

  5. python读取,显示,保存mnist图片

    python处理二进制 python的struct模块可以将整型(或者其它类型)转化为byte数组.看下面的代码. # coding: utf-8 from struct import * # 包装成 ...

  6. C++基于文件流和armadillo读取mnist

    发现网上大把都是用python读取mnist的,用C++大都是用opencv读取的,但我不怎么用opencv,因此自己摸索了个使用文件流读取mnist的方法,armadillo仅作为储存矩阵的一种方式 ...

  7. 深度学习(一)之MNIST数据集分类

    任务目标 对MNIST手写数字数据集进行训练和评估,最终使得模型能够在测试集上达到\(98\%\)的正确率.(最终本文达到了\(99.36\%\)) 使用的库的版本: python:3.8.12 py ...

  8. 利用Python读取外部数据文件

      不论是数据分析,数据可视化,还是数据挖掘,一切的一切全都是以数据作为最基础的元素.利用Python进行数据分析,同样最重要的一步就是如何将数据导入到Python中,然后才可以实现后面的数据分析.数 ...

  9. MNIST数据集转化为二维图片

    #coding: utf-8 from tensorflow.examples.tutorials.mnist import input_data import scipy.misc import o ...

随机推荐

  1. EAC3 mantissa quantization(VQ & GAQ)

    EAC3基于hebap来决定mantissa的quantizer. hebap如下: mantissa 使用VQ(vector quantization) 和GAQ(gain adaptive qua ...

  2. 浅谈Power Signoff

    Power Analysis是芯片设计实现中极重要的一环,因为它直接关系到芯片的性能和可靠性.Power Analysis 需要Timing Analysis 产生包含频率.transition 等时 ...

  3. ASP.NET Core 使用过滤器移除重复代码

    USING ACTIONFILTERS TO REMOVE DUPLICATED CODE ASP.NET Core 的过滤器可以让我们在请求管道的特定状态之前或之后运行一些代码.因此如果我们的 ac ...

  4. Bugku-CTF之江湖魔头(学会如来神掌应该就能打败他了吧)

    Day39     江湖魔头 200 http://123.206.31.85:1616/ 学会如来神掌应该就能打败他了吧

  5. 左偏树(p3377)

    题目描述 如题,一开始有N个小根堆,每个堆包含且仅包含一个数.接下来需要支持两种操作: 操作1: 1 x y 将第x个数和第y个数所在的小根堆合并(若第x或第y个数已经被删除或第x和第y个数在用一个堆 ...

  6. curl模拟提交

    function curl_post($url, $post){ $options = array( CURLOPT_RETURNTRANSFER =>true, CURLOPT_HEADER ...

  7. 【C语言】将两个字符串连接起来

    #include<stdio.h> int main() { ] = "I "; ] = "am a student"; int i, j, n; ...

  8. CSS学习(4)常见样式声明

    1.文本 color 文字颜色 预设值:定义好的单词,如red blue 光学的三原色(红,绿,蓝),如 rgb(32,45,255) HEX十六进制,如#008CFF(#112233可以简写为#12 ...

  9. knn 算法 k个相近邻居

    # 一个最基本的例子 #样本数据的封装 feature = [[170,70,42],[166,56,39],[188,90,44],[165,88,40],[170,66,40],[176,80,4 ...

  10. MBA 报考

      1. 作者:MBA薛老师链接:https://www.zhihu.com/question/277811289/answer/397083199来源:知乎著作权归作者所有.商业转载请联系作者获得授 ...