利用平pytorch搭建简单的神经网络实现minist手写字体的识别,采用三层线性函数迭代运算,使得其具备一定的非线性转化与运算能力,其数学原理如下:

其具体实现代码如下所示:
import torch
import matplotlib.pyplot as plt
def plot_curve(data): #曲线输出函数构建
fig=plt.figure()
plt.plot(range(len(data)),data,color="blue")
plt.legend(["value"],loc="upper right")
plt.xlabel("step")
plt.ylabel("value")
plt.show() def plot_image(img,label,name): #输出二维图像灰度图
fig=plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307,cmap="gray",interpolation="none")
plt.title("{}:{}".format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
def one_hot(label,depth=10): #根据分类结果的数目将结果转换为一定的矩阵形式[n,1],n为分类结果的数目
out=torch.zeros(label.size(0),depth)
idx=torch.LongTensor(label).view(-1,1)
out.scatter_(dim=1,index=idx,value=1)
return out batch_size=512
import torch
from torch import nn #完成神经网络的构建包
from torch.nn import functional as F #包含常用的函数包
from torch import optim #优化工具包
import torchvision #视觉工具包
import matplotlib.pyplot as plt
from utils import plot_curve,plot_image,one_hot
#step1 load dataset 加载数据包
train_loader=torch.utils.data.DataLoader(
torchvision.datasets.MNIST("minist_data",train=True,download=True,transform=torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
])),
batch_size=batch_size,shuffle=True)
test_loader=torch.utils.data.DataLoader(
torchvision.datasets.MNIST("minist_data",train=True,download=False,transform=torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
])),
batch_size=batch_size,shuffle=False)
x,y=next(iter(train_loader))
print(x.shape,y.shape)
plot_image(x,y,"image")
print(x)
print(y) #构建神经网络结构
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
#xw+b
self.fc1=nn.Linear(28*28,256)
self.fc2=nn.Linear(256,64)
self.fc3=nn.Linear(64,10)
def forward(self, x):
#x:[b,1,28,28]
#h1=relu(xw1+b1)
x=F.relu(self.fc1(x))
#h2=relu(h1w2+b2)
x=F.relu(self.fc2(x))
#h3=h2w3+b3
x=(self.fc3(x))
return x net=Net()
#[w1,b1,w2,b2,w3,b3]
optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9)
train_loss=[]
for epoch in range(3):
for batch_idx,(x,y) in enumerate(train_loader):
#x:[b,1,28,28],y:[512]
x=x.view(x.size(0),28*28)
# => [b,10]
out =net(x)
# [b,10]
y_onehot=one_hot(y)
#loss=mse(out,y_onehot)
loss= F.mse_loss(out,y_onehot) optimizer.zero_grad()
loss.backward()
#w'=w-lr*grad
optimizer.step()
train_loss.append(loss.item()) if batch_idx %10==0:
print(epoch,batch_idx,loss.item()) #输出其预测loss损失函数的变化曲线
plot_curve(train_loss)
#get optimal [w1,b1,w2,b2,w3,b3] total_correct=0
for x,y in test_loader:
x=x.view(x.size(0),28*28)
out=net(x)
pred=out.argmax(dim=1)
correct=pred.eq(y).sum().float().item()
total_correct+=correct
total_num=len(test_loader.dataset)
acc=total_correct/total_num
print("test.acc:",acc) #输出整体预测的准确度 x,y=next(iter(test_loader))
out=net(x.view(x.size(0),28*28))
pred=out.argmax(dim=1)
plot_image(x,pred,"test")
实现结果如下所示:


pytorch深度学习神经网络实现手写字体识别的更多相关文章

  1. 深度学习之 mnist 手写数字识别

    深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...

  2. 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  3. 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识

    深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...

  4. 深度学习---手写字体识别程序分析(python)

    我想大部分程序员的第一个程序应该都是“hello world”,在深度学习领域,这个“hello world”程序就是手写字体识别程序. 这次我们详细的分析下手写字体识别程序,从而可以对深度学习建立一 ...

  5. 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别

    深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...

  6. BP神经网络的手写数字识别

    BP神经网络的手写数字识别 ANN 人工神经网络算法在实践中往往给人难以琢磨的印象,有句老话叫“出来混总是要还的”,大概是由于具有很强的非线性模拟和处理能力,因此作为代价上帝让它“黑盒”化了.作为一种 ...

  7. TensorFlow卷积神经网络实现手写数字识别以及可视化

    边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...

  8. 利用c++编写bp神经网络实现手写数字识别详解

    利用c++编写bp神经网络实现手写数字识别 写在前面 从大一入学开始,本菜菜就一直想学习一下神经网络算法,但由于时间和资源所限,一直未展开比较透彻的学习.大二下人工智能课的修习,给了我一个学习的契机. ...

  9. 第二节,TensorFlow 使用前馈神经网络实现手写数字识别

    一 感知器 感知器学习笔记:https://blog.csdn.net/liyuanbhu/article/details/51622695 感知器(Perceptron)是二分类的线性分类模型,其输 ...

随机推荐

  1. SQLite3创建表及操作

    SQLite 创建表 SQLite 的 CREATE TABLE 语句用于在任何给定的数据库创建一个新表.创建基本表,涉及到命名表.定义列及每一列的数据类型. 语法 CREATE TABLE 语句的基 ...

  2. 【笔记7-部署发布】从0开始 独立完成企业级Java电商网站开发(服务端)

    阿里云服务 购买 连接 购买域名 域名备案 域名解析 源配置步骤 资源地址 http://learning.happymmall.com/ 配置阿里云的yum源 1.备份 mv /etc/yum.re ...

  3. HDU1285-确定比赛名次(拓扑+优先队列)

    对于拓扑排序,每次能入队的只有入度为0的点,所以用优先队列即可. 以及,第一组数据日常卡OJ,这组数据跳了一个点,我的程序这个版本也过不了(其实写了另一个版的),稍微改改更正确. #include & ...

  4. ubuntu 允许root用户登录到ssh

    ubuntu的系统太太太麻烦了,我喜欢centos,但是还是要用ubuntu做东西,讨厌,装完系统以后,因为他不让你用root,我新建了一个wqz的用户名. 1.首先更新root的密码 sudo pa ...

  5. pdf.js的使用(2)新的需求已经出现,怎么能够停止不前(迪迦奥特曼主题曲)哈哈哈。^_^

    来,咱们看图说事 按钮1,2是pdf.js自带的,分别对应顺时针旋转90度,逆时针旋转90度.于是乎又要我做一个旋转180度的按钮,诺!按钮3来了. 1.别怂,干!首先顺藤摸瓜,看按钮1,2的html ...

  6. 转入软工后第一节java课的作业

    这个作业,鸽了好久.本来大家都在中秋前发了,我摸摸索索加上各种缓慢的学习,终于是将他做完了. 做完之后,java最基本的输入输出功能都基本学习到了.下面附上代码: import java.util.* ...

  7. C语言-define 与do{}while(0)

    问题引出: 我们都知道宏定义#define只是简单替换,所以遇到复杂的带参数宏,必须很小心的为需要的参数加上括号“()”:同样碰到复杂的多条语句替代,虽然加{}可以将其封装成一个整体,但同时又有另一个 ...

  8. Python基础语法笔记2

    ------------------------------------------------------------------------------- 常量和Pylint的规范 1.常量:常量 ...

  9. [运维] 请求 nginx 出现 502 Bad Gateway 的解决方案!

    环境: 云服务器镜像 Linux CentOS 7.6 已经安装并成功配置 SSL 的 nginx 1.16.1 成功安装并且可以正常运行的 apache-tomcat-9.0.26 遇到的问题: 在 ...

  10. SpringBoot实现restuful风格的CRUD

    restuful风格: 百度百科: RESTFUL是一种网络应用程序的设计风格和开发方式,基于HTTP,可以使用XML格式定义或JSON格式定义.RESTFUL适用于移动互联网厂商作为业务使能接口的场 ...