DAGNN 是Directed acyclic graph neural network 缩写,也就有向图非循环神经网络。我使用的是由MatConvNet 提供的DAGNN API。选择这套API作为工具的原因有三点,第一:这是matlab的API,相对其他语言我对Matlab比较熟悉;第二:有向图非循环的网络可以实现RPN,Network in Network 等较为复杂的功能,可以随意的引出各层的输入和输出,有利于针对三维视觉任务改造网络结构。MNIST 是手写数字的图片集,也是机器学习网络最简单的试金石。

1、定义层

 conv_layer1 = dagnn.Conv('size',single([5,5,1,30]),'hasBias',true);
relu_layer2 = dagnn.ReLU(); conv_layer3 = dagnn.Conv('size',single([5,5,30,16]),'hasBias',true);
relu_layer4 = dagnn.ReLU();
pooling_layer5 = dagnn.Pooling('poolSize',[2,2],'stride',[2 2]); fullConnet_layer6 = dagnn.Conv('size',single([4,4,16,256]),'hasBias',true);
relu_layer7 = dagnn.ReLU();
fullConnet_layer8 = dagnn.Conv('size',single([1,1,256,10]),'hasBias',true);
SoftMat_layer9 = dagnn.SoftMax();
Loss_layer = dagnn.Loss();

首先是利用API构造各层网络,定义网络结构类型。所有的Layer 都继承自dagnn.Layer类,子类中定义了输入输出,前向传播,反向传播的行为。

其中包括卷积层,激活层,池化层,Softmax 分类层,以及计算Loss层。值得注意的是全连接层是通过大卷积层来实现的。本质上全连接就是“输入的等尺寸卷积”。全连接层的作用是将卷积层提取的特征进行高度非线性的映射,将其映射到输出空间中。

2、定义网络

 mynet = dagnn.DagNN();
mynet.addLayer('conv1',conv_layer1,{'input'},{'x2'},{'filters_conv1','bias_conv1'});
mynet.addLayer('relu1',relu_layer2,{'x2'},{'x3'});
mynet.addLayer('pool1',pooling_layer5,{'x3'},{'x4'}); mynet.addLayer('conv2',conv_layer3,{'x4'},{'x5'},{'filters_conv2','bias_conv2'});
mynet.addLayer('relu2',relu_layer4,{'x5'},{'x6'});
mynet.addLayer('pool2',pooling_layer5,{'x6'},{'x7'}); mynet.addLayer('full1',fullConnet_layer6,{'x7'},{'x8'},{'filters_fc1','bias_fc1'});
mynet.addLayer('relu3',relu_layer7,{'x8'},{'x9'});
mynet.addLayer('full2',fullConnet_layer8,{'x9'},{'x10'},{'filters_fc2','bias_fc2'});
mynet.addLayer('Cls1',SoftMat_layer9,{'x10'},{'pred'});
mynet.addLayer('Loss',Loss_layer,{'pred','label'},{'loss'});
mynet.initParams();
mynet.meta.inputs = {'data',[28,28,1,1]};
mynet.meta.classes.name = {1,2,3,4,5,6,7,8,9,10};
mynet.meta.normalization.imageSize = [28,28,1,1];
mynet.meta.interpolation = 'bicubic';

定义网络调用了addLayer方法,与其他API的网络构建方法不同的是,DAGNN的API需要针对每层定义输入和输出,以及网络中的待求得参数。当然,作为初学者我先实现了链式网络,在下周的工作中会尝试实现Faster R-CNN。

net.addLayer('full1',fullConnet_layer6,{'x7'},{'x8'},{'filters_fc1','bias_fc1'});

以此为例,代表该层的名字是full1 , 该层的结构是fullConnect_layer6,输入为x7、输出x8,参数名为filters_fc1 和 bias_fc1。其中loss 层最为特殊,其具有来自softmax层的pred 和 label (ground truth) 两种输入。

最重要的是一定要initParams()!!!!这会生成初始参数。

3、定义数据输入函数

为了训练网络,我们需要定义一个输入函数。数据量小,可存在内存中,但当数据量大的时候全部存在内存里是不现实的,这就需要一个数据输入函数来对你定义的数据库进行操作。本例中我仅使用5000幅图片进行训练,所以可以把图片放在内存中。getBatch函数如下所示:

 function inputs = getBatch(imdb, batch)
% --------------------------------------------------------------------
images = imdb.images.data(:,:,:,batch) ;
labels = imdb.images.labels(1,batch) ; % images = gpuArray(images) ; inputs = {'input', images, 'label', labels} ;

其中 imdb 是image data base. 其中包括:

imdb.images.data  图片 W*H*C*N 的4-D  single Array

imdb.images.label  标签 N*1  的 single Array

imdb.images.data_mean 图片平均值 用于预处理时去中心

imdb.images.set    集合号 N*1 的 single Array, 其中1 代表训练集 2 代表测试集 3 代表验证集

imdb.meta 存放类型名称等和训练关系不太密切的东西

4、开始训练

直接调用 cnn_train_dag 的API 开始对整个集合进行训练,注意getBatch 输入的是函数句柄。

cnn_train_dag(mynet,imdb_sub,@getBatch);

  训练了30个epoch,但是learningRate好像给太高了,掉局部最小里了。。。。。。。不过结果不错,在验证集中拿到了4998/5000.

机器学习 —— 深度学习 —— 基于DAGNN的MNIST NET的更多相关文章

  1. 机器学习&深度学习基础(目录)

    从业这么久了,做了很多项目,一直对机器学习的基础课程鄙视已久,现在回头看来,系统的基础知识整理对我现在思路的整理很有利,写完这个基础篇,开始把AI+cv的也总结完,然后把这么多年做的项目再写好总结. ...

  2. [转载]机器学习&深度学习经典资料汇总,全到让人震惊

    自学成才秘籍!机器学习&深度学习经典资料汇总 转自:中国大数据: http://www.thebigdata.cn/JiShuBoKe/13299.html [日期:2015-01-27] 来 ...

  3. 机器学习&深度学习资料

    机器学习(Machine Learning)&深度学习(Deep Learning)资料(Chapter 1) 机器学习(Machine Learning)&深度学习(Deep Lea ...

  4. 深度学习之 GAN 进行 mnist 图片的生成

    深度学习之 GAN 进行 mnist 图片的生成 mport numpy as np import os import codecs import torch from PIL import Imag ...

  5. 机器学习&深度学习基础(机器学习基础的算法概述及代码)

    参考:机器学习&深度学习算法及代码实现 Python3机器学习 传统机器学习算法 决策树.K邻近算法.支持向量机.朴素贝叶斯.神经网络.Logistic回归算法,聚类等. 一.机器学习算法及代 ...

  6. 最全的机器学习&深度学习入门视频课程集

    资源介绍 链接:http://pan.baidu.com/s/1kV6nWJP 密码:ryfd     链接:http://pan.baidu.com/s/1dEZWlP3 密码:y82m 更多资源 ...

  7. 深度学习|基于LSTM网络的黄金期货价格预测--转载

    深度学习|基于LSTM网络的黄金期货价格预测 前些天看到一位大佬的深度学习的推文,内容很适用于实战,争得原作者转载同意后,转发给大家.之后会介绍LSTM的理论知识. 我把code先放在我github上 ...

  8. 机器学习&深度学习经典资料汇总,data.gov.uk大量公开数据

    <Brief History of Machine Learning> 介绍:这是一篇介绍机器学习历史的文章,介绍很全面,从感知机.神经网络.决策树.SVM.Adaboost到随机森林.D ...

  9. 机器学习&深度学习基础(tensorflow版本实现的算法概述0)

    tensorflow集成和实现了各种机器学习基础的算法,可以直接调用. 代码集:https://github.com/ageron/handson-ml 监督学习 1)决策树(Decision Tre ...

随机推荐

  1. C_关于递归算法的几个例子

    1.递归算法的定义: 2.递归与迭代的优劣 eg1:斐波那契数列:斐波那契数列(Fibonacci sequence),又称黄金分割数列.因数学家列昂纳多·斐波那契(Leonardoda Fibona ...

  2. GMA Round 1 YGGDRASIL

    传送门 YGGDRASIL 在YGGDRASIL世界,一年有213天. Demiurge推广种植了一种植物,姑且称之为“黄金果”,它第一期生长需要140天,此后第i期生长需要的天数$a_i$满足$a_ ...

  3. poj2376 Cleaning Shifts(区间贪心,理解题意)

    https://vjudge.net/problem/POJ-2376 题意理解错了!!真是要仔细看题啊!! 看了poj的discuss才发现,如果前一头牛截止到3,那么下一头牛可以从4开始!!! # ...

  4. Java知识回顾 (10) 线程

    再次声明,正如(1)中所描述的,本资料来自于runoob,略有修改. 一条线程指的是进程中一个单一顺序的控制流,一个进程中可以并发多个线程,每条线程并行执行不同的任务. Java 给多线程编程提供了内 ...

  5. Lua MD5加密字符串

    function md5_sumhexa(k) local md5_core = require "md5.core" k = md5_core.sum(k) return (st ...

  6. 转: 关于linux用户时间与系统时间的说明

    写的很不错的一篇. https://blog.csdn.net/mmshixing/article/details/51307853

  7. python pip 安装包报 编码问题

    好久不玩 TF 了, 今天尝试了一个案例,发现要安装module , 就搞了一下, 发现要先安装 base , 安装过程有遇到好多问题, 就写写, 将其中解决过程记录下来. 1. 保存,编码问题 Un ...

  8. 深入理解Java String类(综合)

    在Java语言了中,所有类似“ABC”的字面值,都是String类的实例:String类位于java.lang包下,是Java语言的核心类,提供了字符串的比较.查找.截取.大小写转换等操作:Java语 ...

  9. 安装Docker,Asp.net core

    升级项目到.NET Core 2.0,在Linux上安装Docker,并成功部署 Docker从入门到实践 一.安装Docker a).设置Docker仓库 1.按惯例,首先更新Ubuntu的包索引 ...

  10. 构建自己的 Smart Life 私有云(二)-> 连通 IFTTT & Slack

    博客搬迁至https://blog.wangjiegulu.com RSS订阅:https://blog.wangjiegulu.com/feed.xml 原文链接:https://blog.wang ...