上一节我们学习了Pytorch优化网络的基本方法,本节我们将以MNIST数据集为例,通过搭建一个完整的神经网络,来加深对Pytorch的理解。

一、数据集

MNIST是一个非常经典的数据集,下载链接:http://yann.lecun.com/exdb/mnist/

下载下来的文件如下:

该手写数字数据库具有60,000个示例的训练集和10,000个示例的测试集。它是NIST提供的更大集合的子集。数字已经过尺寸标准化,并以固定尺寸的图像为中心。

手写数字识别是一个比较简单的任务,它是一个10分类问题,(0-9),之所以选这个数据集,是因为识别难度低,计算量小,数据容易获得。

二、模型搭建

1、网络节点的确定

对于不同的目的,网络的选择也是不一样的。一般来说,网络容量和数据集大小是对应的。一个小型数据集也只需要一个小型的网络。

这里有一个经验值:

1)model_size=sqrt(in_size*out_size)

2)model_size=log(in_size)

3)  model_size=sqrt(in_size*out_size)

model_size:网络的节点量

in_size:输入的节点量

out_size输出的节点量

2、导入pytorch包

import torch
import torchvision
import trochvision import datasets
import trochvision import transforms
from torch.autograd import Variable

3、获取训练集和测试集

#root用于指定数据集下载后的存放路径
#transform用于指定导入数据集需要对数据进行变换操作
#train指定在数据集下载后需要载入哪部分数据,true为训练集,false为测试集
data_train=datasets.MNIST(root="./data/",transform=transform,train=True,download=True)
data_test=datasets.MNIST(root='./data/',transform=transform,train=False)

4、数据预览和装载

#数据装载,可以理解为对图片的处理
#处理完成后,将图片送给模型训练,装载就是打包的过程
#dataset 用于指定载入的数据集名称
#batch_size设置了每个包的图片数据数据个数
#shuffle 装载过程将数据随机打乱并打包
data_loader_train=torch.utils.data.DataLoader(dataset=data_train,batch_size=64,shuffle=True)
data_loader_test=torch.utils.data.DataLoader(dataset=data_test,batch_size=64,shuffle=True)

  

如何入门Pytorch之四:搭建神经网络训练MNIST的更多相关文章

  1. 使用pytorch快速搭建神经网络实现二分类任务(包含示例)

    使用pytorch快速搭建神经网络实现二分类任务(包含示例) Introduce 上一篇学习笔记介绍了不使用pytorch包装好的神经网络框架实现logistic回归模型,并且根据autograd实现 ...

  2. 用Kersa搭建神经网络【MNIST手写数据集】

    MNIST手写数据集的识别算得上是深度学习的”hello world“了,所以想要入门必须得掌握.新手入门可以考虑使用Keras框架达到快速实现的目的. 完整代码如下: # 1. 导入库和模块 fro ...

  3. TensorFlow初探之简单神经网络训练mnist数据集(TensorFlow2.0代码)

    from __future__ import print_function from tensorflow.examples.tutorials.mnist import input_data #加载 ...

  4. mxnet卷积神经网络训练MNIST数据集测试

    mxnet框架下超全手写字体识别—从数据预处理到网络的训练—模型及日志的保存 import numpy as np import mxnet as mx import logging logging. ...

  5. 使用一层神经网络训练mnist数据集

    import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...

  6. 如何入门Pytorch之二:如何搭建实用神经网络

    上一节中,我们介绍了Pytorch的基本知识,如数据格式,梯度,损失等内容. 在本节中,我们将介绍如何使用Pytorch来搭建一个经典的分类神经网络. 搭建一个神经网络并训练,大致有这么四个部分: 1 ...

  7. 如何入门Pytorch之三:如何优化神经网络

    在上一节中,我们介绍了如何使用Pytorch来搭建一个经典的分类神经网络.一般情况下,搭建完模型后训练不会一次就能达到比较好的效果,这样,就需要不断的调整和优化模型的各个部分.从而引出了本文的主旨:如 ...

  8. 【pytorch】学习笔记(四)-搭建神经网络进行关系拟合

    [pytorch学习笔记]-搭建神经网络进行关系拟合 学习自莫烦python 目标 1.创建一些围绕y=x^2+噪声这个函数的散点 2.用神经网络模型来建立一个可以代表他们关系的线条 建立数据集 im ...

  9. keras搭建神经网络快速入门笔记

    之前学习了tensorflow2.0的小伙伴可能会遇到一些问题,就是在读论文中的代码和一些实战项目往往使用keras+tensorflow1.0搭建, 所以本次和大家一起分享keras如何搭建神经网络 ...

随机推荐

  1. 浅谈:C#中的非泛型集合

    1.首先:ArrayList:非泛型集合 List:泛型集合 集合跟数组比较我们更容易理解.数组:1,长度固定2,数据类型预先声明 集合:1,长度可变2,数据类型预先声明的为泛型集合,数据类型不限定为 ...

  2. 题解 BZOJ4709

    题目描述 一道简单DP优化调了好久qwq 首先分析题目,发现每次从一边取贝壳是完全没用的,此题本质就是将区间分成数个区间,使区间价值和最大. 可以发现一个性质,那就是最优解的每个区间的两端点一定相同且 ...

  3. Jmeter(二十一) - 从入门到精通 - JMeter断言 - 上篇(详解教程)

    1.简介 最近由于宏哥在搭建自己的个人博客可能更新的有点慢.断言组件用来对服务器的响应数据做验证,常用的断言是响应断言,其支持正则表达式.虽然我们的通过响应断言能够完成绝大多数的结果验证工作,但是JM ...

  4. servlet的生命周期和工作原理介绍

    一.servlet生命周期 Servlet生命周期分为三个阶段: 1)初始化阶段: 调用init()方法 2)响应客户请求阶段:调用service()方法 3)终止阶段:调用destroy()方法 T ...

  5. Eclipse导入项目后JSP页面出现报红

    Multiple annotations found at this line:- javax.servlet.jsp.JspException cannot be resolved to a typ ...

  6. 2020.5.21 第一篇 Scrum冲刺博客

    Team:银河超级无敌舰队 Project:招新通 项目冲刺集合贴:链接 目录 一.Alpha 阶段成员任务安排 二.明日任务安排 三.预期的任务量 四.敏捷开发前的感想 五.团队期望 一.Alpha ...

  7. Jupyter Notebook 入门指南

    https://www.jianshu.com/p/061c6e5c4b0d cmd输入 :jupyter notebook

  8. Python中print()函数不换行的方法以及分隔符替换

    一.让print()函数不换行 在Python中,print()函数默认是换行的.但是,在很多情况下,我们需要不换行的输出(比如在算法竞赛中).那么,在Python中如何做到这一点呢? 其实很简单.只 ...

  9. [CSP-S2019]Emiya 家今天的饭 题解

    CSP-S2 2019 D2T1 很不错的一题DP,通过这道题学到了很多. 身为一个对DP一窍不通的蒟蒻,在考场上还挣扎了1h来推式子,居然还有几次几乎推出正解,然而最后还是只能打个32分的暴搜滚粗 ...

  10. JavaScript对象原型链的学习

    1.构造函数和原型 1.1对象的三种创建方式 字面量方式 var obj = {}; new关键字 var obj = new Object(); 构造函数方式 function Person(nam ...