人工智能深度学习框架MXNet实战:深度神经网络的交通标志识别训练
人工智能深度学习框架MXNet实战:深度神经网络的交通标志识别训练
MXNet 是一个轻量级、可移植、灵活的分布式深度学习框架,2017 年 1 月 23 日,该项目进入 Apache 基金会,成为 Apache 的孵化器项目。尽管现在已经有很多深度学习框架,包括 TensorFlow, Keras, Torch,以及 Caffe,但 Apache MXNet 因其对多 GPU 的分布式支持而越来越受欢迎。
环境准备
1.安装 Anaconda。Anaconda 是一个用于科学计算的 Python 发行版,提供了包管理与环境管理的功能。Anaconda 利用 conda 来进行 package 和 environment 的管理,并且已经包含了 Python 和相关的配套工具。
Anaconda3-4.4下载地址: https://repo.continuum.io/archive/Anaconda3-4.4.0-Windows-x86_64.exe
2.在 conda 下安装 pip,安装命令为‘conda install pip’
3.安装 OpenCV-python 库。OpenCV-python 是一个很强大的计算机视觉库,在这个项目中可以用于处理图像。使用‘pip install openvc-python’在 Anaconda 环境下安装 OpenCV。也可以从源文件进行编译(注意:conda 安装 opencv3.0 不能运行)。
4.安装 scikit learn,一个开源的 python 机器学习科学计算库,它将用于对数据进行预处理。安装命令为‘conda install scikit-learn’。
5.安装 Jupyter Notebook,安装命令为‘conda install jupyter notebook’。
6.安装 MXNet。安装命令为‘pip install mxnet’。
------------------
数据库
使用的数据库是德国交通标志识别基准,来自论文《德国交通标志识别基准:多类别分类竞赛》( J. Stallkamp, M. Schlipsing, J. Salmen, and C. Igel. "The German Traffic Sign Recognition Benchmark: A multi-class classification competition." ),发表在 IEEE International Joint Conference on Neural Networks,2011。该数据集包含 39209 张训练样例和 12630 张测试样例,有 43 种不同的交通标志——停车标志,限速标志,各种警示标志以及其他标志。
数据库中的每张图像大小为 32×32,均为三通道彩色图。每幅图属于一种交通标志。图像种类标签由 0 到 42 的整数表示。
从一个 NumPy 阵列中下载数据,数据分为训练,验证和测试集。训练集包含 39209 张大小为 32×32,通道数为 3 的图像,所以 NumPy 阵列的维度为 39209×32×32×3。该项目中作者仅使用了训练集和验证集。作者将使用网上的真实图像来测试所构建的模型。X_train 存储图像,维度为 39209×32×32×3。Y_train 存储图像对应的类标,维度为 39209,包含 0-42 的整数,对应每张图的类标。
训练过程
1. 准备数据集
X_train 和 Y_train 组成了训练数据集。可以使用 scikit-learn 对训练数据集进行分割得到验证集,这样可以避免使用出现过的图片测试模型。代码如下:

2. 训练数据预处理
批训练
神经网络训练需要花费大量时间和内存。所以作者将数据分批训练,一批大小为 64. 不仅是为了让数据适应内存,而且它可以让 MXNet 尽量利用 GPU 的计算效率。
归一化
除此之外,图像的像素值也进行了归一化,可以使学习算法更快收敛。下面是对训练数据进行预处理的代码:

3. 构建深度网络
目前,对于图像识别这类处在探索研究热点的问题,学界已经设计了很多效果良好的网络结构。所以最好的方法是实现一个已经发表出来的网络结构,然后对其进行改进。基于 AlexNet 结构,构建了一个简化版的卷积神经网络。AlexNet 是 2012 年发表的一个经典网络,在当年取得了 ImageNet 的最好成绩。

网络共有 8 层,其中前 5 层是卷积层,后边 3 层是全连接层,在每一个卷积层中包含了激励函数 RELU 以及局部响应归一化(LRN)处理,然后再经过池化(max pooling),最后的一个全连接层的输出是具有 1000 个输出的 softmax 层,最后的优化目标是最大化平均的多元逻辑回归。
在此之后也有很多更优秀的网络结构被提出,例如 VGGNet 和 ResNet,大家可以选择更好的网络结构去实现。
由于 MXNet 的符号计算构架,该神经网络的代码十分简洁明了
4. 训练网络
训练 epoch 为 10,训练好的模型存在 JSON 文件中,并且可以通过测量训练和验证准确率来观测网络“学习”的情况。

5. 载入预训练模型
下面给出了加载第 10 个 epoch 模型(最终模型)的代码。由于将在单张图片上进行测试,所以批尺寸由 64 减到 1,数据维度也变成了 1×3×32×32。
测试过程
测试图像(32×32×3)样例:

从结果可以看出可能性最高的种类为停车标志,说明预测准确。如果需要对模型有一个更完整的衡量,还需要用测试数据库进行测试,得到最终的分类准确率。
总结
本文我们介绍了使用 MXNet 进行多目标分类任务的方法。使用 MXNet,在 AlexNet 的结构基础上构建了一个更为简单的卷积神经网络结构。网络由卷积层,激活函数层,池化层和全连接层组成,采用德国交通标志图像训练数据库对该网络进行训练,实验结果证明网络可以将交通标志进行正确的分类。介绍了如何使用 MXNet 对数据进行预处理,构建网络,以及如何加载预训练好的网络模型。可以看出,MXNet 因其在多 GPU 上进行并行训练的能力,以及网络模型构建简单灵活的特性,是一个十分优秀的深度学习框架。
人工智能深度学习框架MXNet实战:深度神经网络的交通标志识别训练的更多相关文章
- Tensorflow 实战Google深度学习框架 第五章 5.2.1Minister数字识别 源代码
import os import tab import tensorflow as tf print "tensorflow 5.2 " from tensorflow.examp ...
- 吴裕雄--天生自然python Google深度学习框架:TensorFlow实现神经网络
http://playground.tensorflow.org/
- 深度学习-使用cuda加速卷积神经网络-手写数字识别准确率99.7%
源码和运行结果 cuda:https://github.com/zhxfl/CUDA-CNN C语言版本参考自:http://eric-yuan.me/ 针对著名手写数字识别的库mnist,准确率是9 ...
- TensorFlow+实战Google深度学习框架学习笔记(5)----神经网络训练步骤
一.TensorFlow实战Google深度学习框架学习 1.步骤: 1.定义神经网络的结构和前向传播的输出结果. 2.定义损失函数以及选择反向传播优化的算法. 3.生成会话(session)并且在训 ...
- TensorFlow实战Google深度学习框架-人工智能教程-自学人工智能的第二天-深度学习
自学人工智能的第一天 "TensorFlow 是谷歌 2015 年开源的主流深度学习框架,目前已得到广泛应用.本书为 TensorFlow 入门参考书,旨在帮助读者以快速.有效的方式上手 T ...
- 转:TensorFlow和Caffe、MXNet、Keras等其他深度学习框架的对比
http://geek.csdn.net/news/detail/138968 Google近日发布了TensorFlow 1.0候选版,这第一个稳定版将是深度学习框架发展中的里程碑的一步.自Tens ...
- [Tensorflow实战Google深度学习框架]笔记4
本系列为Tensorflow实战Google深度学习框架知识笔记,仅为博主看书过程中觉得较为重要的知识点,简单摘要下来,内容较为零散,请见谅. 2017-11-06 [第五章] MNIST数字识别问题 ...
- Reading | 《TensorFlow:实战Google深度学习框架》
目录 三.TensorFlow入门 1. TensorFlow计算模型--计算图 I. 计算图的概念 II. 计算图的使用 2.TensorFlow数据类型--张量 I. 张量的概念 II. 张量的使 ...
- 1 如何使用pb文件保存和恢复模型进行迁移学习(学习Tensorflow 实战google深度学习框架)
学习过程是Tensorflow 实战google深度学习框架一书的第六章的迁移学习环节. 具体见我提出的问题:https://www.tensorflowers.cn/t/5314 参考https:/ ...
随机推荐
- finecms指定从第几篇文章开始调用5条记录,并调用文章所在栏目
我们在建站时可能会有具体的要求,比如从第几篇文章开始调用5篇,finecms如何实现呢?用下面一段代码就能完成:num=0,5表示从第一篇开始调用5篇,注意,0代表第一,5表示调用5篇: <a ...
- OSError:[Errno 13] Permission denied:'my_library' 问题解决方法
出现问题: 执行 rosrun rosserial_windows make_libraries.py my_library 命令时出现OSError:[Errno 13] Permission de ...
- echarts给数据视图添加表格样式
1,准备好样式 <style>.myTable {margin: 0 auto;/* height: 300px; */width: 700px;} .myTitle {backgroun ...
- 列表 list 容器类型数据(str字符串, list列表, tuple元组, set集合, dict字典)--->元组 tuple-->字符串 str
# ### 列表 list 容器类型数据(str字符串, list列表, tuple元组, set集合, dict字典) # (1)定义一个列表 listvar = [] print(listvar, ...
- WCG distribution of byteball
https://steemit.com/byteball/@punqtured/processing-for-good-get-rewarded-for-crunching-numbers-to-cu ...
- 基于jquery ajax的多文件上传进度条
效果图 前端代码,基于jquery <!DOCTYPE html> <html> <head> <title>主页</title> < ...
- Dotfuscator 使用图解教程
Dotfuscator:是.NET混淆器和压缩器,它可以帮助您防止您的应用程序被反编译.同时,它还可以使得您的应用程序更加小巧以及高效.我用的是4.9版本的Dotfuscator,Dotfuscato ...
- MySQL数据类型--与MySQL零距离接触 3-2 外键约束的要求解析
列级约束:只针对某一个字段 表级约束:约束针对2个或2个以上的字段 约束类型是按功能来划分. 外键约束:保持数据一致性,完整性.实现数据表的一对一或一对多的关系.这就是把MySQL称为关系型数据库的根 ...
- vim自动格式化
,gg 跳转到第一行 ,shift+v 转到可视模式 ,shift+g 全选 ,按下神奇的 = 你会惊奇的发现代码自动缩进了,呵呵,当然也可能是悲剧了.
- sqli-labs(十七)
第五十四关: 这关大概意思就是尝试次数不能多于十次,必须十次之类查询处特点的key. 第一次:输入单引号报错 第二次:输入双引号不报错 说明后台是单引号进行的拼凑 第三步:这里应该是判断列,用orde ...