深度学习 CNN CUDA 版本2
作者:zhxfl
邮箱:zhxfl##mail.ustc.edu.cn
主页:http://www.cnblogs.com/zhxfl/p/4155236.html
第1个版本blog在这里:http://www.cnblogs.com/zhxfl/p/4134834.html
第2个版本github:https://github.com/zhxfl/CUDA-CNN
欢迎fork,在第一个版本的时候,我们只是针对手写数字,也就是黑白图片。在第二个版本中,我加入了很多东西。
第二个版本的特性
1、支持rgb图片格式和rgbd图片格式(带有深度信息的图片)训练,带有深度信息的图片可以来源于Kinect。
参考论文Anddrew Y.Ng的论文:Convolutional-Recursive Deep Learning for 3D Object Classification,你可以找到对应的带有深度信息的数据集。
4d的图片不是这个版本的主要目的,但是你确实可以用这个代码来训练4D的数据集。(我在不久的未来会让这个版本更好的支持4D数据的训练)
2、第二个比较突出的特性是你可以看到配置文件的参数更加复杂了,我会对所有参数一一做说明。
#Comment# IS_GRADIENT_CHECKING = false; #is true when debug#
BATCH_SIZE = 100; #test image size should be divided with no remainder#
NON_LINEARITY = NL_RELU; #NON_LINEARITY CAN = NL_SIGMOID , NL_TANH , NL_RELU#
CHANNELS = 3; #1, 3, 4#
CROP = 0.0; #0<= crop <=imgSize#
SCALE = 0.0; #ImgSize from -13.0 to 13.0#
ROTATION = 0.0; #angle from -13.0 to 13.0#
DISTORTION = 0.0; #just for mnist#
SHOWIMAGE = false; #show the images after transformation# [
LAYER = CONV;
KERNEL_SIZE = 5;
KERNEL_AMOUNT = 7;
WEIGHT_DECAY = 1e-6;
POOLING_DIM = 2;
] [
LAYER = CONV;
KERNEL_SIZE = 5;
KERNEL_AMOUNT = 9;
WEIGHT_DECAY = 1e-6;
POOLING_DIM = 2;
] [
LAYER = FC;
NUM_HIDDEN_NEURONS = 256;
WEIGHT_DECAY = 1e-6;
DROPOUT_RATE = 0.5;
] [
LAYER = FC;
NUM_HIDDEN_NEURONS = 256;
WEIGHT_DECAY = 1e-6;
DROPOUT_RATE = 0.5;
] [
LAYER = SOFTMAX;
NUM_CLASSES = 10;
WEIGHT_DECAY = 1e-6;
]
1)IS_GRADIENT_CHECKING 这是一个debug选项(其原理可以参考斯坦福深度学习的教程)。如果你修改了代码,建议你设置为true。你必须确保(g(s + delta) - g(s - delta)) / 2 约等于g(s)。他可以辅助你判断目前的代码是否存在bug。
2)BASH_SIZE,我们训练的方法是mini-batch,这个数值的设置对于收敛的结果和速度都是有影响的。建议可以尝试50,100,150,200等,你会得到不同的试验结果。 接下来的几个参数都是用来克服overfitting的,对于深度学习而言,训练样本越多,效果会越好。所以我们对于训练数据必须加以扩展。 3)CROP是裁剪参数,假设图像大小为ImgSize,那么是最终训练的数据应该是ImgSize-CROP,裁剪的窗口起点是随机的,也就是一张图片已经变成了CROP*CROP张图片了。
4)ROTATION是旋转,这步操作对于手写数字非常有效,但是你必须确保旋转的角度不要过大,比如13度,那么最后代码训练的图片都会被随机的旋转角度[-13,13],这是一个区间。
5)DISTORTION又称为畸变,这个也比较适合手写数字,参数越大,图片变化越大,从大量实验看,针对手写数据集,设置为3.4是比较合适的,原理参考论文Best Practices for Convolutional Neural Networks Applied to Visual Document Analysis
6)SHOWIMAGE这是一个debug选项,3)-5)都是对图片做一些变化,如果你想知道变化的效果,那么可以把这个参数设置为true,这样你就可以看到变化之后的效果。方便你更好的调整3)-5)这些参数。 目前试验结果
1、对CIFAR-10数据集进行了比较短时间的训练(没有对数据进行变化),测试准确率是81.37%,接近于https://code.google.com/p/cuda-convnet/ 的初步结果,这样一个试验结果已经足够说明代码的正确性了。
我最初的代码是参考http://eric-yuan.me/cnn3/,Eric加入了不少东西,但是针对CIFAR-10他只是得到了71%的正确率,我能够等到更高的正确率归功于CUDA加速,使得我可以设置规模更大的网络,仅此而已。
当然,在我的第三个大版本中,我会确保针对cifar-10数据集,我能够得到接近于所有公开结果中最好的实验结果。
2、针对mnist数据集,依然可以轻易的实现99%以上的正确率。 第3个版本的主要任务。
1、在实现第二个版本的时候,我fix了大量的bug,你要清楚,一个大型项目不可能没有bug的,只要他不影响工作,目前从试验效果看,第二个版本已经稳定了。
2、目前我的网络结构依然太单一了,第3个版本的核心任务就是加入如下两个特性:
1)参考Notes on Convolutional Neural Networks第3.3节,Learning Conbinations of Feature Maps。
2)参考ImageNet Classification with Deep Convolutional Neural Networks第3.3节,Local Response Normalization。
这两个特性是非常重要,可以非常显著提升数据集CIFAR-10的准确率,你会在第3个版本看到这两个特性,并且通过配置文件决定是否使用它们进行训练(因为针对mnist你并不需要这么复杂的特性,加入会降低运算效率)。
深度学习 CNN CUDA 版本2的更多相关文章
- 深度学习-CNN+RNN笔记
以下叙述只是简单的叙述,CNN+RNN(LSTM,GRU)的应用相关文章还很多,而且研究的方向不仅仅是下文提到的1. CNN 特征提取,用于RNN语句生成图片标注.2. RNN特征提取用于CNN内容分 ...
- 深度学习-使用cuda加速卷积神经网络-手写数字识别准确率99.7%
源码和运行结果 cuda:https://github.com/zhxfl/CUDA-CNN C语言版本参考自:http://eric-yuan.me/ 针对著名手写数字识别的库mnist,准确率是9 ...
- 深度学习——CNN
整理自: https://blog.csdn.net/woaidapaopao/article/details/77806273?locationnum=9&fps=1 思想 filter尺寸 ...
- 小刘的深度学习---CNN
前言: 前段时间我在树莓派上通过KNN,SVM等机器学习的算法实现了门派识别的项目,所用到的数据集是经典的MNIST.可能是因为手写数字与印刷体存在一些区别,识别率并是很不高.基于这样的情况,我打算在 ...
- 经典深度学习CNN总结 - LeNet、AlexNet、GoogLeNet、VGG、ResNet
参考了: https://www.cnblogs.com/52machinelearning/p/5821591.html https://blog.csdn.net/qq_24695385/arti ...
- 深度学习-CNN tensorflow 可视化
tf.summary模块的简介 在TensorFlow中,最常用的可视化方法有三种途径,分别为TensorFlow与OpenCv的混合编程.利用Matpltlib进行可视化.利用TensorFlow自 ...
- python数据可视化、数据挖掘、机器学习、深度学习 常用库、IDE等
一.可视化方法 条形图 饼图 箱线图(箱型图) 气泡图 直方图 核密度估计(KDE)图 线面图 网络图 散点图 树状图 小提琴图 方形图 三维图 二.交互式工具 Ipython.Ipython not ...
- 深度学习-theano-windows -cuda-环境搭建
本文将具体介绍深度学习之cuda的环境搭建 工具:支持CUDA的显卡(安装cuda6.5),VS2013.Anaconda. 步骤: 1.安装cuda6.5 这个不具体介绍,网上有很多文章.注意选择你 ...
- win10+anaconda+cuda配置dlib,使用GPU对dlib的深度学习算法进行加速(以人脸检测为例)
在计算机视觉和机器学习方向有一个特别好用但是比较低调的库,也就是dlib,与opencv相比其包含了很多最新的算法,尤其是深度学习方面的,因此很有必要学习一下.恰好最近换了一台笔记本,内含一块GTX1 ...
随机推荐
- [转载]jquery tmpl使用方法
动态请求数据来更新页面是现在非常常用的方法,比如博客评论的分页动态加载,微博的滚动加载和定时请求加载等. 这些情况下,动态请求返回的数据一般不是已拼好的 HTML 就是 JSON 或 XML,总之不在 ...
- 关闭MyEclipse的Quick Update
关闭MyEclipse的Quick Update, Windows > Preferences > MyEclipse > Community Essentials, 把选项 &qu ...
- ***CI异常记录到日志:CodeIgniter中设计一个全局exception hook
在CodeIgniter中,当发生异常时,经常要通知系统管理员,因此有必要在全局的高度上 捕捉异常,因此可以写一个hook, 比如在config目录的hook.php中,加入: $hook['pre_ ...
- MAT使用总结
最近在做项目的时候遇到一个内存泄漏,最后通过MAT定位了问题, 先介绍一下MAT的一些基本概念: Shallow Heap:对象本身占用内存的大小,不包含对其他对象的引用,也就是对象头加成员变量(不是 ...
- 115. Distinct Subsequences
题目: Given a string S and a string T, count the number of distinct subsequences of T in S. A subseque ...
- SGU 130
SGU130,用k条弦将一个圆分成k+1份的方法数. #include <iostream> #include <vector> #include <string> ...
- 对C#中的web访问mysql数据库的一些知识点进行了整理归纳总结
基本对比 使用方式 使用场合 优缺点 是否需要安装 需要的dll网址 引用方式 程序内引用 程序初期确定使用MySql,前期添加引用 大多数情况下使用在类文件内,多数使用于aspx,ashx等带有后置 ...
- c#自带压缩类实现数据库表导出到CSV压缩文件的方法
在导出大量CSV数据的时候,常常体积较大,采用C#自带的压缩类,可以方便的实现该功能,并且压缩比例很高,该方法在我的开源工具DataPie中已经经过实践检验.我的上一篇博客<功能齐全.效率一流的 ...
- IMX51+WINCE6.0平台缩写意义
1.以EPIT为例 EPIT(Enhanced Periodic Interrupt Timer)为增强型周期中断定时器,其中有CR控制寄存器,要设置CR寄存器的SWR位,代码如下: // Asser ...
- bzoj1856
这是一道无比涨姿势的题目 首先总结一下这种输入几个数的题目, 一般不是递推就是数学题 显然,这道题用递推是无法做到O(n)的复杂度的 那我们就考虑这是一道数学题了 我已开始纠结在正向思维了,正向求好像 ...