使用PyTorch构建神经网络模型进行手写识别

PyTorch是一种基于Torch库的开源机器学习库,应用于计算机视觉和自然语言处理等应用,本章内容将从安装以及通过Torch构建基础的神经网络,计算梯度为主要内容进行学习。

How can we install Torch?

Torch在Linux,Windows,Mac等开发环境下都有特定的安装方法,首先搜索官方网页https://pytorch.org/,由下图所示我们可以根据自己适合的环境进行选择,我使用的是1.9.0版本Windows环境下conda包Python语言,CPU计算平台的安装。



安装过程需要打开Anaconda命令行输入下方所给提示命令指引,



安装好Torch后打开常用的编辑器进行测试



OK,我们可以看到已经成功的在电脑上安装了Torch

下列代码均在Jupyter NoteBook编辑,conda等安装方式不在此文章说明

在“PYTORCH”中定义神经网络

深度学习算法即为神经网络算法,它是由多层互连计算单元组成的计算系统。通过这些相互连接的单元传递数据,神经网络能够学习如何近似将输入转换位输出所需的计算。在Torch中可以使用torch.nn包构建神经网络。

最常听说的也是最基础的MNIST数据集也就是手写识别数据,定义用于MNIST数据集的神经网络需要如下步骤

1.导入库

2.定义初始化神经网络

3.指定数据集构建模型

4.通过模型传递数据进行测试

将从应用角度出发,下述内容神经网络名词定义不做过多叙述。

导入相关库加载数据

构建神经网络所需库为torch.nn以及torch.nn.functional

import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

通过上述模块和类,torch.nn帮助我们创建和训练神经网络,包含forward(input),返回output

定义,初始化神经网络

我们定义的神经网络将帮助我们识别图像,将使用PyTorch内置的卷积。卷积过程将图像的每个元素添加到local neighbors,由内核或小型矩阵权重配比,将有助于我们从输入图像中提取某些特征(边缘检测,锐度,模糊度等)。

定义Net模型的类有两个要求。第一个是编写一个__init__引用nn.Moudle。这个函数是你在你神经网络中定义全连接层的地方。

使用卷积,我们从构建的神经网络模型输出一个图像通道,输出匹配数字从0到9的10个标签的目标,下列构建传统的MNIST算法

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__() # First 2D convolutional layer, taking in 1 input channel (image),
# outputting 32 convolutional features, with a square kernel size of 3
self.conv1 = nn.Conv2d(1, 32, 3, 1)
# Second 2D convolutional layer, taking in the 32 input layers,
# outputting 64 convolutional features, with a square kernel size of 3
self.conv2 = nn.Conv2d(32, 64, 3, 1) # Designed to ensure that adjacent pixels are either all 0s or all active
# with an input probability
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5) # First fully connected layer
self.fc1 = nn.Linear(9216, 128)
# Second fully connected layer that outputs our 10 labels
self.fc2 = nn.Linear(128, 10) my_nn = Net()
print(my_nn)

如代码所示,构建的三层神经网络,第一个二维接收层,输入图像数据,输出32个特征,平方核大小为3,第二个二维convolutional 层输入32组数据得到64个特征平方核大小为3

通过指定数据传递进行训练

我们已经完成了神经网络的定义,下面将使用数据进行训练,在使用PyTorch构建模型只需要定义foward函数,将数据传递到计算图中,将代表我们的前馈算法。

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10) # x represents our data
def forward(self, x):
# Pass data through conv1
x = self.conv1(x)
# Use the rectified-linear activation function over x
x = F.relu(x) x = self.conv2(x)
x = F.relu(x) # Run max pooling over x
x = F.max_pool2d(x, 2)
# Pass data through dropout1
x = self.dropout1(x)
# Flatten x with start_dim=1
x = torch.flatten(x, 1)
# Pass data through fc1
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x) # Apply softmax to x
output = F.log_softmax(x, dim=1)
return output

参考开发文档:https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html

推荐阅读

使用PyTorch构建神经网络模型进行手写识别的更多相关文章

  1. 学习笔记CB009:人工神经网络模型、手写数字识别、多层卷积网络、词向量、word2vec

    人工神经网络,借鉴生物神经网络工作原理数学模型. 由n个输入特征得出与输入特征几乎相同的n个结果,训练隐藏层得到意想不到信息.信息检索领域,模型训练合理排序模型,输入特征,文档质量.文档点击历史.文档 ...

  2. TensorFlow 入门之手写识别CNN 三

    TensorFlow 入门之手写识别CNN 三 MNIST 卷积神经网络 Fly 多层卷积网络 多层卷积网络的基本理论 构建一个多层卷积网络 权值初始化 卷积和池化 第一层卷积 第二层卷积 密集层连接 ...

  3. 77、tensorflow手写识别基础版本

    ''' Created on 2017年4月20日 @author: weizhen ''' #手写识别 from tensorflow.examples.tutorials.mnist import ...

  4. tensorflow笔记(五)之MNIST手写识别系列二

    tensorflow笔记(五)之MNIST手写识别系列二 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7455233.html ...

  5. Tensorflow之基于MNIST手写识别的入门介绍

    Tensorflow是当下AI热潮下,最为受欢迎的开源框架.无论是从Github上的fork数量还是star数量,还是从支持的语音,开发资料,社区活跃度等多方面,他当之为superstar. 在前面介 ...

  6. 使用tensorflow实现mnist手写识别(单层神经网络实现)

    import tensorflow as tf import tensorflow.examples.tutorials.mnist.input_data as input_data import n ...

  7. Tensorflow编程基础之Mnist手写识别实验+关于cross_entropy的理解

    好久没有静下心来写点东西了,最近好像又回到了高中时候的状态,休息不好,无法全心学习,恶性循环,现在终于调整的好一点了,听着纯音乐突然非常伤感,那些曾经快乐的大学时光啊,突然又慢慢的一下子出现在了眼前, ...

  8. 10分钟教你用python 30行代码搞定简单手写识别!

    欲直接下载代码文件,关注我们的公众号哦!查看历史消息即可! 手写笔记还是电子笔记好呢? 毕业季刚结束,眼瞅着2018级小萌新马上就要来了,老腊肉小编为了咱学弟学妹们的学习,绞尽脑汁准备编一套大学秘籍, ...

  9. 【Win 10 应用开发】手写识别

    记得前面(忘了是哪天写的,反正是前些天,请用力点击这里观看)老周讲了一个14393新增的控件,可以很轻松地结合InkCanvas来完成涂鸦.其实,InkCanvas除了涂鸦外,另一个大用途是墨迹识别, ...

随机推荐

  1. 大数据学习day17------第三阶段-----scala05------1.Akka RPC通信案例改造和部署在多台机器上 2. 柯里化方法 3. 隐式转换 4 scala的泛型

    1.Akka RPC通信案例改造和部署在多台机器上  1.1 Akka RPC通信案例的改造(主要是把一些参数不写是) Master package com._51doit.akka.rpc impo ...

  2. SpringCloud微服务实战——搭建企业级开发框架(三十二):代码生成器使用配置说明

    一.新建数据源配置 因考虑到多数据源问题,代码生成器作为一个通用的模块,后续可能会为其他工程生成代码,所以,这里不直接读取系统工程配置的数据源,而是让用户自己维护. 参数说明 数据源名称:用于查找区分 ...

  3. clickhouse安装数据导入及查询测试

    官网 https://clickhouse.tech/ quick start ubantu wget https://repo.yandex.ru/clickhouse/deb/lts/main/c ...

  4. HTTPS及流程简析

    [序] 在我们在浏览某些网站的时候,有时候浏览器提示需要安装根证书,可是为什么浏览器会提示呢?估计一部分人想也没想就直接安装了,不求甚解不好吗? 那么什么是根证书呢?在大概的囫囵吞枣式的百度之后知道了 ...

  5. 关于为了一时方便,使用@Scheduled注解定时踩的坑

    摘要: 事情是这样的前两周在做项目的时候碰到一个需求---要求每天晚上执行一个任务,公司统一使用的是 xxl-job 写定时任务的,我当时为了方便自己,然后就简单的使用了Spring的那个@Sched ...

  6. Centos7源码部署Redis3.2.9

    目录 一.环境准备 二.安装 三.测试 四.编写启动脚本 一.环境准备 [Redis-Server] 主机名 = host-1 系统 = centos-7.3 地址 = 1.1.1.1 软件 = re ...

  7. 2020ACTF pwn writeup

    为了打2021的ACTF,想着把2020年的pwn题做一做吧,发现2020年的pwn题质量还挺高的.反倒是2021年的题目质量不太高,好像是没有专门的pwn师傅出题,可以理解,毕竟办校赛,说白了就是用 ...

  8. Element-UI 使用 class 方式和 css 方式引入图标

    今天在使用 vxe-table 时,需要引入 Element UI的图标,顺便就找了下这些组件库中图标的引用方式. 我们知道 Element .Ant Design.Font Awesome 等很多组 ...

  9. 自定义函数(Power Query 之 M 语言)

    数据源: 任意工作簿 目标: 使用自定义函数实现将数据源导入Power Query编辑器 操作过程: PowerQuery编辑器>主页>新建源>其他源>空查询 编辑栏内写入公式 ...

  10. SP1798 ASSIST - Assistance Required 题解

    Content 有一个足够长的数列 \(a\),是一个首项为 \(2\),公差为 \(1\) 的等差递增数列.另有一个初始为空的数列 \(b\). 重复进行如下操作: 假设当前数列 \(a\) 第一项 ...