利用 TFLearn 快速搭建经典深度学习模型

使用 TensorFlow 一个最大的好处是可以用各种运算符(Ops)灵活构建计算图,同时可以支持自定义运算符(见本公众号早期文章《TensorFlow 增加自定义运算符》)。由于运算符的粒度较小,在构建深度学习模型时,代码写出来比较冗长,比如实现卷积层:5, 9

这种方式在设计较大模型时会比较麻烦,需要程序员徒手完成各个运算符之间的连接,像一些中间变量的维度变换、运算符参数选项、多个子网络连接处极易发生问题,肉眼检查也很难发现代码中潜伏的 bug,会导致运行时出错(运气好),或者运行时不出错但运行结果不可解释(运气不好),消耗大量时间和精力。

有没有更好的实现各种经典模型的方式?

答案是肯定的!

我们今天学习一下在 TensorFlow 之上构建的高层次 API—— TFLearn【2】。

TFLearn 是一个模块化和透明的深度学习库,构建在 TensorFlow 之上。

它为 TensorFlow 提供高层次 API,目的是便于快速搭建试验环境,同时保持对 TensorFlow 的完全透明和兼容性。

TFLearn 的一些特点:

  • 容易使用和易于理解的高层次 API 用于实现深度神经网络,附带教程和例子;

  • 通过高度模块化的内置神经网络层、正则化器、优化器等进行快速原型设计;

  • 对 TensorFlow 完全透明,所有函数都是基于 tensor,可以独立于 TFLearn 使用;

  • 强大的辅助函数,训练任意 TensorFlow 图,支持多输入、多输出和优化器;

  • 简单而美观的图可视化,关于权值、梯度、特征图等细节;

  • 无需人工干预,可使用多 CPU、多 GPU;

  • 高层次 API 目前支持最近大多数深度学习模型,像卷积网络、LSTM、BiRNN、BatchNorm、PReLU、残差网络、生成网络、增强学习…… 将来会一直更新最近的深度学习技术;

心动不如行动,我们马上就体验!在一台已经安装了 TensorFlow 的机器上(安装步骤参考之前文章《TensorFlow 1.0.0rc1 入坑记》《利用 TensorFlow 集装箱快速搭建交互式开发环境》《如何在 Windows 系统玩 TensorFlow》)直接运行以下命令:pip in

检查安装成功:

为了方便运行 TFLearn 附带例程,我们需要克隆 TFLearn 源码:h

先看看如何用 TFLearn 实现 AlexNet 用于 Oxford 17 类鲜花数据集分类任务的:

上图为论文【1】 中的 AlexNet 结构。

TFLearn 例程中实现的 AlexNet 和论文【1】中相比做了一些修改:

  • 输入图像尺寸变为 227 x 227;

  • 将 2-tower 架构改为 single-tower;

  • 最后一个分类层的输出类别数从 1000 变为 17;

运行该例程:

该程序会自动下载 Oxford 17 flowers 数据集, 选了几个不同类别图片如下:

运行 AlexNet 模型训练截图如下:

在另一个命令行窗口启动 TensorBoard:

打开浏览器,输入地址:localhost:6006,打开 TensorBoard 页面,查看训练过程的准确率、loss 值变化:

AlexNet 模型可视化(之一)

(之二)

模型权值分布:

模型权值的直方图,可以看出权值训练历史:

通过今天内容,读者可以看出使用 TFLearn 高层次 API 相比直接使用 TensorFlow 实现深度学习模型具有使用更简单、构建更快速、可视化更方便等特点,从此无需手动处理各个运算符之间的连接,解放了生产力,提高了模型设计和优化效率。

作为练习,读者可以进一步学习 TFLearn 实现其他经典深度学习模型如 VGG、Inception、NIN、ResNet 等,对比原始论文学习,相信会有更大的收获。

参考文献

【1】Alex Krizhevsky, Ilya Sutskever & Geoffrey E. Hinton. ImageNet Classification with Deep Convolutional Neural Networks. NIPS, 2012.

【2】 http://tflearn.org/

利用 TFLearn 快速搭建经典深度学习模型的更多相关文章

  1. Roofline Model与深度学习模型的性能分析

    原文链接: https://zhuanlan.zhihu.com/p/34204282 最近在不同的计算平台上验证几种经典深度学习模型的训练和预测性能时,经常遇到模型的实际测试性能表现和自己计算出的复 ...

  2. Apple的Core ML3简介——为iPhone构建深度学习模型(附代码)

    概述 Apple的Core ML 3是一个为开发人员和程序员设计的工具,帮助程序员进入人工智能生态 你可以使用Core ML 3为iPhone构建机器学习和深度学习模型 在本文中,我们将为iPhone ...

  3. 用 Java 训练深度学习模型,原来可以这么简单!

    本文适合有 Java 基础的人群 作者:DJL-Keerthan&Lanking HelloGitHub 推出的<讲解开源项目> 系列.这一期是由亚马逊工程师:Keerthan V ...

  4. 利用MONAI加速医学影像学的深度学习研究

    利用MONAI加速医学影像学的深度学习研究 Accelerating Deep Learning Research in Medical Imaging Using MONAI 医学开放式人工智能网络 ...

  5. AI佳作解读系列(一)——深度学习模型训练痛点及解决方法

    1 模型训练基本步骤 进入了AI领域,学习了手写字识别等几个demo后,就会发现深度学习模型训练是十分关键和有挑战性的.选定了网络结构后,深度学习训练过程基本大同小异,一般分为如下几个步骤 定义算法公 ...

  6. 『高性能模型』Roofline Model与深度学习模型的性能分析

    转载自知乎:Roofline Model与深度学习模型的性能分析 在真实世界中,任何模型(例如 VGG / MobileNet 等)都必须依赖于具体的计算平台(例如CPU / GPU / ASIC 等 ...

  7. PyTorch如何构建深度学习模型?

    简介 每过一段时间,就会有一个深度学习库被开发,这些深度学习库往往可以改变深度学习领域的景观.Pytorch就是这样一个库. 在过去的一段时间里,我研究了Pytorch,我惊叹于它的操作简易.Pyto ...

  8. flask部署深度学习模型

    flask部署深度学习模型 作为著名Python web框架之一的Flask,具有简单轻量.灵活.扩展丰富且上手难度低的特点,因此成为了机器学习和深度学习模型上线跑定时任务,提供API的首选框架. 众 ...

  9. CUDA上深度学习模型量化的自动化优化

    CUDA上深度学习模型量化的自动化优化 深度学习已成功应用于各种任务.在诸如自动驾驶汽车推理之类的实时场景中,模型的推理速度至关重要.网络量化是加速深度学习模型的有效方法.在量化模型中,数据和模型参数 ...

随机推荐

  1. POJ 2449Remmarguts' Date 第K短路

    Remmarguts' Date Time Limit: 4000MS   Memory Limit: 65536K Total Submissions: 29625   Accepted: 8034 ...

  2. sqli-labs:7,导入导出;8-10 延时注入

    1,Load_file()导出文件 使用条件: A.必须有权限读取并且文件必须完全可读(and (select count(*) from mysql.user)>0/* 如果结果返回正常,说明 ...

  3. vue引用公用的头部和尾部文件。

    我创建了一个header.vue和fotter.vue,用来做于网站的头部和尾部,每个页面都需要引用这两个,我以组件的方式,来引用这样只需要添加注册的组件就可以了. 第一步.在components文件 ...

  4. Oracle12c的卸载

    之前电脑装了Oracle12c 现在希望删除重新安装: 参照教程: http://jingyan.baidu.com/article/642c9d34e1cbdd644a46f7de.html E:\ ...

  5. web项目目录结构

    eclipse web项目目录结构 按照 Java EE 规范的规定,一个典型的 Web 应用程序有四个部分: 1.  公开目录 ; 2. WEB-INF/web.xml 文件,发布描述符(必选) ; ...

  6. [网络]10M、100M、1000M网线的水晶头接法

    在网络维护过程中经常要自己制作网线,水晶头理论上是这样接的: 10M和100M和1000M以太网在使用网线时,对网线各自有不同的要求. 10M和100M在目前来说,连接网络的时候,只用到两对线来传输网 ...

  7. 使用tensorflow下的GPU加速神经网络训练过程

    下载CUDA8.0,安装 下载cuDNN v5.1安装.放置环境变量等. 其他版本就不装了.不用找其他版本的关系. 使用tensorflow-gpu1.0版本. 使用keras2.0版本. 有提示的. ...

  8. 求一个数的n次幂

    1.当这个数是2的多少次幂: 求(2^m)^n  =  2^(m*n) = 1<<m*n; 2.快速幂(要考虑底数a,和m的正负) int quick_mod(int a,int m){ ...

  9. 实习番外篇:解决C语言使用Makefile无法实现更好的持续集成问题

    工作中遇见的一个问题,提供项目源代码的情况下,希望对项目进行持续集成,达到一个C项目增量编译的效果.原本第一天是想通过模拟Makefile执行步骤来实现整个过程的,但是事实上发现整个Makefile显 ...

  10. 第18章:MongoDB-聚合操作--聚合管道--$sort

    ①$sort 使用“$sort”可以实现排序,设置1表示升序,设置-1表示降序. ②范例:实现排序