利用平pytorch搭建简单的神经网络实现minist手写字体的识别,采用三层线性函数迭代运算,使得其具备一定的非线性转化与运算能力,其数学原理如下:

其具体实现代码如下所示:
import torch
import matplotlib.pyplot as plt
def plot_curve(data): #曲线输出函数构建
fig=plt.figure()
plt.plot(range(len(data)),data,color="blue")
plt.legend(["value"],loc="upper right")
plt.xlabel("step")
plt.ylabel("value")
plt.show() def plot_image(img,label,name): #输出二维图像灰度图
fig=plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307,cmap="gray",interpolation="none")
plt.title("{}:{}".format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
def one_hot(label,depth=10): #根据分类结果的数目将结果转换为一定的矩阵形式[n,1],n为分类结果的数目
out=torch.zeros(label.size(0),depth)
idx=torch.LongTensor(label).view(-1,1)
out.scatter_(dim=1,index=idx,value=1)
return out batch_size=512
import torch
from torch import nn #完成神经网络的构建包
from torch.nn import functional as F #包含常用的函数包
from torch import optim #优化工具包
import torchvision #视觉工具包
import matplotlib.pyplot as plt
from utils import plot_curve,plot_image,one_hot
#step1 load dataset 加载数据包
train_loader=torch.utils.data.DataLoader(
torchvision.datasets.MNIST("minist_data",train=True,download=True,transform=torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
])),
batch_size=batch_size,shuffle=True)
test_loader=torch.utils.data.DataLoader(
torchvision.datasets.MNIST("minist_data",train=True,download=False,transform=torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
])),
batch_size=batch_size,shuffle=False)
x,y=next(iter(train_loader))
print(x.shape,y.shape)
plot_image(x,y,"image")
print(x)
print(y) #构建神经网络结构
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
#xw+b
self.fc1=nn.Linear(28*28,256)
self.fc2=nn.Linear(256,64)
self.fc3=nn.Linear(64,10)
def forward(self, x):
#x:[b,1,28,28]
#h1=relu(xw1+b1)
x=F.relu(self.fc1(x))
#h2=relu(h1w2+b2)
x=F.relu(self.fc2(x))
#h3=h2w3+b3
x=(self.fc3(x))
return x net=Net()
#[w1,b1,w2,b2,w3,b3]
optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9)
train_loss=[]
for epoch in range(3):
for batch_idx,(x,y) in enumerate(train_loader):
#x:[b,1,28,28],y:[512]
x=x.view(x.size(0),28*28)
# => [b,10]
out =net(x)
# [b,10]
y_onehot=one_hot(y)
#loss=mse(out,y_onehot)
loss= F.mse_loss(out,y_onehot) optimizer.zero_grad()
loss.backward()
#w'=w-lr*grad
optimizer.step()
train_loss.append(loss.item()) if batch_idx %10==0:
print(epoch,batch_idx,loss.item()) #输出其预测loss损失函数的变化曲线
plot_curve(train_loss)
#get optimal [w1,b1,w2,b2,w3,b3] total_correct=0
for x,y in test_loader:
x=x.view(x.size(0),28*28)
out=net(x)
pred=out.argmax(dim=1)
correct=pred.eq(y).sum().float().item()
total_correct+=correct
total_num=len(test_loader.dataset)
acc=total_correct/total_num
print("test.acc:",acc) #输出整体预测的准确度 x,y=next(iter(test_loader))
out=net(x.view(x.size(0),28*28))
pred=out.argmax(dim=1)
plot_image(x,pred,"test")
实现结果如下所示:


pytorch深度学习神经网络实现手写字体识别的更多相关文章

  1. 深度学习之 mnist 手写数字识别

    深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...

  2. 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  3. 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识

    深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...

  4. 深度学习---手写字体识别程序分析(python)

    我想大部分程序员的第一个程序应该都是“hello world”,在深度学习领域,这个“hello world”程序就是手写字体识别程序. 这次我们详细的分析下手写字体识别程序,从而可以对深度学习建立一 ...

  5. 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别

    深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...

  6. BP神经网络的手写数字识别

    BP神经网络的手写数字识别 ANN 人工神经网络算法在实践中往往给人难以琢磨的印象,有句老话叫“出来混总是要还的”,大概是由于具有很强的非线性模拟和处理能力,因此作为代价上帝让它“黑盒”化了.作为一种 ...

  7. TensorFlow卷积神经网络实现手写数字识别以及可视化

    边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...

  8. 利用c++编写bp神经网络实现手写数字识别详解

    利用c++编写bp神经网络实现手写数字识别 写在前面 从大一入学开始,本菜菜就一直想学习一下神经网络算法,但由于时间和资源所限,一直未展开比较透彻的学习.大二下人工智能课的修习,给了我一个学习的契机. ...

  9. 第二节,TensorFlow 使用前馈神经网络实现手写数字识别

    一 感知器 感知器学习笔记:https://blog.csdn.net/liyuanbhu/article/details/51622695 感知器(Perceptron)是二分类的线性分类模型,其输 ...

随机推荐

  1. Arrays类的概述和常用的方法

    1.  2.为了防止外界创造对象,系统把Arrays的无参构造方法设为私有: 并且再其类方法用静态修饰,强制你用类名调用方法,另外math和system也是如此

  2. 安装RabbitMQ,一直提示Erlang版本过低

    1.背景 windows系统,控制面板卸载Erlang后,重新安装Erlang成功,当再安装RabbitMQ时,报如下提示: 意思就是说Erlang版本过低,请安装更高的版本. 出现上面问题的原因,是 ...

  3. 学习笔记(26)- plato-端到端模型-定闹钟

    今天用了定闹钟的场景语料,在plato框架尝试了端到端的模型. 本文先记录英文的训练过程,然后记录中文的训练过程. 训练端到端的模型 发现使用英文的模型,还是显示有中文,所以,新建目录,重新训练 1. ...

  4. contextField 键盘只允许输入数字和小数点,并且现在小数点后位数

    - (BOOL)textField:(UITextField *)textField shouldChangeCharactersInRange:(NSRange)range replacementS ...

  5. 计算机二级-C语言-程序修改题-190113记录-对指定字符串的大小写变换处理。

    //给定程序中fun函数的功能是:将p所指的字符串中每个单词的最后一个字母改成大写.(这里的“单词”是指由空格隔开的字符串) //重难点:指针对数组的遍历.大小写转换的方法.第一种使用加减32 得到, ...

  6. 【笔记6-支付及订单模块】从0开始 独立完成企业级Java电商网站开发(服务端)

    支付模块 实际开发工作中经常会遇见如下场景,一个支付模块,一个订单模块,有一定依赖,一个同事负责支付模块,另一个同事负责订单模块,但是开发支付模块的时候要依赖订单模块的相关类 ,方法,或者工具类,这些 ...

  7. 【Go语言系列】1.1、GO语言简介:什么是GO语言

    一.Go的起源 Go语言的所有设计者都说,设计Go语言是因为 C++ 给他们带来了挫败感.在 Google I/O 2012 的 Go 设计小组见面会上,Rob Pike 是这样说的: 我们做了大量的 ...

  8. 为 git 设定 socks5 代理

    为 git 设定 socks5 代理 查看当前设定 git config --global -l 为 git 设定全局代理 git config --global http.proxy socks5h ...

  9. Python:时间日历基本处理

    time 模块 提供了处理时间和表示之间转换的功能 获取当前时间戳 时间戳:从0时区的1970年1月1日0时0分0秒,到所给定日期时间的时间,浮点秒数,或者毫秒整数 获取方式: import time ...

  10. IPSec的高可用性技术

    IPSec VPN的高可用性技术:①.DPD(Dead Peer Detection)对等体检测                      ——旨在检查有问题的IPSec VPN网络,并快速的切换到备 ...