记得上次练习了神经网络分类,不过当时应该有些地方写的还是不对。

这次用神经网络识别mnist手写数据集,主要参考了深度学习工具包的一些代码。

mnist数据集训练数据一共有28*28*60000个像素,标签有60000个。

测试数据一共有28*28*10000个,标签10000个。

这里神经网络输入层是784个像素,用了100个隐含层,最终10个输出结果。

arc代表的是神经网络结构,可以增加隐含层,不过我试了没太大效果,毕竟梯度消失。

因为是最普通的神经网络,最终识别错误率大概在5%左右。

迭代曲线:

代码如下:

 
clear all;
close all;
clc; load mnist_uint8; train_x = double(train_x) / 255;
test_x = double(test_x) / 255;
train_y = double(train_y);
test_y = double(test_y); mu=mean(train_x);
sigma=max(std(train_x),eps);
train_x=bsxfun(@minus,train_x,mu); %每个样本分别减去平均值
train_x=bsxfun(@rdivide,train_x,sigma); %分别除以标准差 test_x=bsxfun(@minus,test_x,mu);
test_x=bsxfun(@rdivide,test_x,sigma); arc = [784 100 10]; %输入784,隐含层100,输出10
n=numel(arc); W = cell(1,n-1); %权重矩阵
for i=2:n
W{i-1} = (rand(arc(i),arc(i-1)+1)-0.5) * 8 *sqrt(6 / (arc(i)+arc(i-1)));
end learningRate = 2; %训练速度
numepochs = 5; %训练5遍
batchsize = 100; %一次训练100个数据 m = size(train_x, 1); %数据总量
numbatches = m / batchsize; %一共有numbatches这么多组 %% 训练
L = zeros(numepochs*numbatches,1);
ll=1;
for i = 1 : numepochs
kk = randperm(m);
for l = 1 : numbatches
batch_x = train_x(kk((l - 1) * batchsize + 1 : l * batchsize), :);
batch_y = train_y(kk((l - 1) * batchsize + 1 : l * batchsize), :); %% 正向传播
mm = size(batch_x,1);
x = [ones(mm,1) batch_x];
a{1} = x;
for ii = 2 : n-1
a{ii} = 1.7159*tanh(2/3.*(a{ii - 1} * W{ii - 1}'));
a{ii} = [ones(mm,1) a{ii}];
end a{n} = 1./(1+exp(-(a{n - 1} * W{n - 1}')));
e = batch_y - a{n};
L(ll) = 1/2 * sum(sum(e.^2)) / mm;
ll=ll+1;
%% 反向传播
d{n} = -e.*(a{n}.*(1 - a{n}));
for ii = (n - 1) : -1 : 2
d_act = 1.7159 * 2/3 * (1 - 1/(1.7159)^2 * a{ii}.^2); if ii+1==n
d{ii} = (d{ii + 1} * W{ii}) .* d_act;
else
d{ii} = (d{ii + 1}(:,2:end) * W{ii}).* d_act;
end
end for ii = 1 : n-1
if ii + 1 == n
dW{ii} = (d{ii + 1}' * a{ii}) / size(d{ii + 1}, 1);
else
dW{ii} = (d{ii + 1}(:,2:end)' * a{ii}) / size(d{ii + 1}, 1);
end
end %% 更新参数
for ii = 1 : n - 1
W{ii} = W{ii} - learningRate*dW{ii};
end end
end %% 测试,相当于把正向传播再走一遍
mm = size(test_x,1);
x = [ones(mm,1) test_x];
a{1} = x;
for ii = 2 : n-1
a{ii} = 1.7159 * tanh( 2/3 .* (a{ii - 1} * W{ii - 1}'));
a{ii} = [ones(mm,1) a{ii}];
end
a{n} = 1./(1+exp(-(a{n - 1} * W{n - 1}'))); [~, i] = max(a{end},[],2);
labels = i; %识别后打的标签
[~, expected] = max(test_y,[],2);
bad = find(labels ~= expected); %有哪些识别错了
er = numel(bad) / size(x, 1) %错误率 plot(L);
 

测试数据可以在这里下载到:https://pan.baidu.com/s/19YPUe9S9xnztg9JGnoXxqw

关注公众号: MATLAB基于模型的设计 (ID:xaxymaker) ,每天推送MATLAB学习最常见的问题,每天进步一点点,业精于勤荒于嬉

打开微信扫一扫哦!

matlab练习程序(神经网络识别mnist手写数据集)的更多相关文章

  1. 用Kersa搭建神经网络【MNIST手写数据集】

    MNIST手写数据集的识别算得上是深度学习的”hello world“了,所以想要入门必须得掌握.新手入门可以考虑使用Keras框架达到快速实现的目的. 完整代码如下: # 1. 导入库和模块 fro ...

  2. TensorFlow实战第五课(MNIST手写数据集识别)

    Tensorflow实现softmax regression识别手写数字 MNIST手写数字识别可以形象的描述为机器学习领域中的hello world. MNIST是一个非常简单的机器视觉数据集.它由 ...

  3. 利用sklearn对MNIST手写数据集开始一个简单的二分类判别器项目(在这个过程中学习关于模型性能的评价指标,如accuracy,precision,recall,混淆矩阵)

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  4. 利用卷积神经网络实现MNIST手写数据识别

    代码: import torch import torch.nn as nn import torch.utils.data as Data import torchvision # 数据库模块 im ...

  5. TensorFlow系列专题(六):实战项目Mnist手写数据集识别

    欢迎大家关注我们的网站和系列教程:http://panchuang.net/ ,学习更多的机器学习.深度学习的知识! 目录: 导读 MNIST数据集 数据处理 单层隐藏层神经网络的实现 多层隐藏层神经 ...

  6. MNIST手写数据集在运行中出现问题解决方案

    今天在运行手写数据集的过程中,出现一个问题,代码没有问题,但是运行的时候一直报错,错误如下: urllib.error.URLError: <urlopen error [SSL: CERTIF ...

  7. Pytorch1.0入门实战一:LeNet神经网络实现 MNIST手写数字识别

    记得第一次接触手写数字识别数据集还在学习TensorFlow,各种sess.run(),头都绕晕了.自从接触pytorch以来,一直想写点什么.曾经在2017年5月,Andrej Karpathy发表 ...

  8. TensorFlow——MNIST手写数据集

    MNIST数据集介绍 MNIST数据集中包含了各种各样的手写数字图片,数据集的官网是:http://yann.lecun.com/exdb/mnist/index.html,我们可以从这里下载数据集. ...

  9. keras—神经网络CNN—MNIST手写数字识别

    from keras.datasets import mnist from keras.utils import np_utils from plot_image_1 import plot_imag ...

随机推荐

  1. 像素数据YUV简介与觉存储格式介绍

    主要学习链接:博客园.51CTO 前言 照例是先废话几句,下面的内容都是在学习时从网上找来的,并非我原创,我之所以要写这篇笔记是因为网的内容都很分散,找的时候要从各个地方看,很不方便,所以就自己总结了 ...

  2. 7.Git分支-分支简介、分支创建、分支切换

    1.分支简介 几乎所有的版本控制系统都支持某种形式的分支.使用分支意味着可以把你的工作从开发主线上分离开来,以免影响开发主线.Git的分支是其必杀技,它相对于其它版本控制系统来说,具有难以置信的轻量性 ...

  3. ECMAScript 6 学习(一)generator函数

    1.ES2017标准引入async函数,那么async函数到底是个什么函数呢? async 是一个generator函数的语法糖. 2.那么generator函数到底是什么函数ne? generato ...

  4. JVM基础系列第13讲:JVM参数之追踪类信息

    我们都知道 JVM 在启动的时候会去加载类信息,那么我们怎么得知他加载了哪些类,又卸载了哪些类呢?我们这一节就来介绍四个 JVM 参数,使用它们我们就可以清晰地知道 JVM 的类加载信息. 为了方便演 ...

  5. nginx的configure流程

    configure配置 nginx的编译过程,第一步是configure.我们使用 --help可以看到configure的很多配置. configure的过程做的事情其实就是检测环境,然后根据环境生 ...

  6. Windows提权与开启远程连接

    1.提权: 建立普通用户:net user 帐户 密码 /add 提权成管理员:net localgroup administrators 帐户 /add 更改用户密码:net user 帐户 密码 ...

  7. 『Tarjan算法 无向图的双联通分量』

    无向图的双连通分量 定义:若一张无向连通图不存在割点,则称它为"点双连通图".若一张无向连通图不存在割边,则称它为"边双连通图". 无向图图的极大点双连通子图被 ...

  8. DDD实战进阶第一波(一):开发一般业务的大健康行业直销系统(概述)

    本系列文章 DDD实战进阶第一波(一):开发一般业务的大健康行业直销系统(概述) DDD实战进阶第一波(二):开发一般业务的大健康行业直销系统(搭建支持DDD的轻量级框架一) 近年来,关于如何开发基于 ...

  9. Android--SoundPool

    前言 在Android中播放音频文件经常会用到MediaPlayer,但是MediaPlayer存在一些不足的地方,如:资源占用量较高.加载延迟时间较长.不支持多个音频同时播放等.这些缺点决定了Med ...

  10. Linux~连接windows的ftp,unzip出现的问题

    在linux进行连接windows下的ftp服务器 ftp://192.168.2.71 输入用户名和密码登陆成功