利用平pytorch搭建简单的神经网络实现minist手写字体的识别,采用三层线性函数迭代运算,使得其具备一定的非线性转化与运算能力,其数学原理如下:

其具体实现代码如下所示:
import torch
import matplotlib.pyplot as plt
def plot_curve(data): #曲线输出函数构建
fig=plt.figure()
plt.plot(range(len(data)),data,color="blue")
plt.legend(["value"],loc="upper right")
plt.xlabel("step")
plt.ylabel("value")
plt.show() def plot_image(img,label,name): #输出二维图像灰度图
fig=plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307,cmap="gray",interpolation="none")
plt.title("{}:{}".format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
def one_hot(label,depth=10): #根据分类结果的数目将结果转换为一定的矩阵形式[n,1],n为分类结果的数目
out=torch.zeros(label.size(0),depth)
idx=torch.LongTensor(label).view(-1,1)
out.scatter_(dim=1,index=idx,value=1)
return out batch_size=512
import torch
from torch import nn #完成神经网络的构建包
from torch.nn import functional as F #包含常用的函数包
from torch import optim #优化工具包
import torchvision #视觉工具包
import matplotlib.pyplot as plt
from utils import plot_curve,plot_image,one_hot
#step1 load dataset 加载数据包
train_loader=torch.utils.data.DataLoader(
torchvision.datasets.MNIST("minist_data",train=True,download=True,transform=torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
])),
batch_size=batch_size,shuffle=True)
test_loader=torch.utils.data.DataLoader(
torchvision.datasets.MNIST("minist_data",train=True,download=False,transform=torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
])),
batch_size=batch_size,shuffle=False)
x,y=next(iter(train_loader))
print(x.shape,y.shape)
plot_image(x,y,"image")
print(x)
print(y) #构建神经网络结构
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
#xw+b
self.fc1=nn.Linear(28*28,256)
self.fc2=nn.Linear(256,64)
self.fc3=nn.Linear(64,10)
def forward(self, x):
#x:[b,1,28,28]
#h1=relu(xw1+b1)
x=F.relu(self.fc1(x))
#h2=relu(h1w2+b2)
x=F.relu(self.fc2(x))
#h3=h2w3+b3
x=(self.fc3(x))
return x net=Net()
#[w1,b1,w2,b2,w3,b3]
optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9)
train_loss=[]
for epoch in range(3):
for batch_idx,(x,y) in enumerate(train_loader):
#x:[b,1,28,28],y:[512]
x=x.view(x.size(0),28*28)
# => [b,10]
out =net(x)
# [b,10]
y_onehot=one_hot(y)
#loss=mse(out,y_onehot)
loss= F.mse_loss(out,y_onehot) optimizer.zero_grad()
loss.backward()
#w'=w-lr*grad
optimizer.step()
train_loss.append(loss.item()) if batch_idx %10==0:
print(epoch,batch_idx,loss.item()) #输出其预测loss损失函数的变化曲线
plot_curve(train_loss)
#get optimal [w1,b1,w2,b2,w3,b3] total_correct=0
for x,y in test_loader:
x=x.view(x.size(0),28*28)
out=net(x)
pred=out.argmax(dim=1)
correct=pred.eq(y).sum().float().item()
total_correct+=correct
total_num=len(test_loader.dataset)
acc=total_correct/total_num
print("test.acc:",acc) #输出整体预测的准确度 x,y=next(iter(test_loader))
out=net(x.view(x.size(0),28*28))
pred=out.argmax(dim=1)
plot_image(x,pred,"test")
实现结果如下所示:


pytorch深度学习神经网络实现手写字体识别的更多相关文章

  1. 深度学习之 mnist 手写数字识别

    深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...

  2. 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  3. 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识

    深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...

  4. 深度学习---手写字体识别程序分析(python)

    我想大部分程序员的第一个程序应该都是“hello world”,在深度学习领域,这个“hello world”程序就是手写字体识别程序. 这次我们详细的分析下手写字体识别程序,从而可以对深度学习建立一 ...

  5. 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别

    深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...

  6. BP神经网络的手写数字识别

    BP神经网络的手写数字识别 ANN 人工神经网络算法在实践中往往给人难以琢磨的印象,有句老话叫“出来混总是要还的”,大概是由于具有很强的非线性模拟和处理能力,因此作为代价上帝让它“黑盒”化了.作为一种 ...

  7. TensorFlow卷积神经网络实现手写数字识别以及可视化

    边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...

  8. 利用c++编写bp神经网络实现手写数字识别详解

    利用c++编写bp神经网络实现手写数字识别 写在前面 从大一入学开始,本菜菜就一直想学习一下神经网络算法,但由于时间和资源所限,一直未展开比较透彻的学习.大二下人工智能课的修习,给了我一个学习的契机. ...

  9. 第二节,TensorFlow 使用前馈神经网络实现手写数字识别

    一 感知器 感知器学习笔记:https://blog.csdn.net/liyuanbhu/article/details/51622695 感知器(Perceptron)是二分类的线性分类模型,其输 ...

随机推荐

  1. 【转】docker配置参数详解---/etc/docker/daemon.json完整参数

    docker-daemon.json各配置详解 { “api-cors-header”:"", ——————在引擎API中设置CORS标头 “authorization-plugi ...

  2. java 如何快速的获取浏览量

    最近公司做了一个类似 于发帖,交友圈一个这样的功能 在如何精确快速的获取用户的浏览量,且及时的更新显示,最初我是这样想,把每条帖子内容浏览量放到reids 里面,但是redis只是用来存零时数据,想想 ...

  3. 使用Eclipse工具开发Servlet(新建web项目->创建Servlet->部署和访问Servlet)

    在Eclipse工具栏中的[File]->[New]->[Other],打开如下菜单栏,选择Dynamic Web Project 点击下一步,如下图所示: 这里Dynamic web m ...

  4. Vue-表单提交

    template <form @submit.prevent="submitFrom"> <!-- 注册提交事件 .prevent 阻止表单的默认提交行为 --& ...

  5. oracle用户表字段注释

    SELECT C.TABLE_NAME,NUM_ROWS,(select COMMENTS from user_tab_comments WHERE TABLE_NAME=C.TABLE_NAME) ...

  6. windows 下安装chrome 调试iphone插件 ios-webkit-debug-proxy

    必备: 1. .NET Framework 4.5 及以上版本 2.powershell 5.1及以上版本 3.可正常访问  https://raw.githubusercontent.com/luk ...

  7. php redis 集群扩展类文件

    <?php /** * redis集群驱动 */ namespace Common\Api; class RedisCluster{ protected $servers=array( '192 ...

  8. 使用阿里的EasyExcel遇到的一些坑(NoSuchMethodError异常)

    引入eayexcel依赖的时候已经包含了poi依赖

  9. ubuntu mysql允许root用户远程登录

    有两种方法: 一. 1.mysql>GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' IDENTIFIED BY '123456' WITH GRANT OPT ...

  10. ZOJ4110 Strings in the Pocket(2019浙江省赛)

    给出两个字符串,询问有多少种反转方法可以使字符串1变成字符串2. 如果两个串相同,就用马拉车算法找回文串的数量~ 如果两个串不同,从前往后找第一个不同的位置l,从后往前找第二个不同的位置r,反转l和r ...