如何入门Pytorch之四:搭建神经网络训练MNIST
上一节我们学习了Pytorch优化网络的基本方法,本节我们将以MNIST数据集为例,通过搭建一个完整的神经网络,来加深对Pytorch的理解。
一、数据集
MNIST是一个非常经典的数据集,下载链接:http://yann.lecun.com/exdb/mnist/
下载下来的文件如下:

该手写数字数据库具有60,000个示例的训练集和10,000个示例的测试集。它是NIST提供的更大集合的子集。数字已经过尺寸标准化,并以固定尺寸的图像为中心。
手写数字识别是一个比较简单的任务,它是一个10分类问题,(0-9),之所以选这个数据集,是因为识别难度低,计算量小,数据容易获得。
二、模型搭建
1、网络节点的确定
对于不同的目的,网络的选择也是不一样的。一般来说,网络容量和数据集大小是对应的。一个小型数据集也只需要一个小型的网络。
这里有一个经验值:
1)model_size=sqrt(in_size*out_size)
2)model_size=log(in_size)
3) model_size=sqrt(in_size*out_size)
model_size:网络的节点量
in_size:输入的节点量
out_size输出的节点量
2、导入pytorch包
import torch
import torchvision
import trochvision import datasets
import trochvision import transforms
from torch.autograd import Variable
3、获取训练集和测试集
#root用于指定数据集下载后的存放路径
#transform用于指定导入数据集需要对数据进行变换操作
#train指定在数据集下载后需要载入哪部分数据,true为训练集,false为测试集
data_train=datasets.MNIST(root="./data/",transform=transform,train=True,download=True)
data_test=datasets.MNIST(root='./data/',transform=transform,train=False)
4、数据预览和装载
#数据装载,可以理解为对图片的处理
#处理完成后,将图片送给模型训练,装载就是打包的过程
#dataset 用于指定载入的数据集名称
#batch_size设置了每个包的图片数据数据个数
#shuffle 装载过程将数据随机打乱并打包
data_loader_train=torch.utils.data.DataLoader(dataset=data_train,batch_size=64,shuffle=True)
data_loader_test=torch.utils.data.DataLoader(dataset=data_test,batch_size=64,shuffle=True)
如何入门Pytorch之四:搭建神经网络训练MNIST的更多相关文章
- 使用pytorch快速搭建神经网络实现二分类任务(包含示例)
使用pytorch快速搭建神经网络实现二分类任务(包含示例) Introduce 上一篇学习笔记介绍了不使用pytorch包装好的神经网络框架实现logistic回归模型,并且根据autograd实现 ...
- 用Kersa搭建神经网络【MNIST手写数据集】
MNIST手写数据集的识别算得上是深度学习的”hello world“了,所以想要入门必须得掌握.新手入门可以考虑使用Keras框架达到快速实现的目的. 完整代码如下: # 1. 导入库和模块 fro ...
- TensorFlow初探之简单神经网络训练mnist数据集(TensorFlow2.0代码)
from __future__ import print_function from tensorflow.examples.tutorials.mnist import input_data #加载 ...
- mxnet卷积神经网络训练MNIST数据集测试
mxnet框架下超全手写字体识别—从数据预处理到网络的训练—模型及日志的保存 import numpy as np import mxnet as mx import logging logging. ...
- 使用一层神经网络训练mnist数据集
import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...
- 如何入门Pytorch之二:如何搭建实用神经网络
上一节中,我们介绍了Pytorch的基本知识,如数据格式,梯度,损失等内容. 在本节中,我们将介绍如何使用Pytorch来搭建一个经典的分类神经网络. 搭建一个神经网络并训练,大致有这么四个部分: 1 ...
- 如何入门Pytorch之三:如何优化神经网络
在上一节中,我们介绍了如何使用Pytorch来搭建一个经典的分类神经网络.一般情况下,搭建完模型后训练不会一次就能达到比较好的效果,这样,就需要不断的调整和优化模型的各个部分.从而引出了本文的主旨:如 ...
- 【pytorch】学习笔记(四)-搭建神经网络进行关系拟合
[pytorch学习笔记]-搭建神经网络进行关系拟合 学习自莫烦python 目标 1.创建一些围绕y=x^2+噪声这个函数的散点 2.用神经网络模型来建立一个可以代表他们关系的线条 建立数据集 im ...
- keras搭建神经网络快速入门笔记
之前学习了tensorflow2.0的小伙伴可能会遇到一些问题,就是在读论文中的代码和一些实战项目往往使用keras+tensorflow1.0搭建, 所以本次和大家一起分享keras如何搭建神经网络 ...
随机推荐
- 2020-06-16:Redis hgetall时间复杂度?
福哥答案2020-06-16: 时间复杂度是O(N).时间复杂度:O(N) where N is the size of the hash.
- Flutter 容器(5) - SizedBox
SizedBox: 两种用法:一是可用来设置两个widget之间的间距,二是可以用来限制子组件的大小. import 'package:flutter/material.dart'; class Au ...
- 利用BeautifulSoup去除HTML指定标签和去除注释
去除指定标签 from bs4 import BeautifulSoup #去除属性ul [s.extract() for s in soup("ul")] # 去除属性svg [ ...
- 封装react antd的form表单组件
form表单在我们日常的开发过程中被使用到的概率还是很大的,比如包含了登录.注册.修改个人信息.新增修改业务数据等的公司内部管理系统.而在使用时这些表单的样式如高度.上下边距.边框.圆角.阴影.高亮等 ...
- Node学习基础之安装node以及配置环境变量
第一步去node官网下载nodejs 我放在D盘 接着在cmd输入node -v 就能得到node的版本号 还有npm -v 下来进入安装好的目录 nodejs目录 创建两个文件夹 node_cach ...
- 朋友国企干了5年java,居然不知道Dubbo是做什么呢?我真信了
点赞再看,养成习惯,微信搜一搜[三太子敖丙]关注这个喜欢写情怀的程序员. 本文 GitHub https://github.com/JavaFamily 已收录,有一线大厂面试完整考点.资料以及我的系 ...
- python 01 print input int
学过c语言与c语言的数据结构与算法后再来学习python,感觉编程的核心内容没有变,但每个编程语言都有自己的特点.本次学习的目标是理解python的特点与用法,把学过的bif(内置函数)用法记录下来, ...
- JavaScript学习系列博客_4_JavaScript中的数据类型
JavaScript中有6种数据类型 一.基本数据类型 - String 字符串 JS中的字符串需要使用引号引起来双引号或单引号都行 但是要注意的是某种引号嵌套使用的话,需要加上 \ 转义.比如说我们 ...
- layui上传同一张图片第二次时choose没有反应
将上传文件的input的val设置为空 $("#test11").parent().find("input").val('');
- Rethinking the performance comparison between SNNS and ANNS
郑重声明:原文参见标题,如有侵权,请联系作者,将会撤销发布! Abstract ANN是通向AI的一种流行方法,它已经通过成熟的模型,各种基准,开源数据集和强大的计算平台获得了非凡的成功.SNN是一类 ...