在训练神经网络之前,我们必须有数据,作为资深伸手党,必须知道以下几个数据提供源:

一、CIFAR-10

CIFAR-10图片样本截图

CIFAR-10是多伦多大学提供的图片数据库,图片分辨率压缩至32x32,一共有10种图片分类,均进行了标注。适合监督式学习。CIFAR-10数据下载页面

二、ImageNet

imagenet首页

ImageNet首页

三、ImageFolder

imagefolder首页

ImageFolder首页

四、LSUN Classification

LSUN Classification

LSUN 图片下载地址

五、COCO (Captioning and Detection)

coco首页

COCO首页地址

六、我们进入正题

为了方便加载以上五种数据库的数据,pytorch团队帮我们写了一个torchvision包。使用torchvision就可以轻松实现数据的加载和预处理。

我们以使用CIFAR10为例:

导入torchvision的库:

import torchvision

import torchvision.transforms as transforms  # transforms用于数据预处理

使用datasets.CIFAR10()函数加载数据库。CIFAR10有60000张图片,其中50000张是训练集,10000张是测试集。

#训练集,将相对目录./data下的cifar-10-batches-py文件夹中的全部数据(50000张图片作为训练数据)加载到内存中,若download为True时,会自动从网上下载数据并解压trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=None)

下面简单讲解root、train、download、transform这四个参数

1.root,表示cifar10数据的加载的相对目录

2.train,表示是否加载数据库的训练集,false的时候加载测试集

3.download,表示是否自动下载cifar数据集

4.transform,表示是否需要对数据进行预处理,none为不进行预处理

由于美帝路途遥远,靠命令台进程下载100多M的数据速度很慢,所以我们可以自己去到cifar10的官网上把CIFAR-10 python version下载下来,然后解压为cifar-10-batches-py文件夹,并复制到相对目录./data下。(若设置download=True,则程序会自动从网上下载cifar10数据到相对目录./data下,但这样小伙伴们可能要等一个世纪了),并对训练集进行加载(train=True)。

如图所示,在脚本文件下建一个data文件夹,然后把数据集文件夹丢到里面去就好了,注意cifar-10-batches-py文件夹名字不能自己任意改。

我们在写完上面三行代码后,在写一行print一下trainset的大小看看:

print len(trainset)

#结果:50000

我们在训练神经网络时,使用的是mini-batch(一次输入多张图片),所以我们在使用一个叫DataLoader的工具为我们将50000张图分成每四张图一分,一共12500份的数据包。

#将训练集的50000张图片划分成12500份,每份4张图,用于mini-batch输入。shffule=True在表示不同批次的数据遍历时,打乱顺序(这个需要在训练神经网络时再来讲)。num_workers=2表示使用两个子进程来加载数据

import torch

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=False, num_workers=2)

那么我们就写下了这几行代码:

print的结果为50000和12500

下面我们需要对数据进行预处理,什么是预处理?为什么要预处理?如果不知道的小盆友可以看看下面几个链接,或许对你有帮助。神经网络为什么要归一化深度学习-----数据预处理。还无法理解也没关系,只要记住,预处理会帮助我们加快神经网络的训练。

在pytorch中我们预处理用到了transforms函数:

transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])

compose函数会将多个transforms包在一起。

我们的transforms有好几种,例如transforms.ToTensor(), transforms.Scale()等,完整列表在。好好学习吧!

我只讲现在用到了两种:

1.ToTensor是指把PIL.Image(RGB) 或者numpy.ndarray(H x W x C) 从0到255的值映射到0到1的范围内,并转化成Tensor格式。

2.Normalize(mean,std)是通过下面公式实现数据归一化

channel=(channel-mean)/std

那么经过上面两个转换一折腾,我们的数据中的每个值就变成了[-1,1]的数了。

1到22行,我们从硬盘中读取数据,并将数据预处理(第13行,transform=transform),然后转换成4张图为一批的数据结构。26行到47行,为我们显示出一个图片例子,可有可无,不再作代码解释。

源代码下载

更多torchvision加载其他数据库方法


作者:Zen_君
链接:http://www.jianshu.com/p/8da9b24b2fb6
來源:简书
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

超简单!pytorch入门教程(四):准备图片数据集的更多相关文章

  1. 超简单!pytorch入门教程(五):训练和测试CNN

    我们按照超简单!pytorch入门教程(四):准备图片数据集准备好了图片数据以后,就来训练一下识别这10类图片的cnn神经网络吧. 按照超简单!pytorch入门教程(三):构造一个小型CNN构建好一 ...

  2. 超强、超详细Redis入门教程【转】

    这篇文章主要介绍了超强.超详细Redis入门教程,本文详细介绍了Redis数据库各个方面的知识,需要的朋友可以参考下 [本教程目录] 1.redis是什么2.redis的作者何许人也3.谁在使用red ...

  3. 超详细Redis入门教程【转】

    这篇文章主要介绍了超强.超详细Redis入门教程,本文详细介绍了Redis数据库各个方面的知识,需要的朋友可以参考下   [本教程目录] 1.redis是什么 2.redis的作者何许人也 3.谁在使 ...

  4. MongoDB最简单的入门教程之五-通过Restful API访问MongoDB

    通过前面四篇的学习,我们已经在本地安装了一个MongoDB数据库,并且通过一个简单的Spring boot应用的单元测试,插入了几条记录到MongoDB中,并通过MongoDB Compass查看到了 ...

  5. PySide——Python图形化界面入门教程(四)

    PySide——Python图形化界面入门教程(四) ——创建自己的信号槽 ——Creating Your Own Signals and Slots 翻译自:http://pythoncentral ...

  6. JasperReports入门教程(四):多数据源

    JasperReports入门教程(四):多数据源 背景 在报表使用中,一个页面需要打印多个表格,每个表格分别使用不同的数据源是很常见的一个需求.假如我们现在有一个需求如下:需要在一个报表同时打印所有 ...

  7. 无废话ExtJs 入门教程四[表单:FormPanel]

    无废话ExtJs 入门教程四[表单:FormPanel] extjs技术交流,欢迎加群(201926085) 继上一节内容,我们在窗体里加了个表单.如下所示代码区的第28行位置,items:form. ...

  8. MongoDB最简单的入门教程之二 使用nodejs访问MongoDB

    在前一篇教程 MongoDB最简单的入门教程之一 环境搭建 里,我们已经完成了MongoDB的环境搭建. 在localhost:27017的服务器上,在数据库admin下面创建了一个名为person的 ...

  9. MongoDB最简单的入门教程之三 使用Java代码往MongoDB里插入数据

    前两篇教程我们介绍了如何搭建MongoDB的本地环境: MongoDB最简单的入门教程之一 环境搭建 以及如何用nodejs读取MongoDB里的记录: MongoDB最简单的入门教程之二 使用nod ...

  10. MongoDB最简单的入门教程之四:使用Spring Boot操作MongoDB

    Spring Boot 是一个轻量级框架,可以完成基于 Spring 的应用程序的大部分配置工作.Spring Boot的目的是提供一组工具,以便快速构建容易配置的Spring应用程序,省去大量传统S ...

随机推荐

  1. Entity Framework 映射问题

    今天在数据库(mysql)新增了一个字段,但是一直以为添加字段,然后在实体模型中选择 一直是以为选择"添加",就导致有问题,原因就不说,有点蠢,人家都已经存在,还加上去干嘛,我要的 ...

  2. UVa 10285【搜索】

    UVa 10285 哇,竟然没超时!看网上有人说是记忆化搜索,其实不太懂是啥...感觉我写的就是毫无优化的dfs暴力....... 建立一个坐标方向结构体数组,每个节点dfs()往下搜就好了. #in ...

  3. php 对接java短信接口带有英文逗号就无法通过

    在对接短息接口时,对方是java接口,要求content两次编码 短信内容(Content)发起请求前必须进行URL转码.例如对于短信内容为“中文短信abc”,转码过程如下(java语言): Stri ...

  4. 洛谷P2512 [HAOI2008]糖果传递

    //不开long long见祖宗!!! #include<bits/stdc++.h> using namespace std; long long n,ans,sum; ],s[]; i ...

  5. 解析P2P金融的业务安全

    看了很多乙方同学们写的业务安全,总结下来,其出发点主要是在技术层面风险问题.另外捎带一些业务风险.今天我要谈的是甲方眼里的业务安全问题,甲方和乙方在业务安全的视野上会有一些区别和一些重合.在同一个问题 ...

  6. 2019-9-2-C#枚举中使用Flags特性

    title author date CreateTime categories C#枚举中使用Flags特性 lindexi 2019-09-02 12:57:37 +0800 2018-2-13 1 ...

  7. git比较两个版本之间的区别

    查看当前没有add 的内容修改: git diff 查看已经add 没有commit 的改动 git diff --cached 查看当前没有add和commit的改动: git diff HEAD ...

  8. 将 vue.js 获取的 html 文本转化为纯文本

    我存入数据表中的数据是使用 html  格式,获取数据是使用 vue 获取. 遇到了一个问题,就是界面上显示的数据是 html 格式的,但是我需要它显示纯文本. 怎么做呢?首先在  js  中写一个将 ...

  9. NIO 中文乱码自我解决的简单DEMO

    import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.FileOutputStrea ...

  10. 2015年NOIP普及组复赛题解

    题目涉及算法: 金币:入门题: 扫雷游戏:入门题: 求和:简单数学推导: 推销员:贪心. 金币 题目链接:https://www.luogu.org/problem/P2669 入门题,直接开一个循环 ...