TensorFlow实现多层感知机MINIST分类

TensorFlow 支持自动求导,可以使用 TensorFlow 优化器来计算和使用梯度。使用梯度自动更新用变量定义的张量。本文将使用 TensorFlow 优化器来训练网络。



前面定义了层、权重、损失、梯度以及通过梯度更新权重。用公式实现可以帮助我们更好地理解,但随着网络层数的增加,这可能非常麻烦。



使用 TensorFlow 的一些强大功能,如 Contrib(层)来定义神经网络层及使用 TensorFlow 自带的优化器来计算和使用梯度。



通过前面的学习,已经知道如何使用 TensorFlow 的优化器。Contrib 可以用来添加各种层到神经网络模型,如添加构建块。这里使用的一个方法是
tf.contrib.layers.fully_connected,在 TensorFlow 文档中定义如下:

这样就添加了一个全连接层。

提示:上面那段代码创建了一个称为权重的变量,表示全连接的权重矩阵,该矩阵与输入相乘产生隐藏层单元的张量。如果提供了 normalizer_fn(比如batch_norm),那么就会归一化。否则,如果 normalizer_fn 是 None,并且设置了 biases_initializer,则会创建一个偏置变量并将其添加到隐藏层单元中。最后,如果 activation_fn 不是 None,它也会被应用到隐藏层单元。

具体做法

第一步是改变损失函数,尽管对于分类任务,最好使用交叉熵损失函数。这里继续使用均方误差(MSE):



接下来,使用 GradientDescentOptimizer:



对于同一组超参数,只有这两处改变,在测试数据集上的准确率只有
61.3%。增加 max_epoch,可以提高准确性,但不能有效地发挥 TensorFlow 的能力。



这是一个分类问题,所以最好使用交叉熵损失,隐藏层使用 ReLU 激活函数,输出层使用 softmax 函数。做些必要的修改,完整代码如下所示:



解读分析

修改后的 MNIST MLP 分类器在测试数据集上只用了一个隐藏层,并且在 10 个 epoch 内,只需要几行代码,就可以得到 96% 的精度:

由此可见 TensorFlow 的强大之处。

TensorFlow实现多层感知机MINIST分类的更多相关文章

  1. TensorFlow实现多层感知机函数逼近

    TensorFlow实现多层感知机函数逼近 准备工作 对于函数逼近,这里的损失函数是 MSE.输入应该归一化,隐藏层是 ReLU,输出层最好是 Sigmoid. 下面是如何使用 MLP 进行函数逼近的 ...

  2. TensorFlow基础笔记(2) minist分类学习

    (1) 最简单的神经网络分类器 # encoding: UTF-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist i ...

  3. gluon 实现多层感知机MLP分类FashionMNIST

    from mxnet import gluon,init from mxnet.gluon import loss as gloss, nn from mxnet.gluon import data ...

  4. TensorFlow学习笔记7-深度前馈网络(多层感知机)

    深度前馈网络(前馈神经网络,多层感知机) 神经网络基本概念 前馈神经网络在模型输出和模型本身之间没有反馈连接;前馈神经网络包含反馈连接时,称为循环神经网络. 前馈神经网络用有向无环图表示. 设三个函数 ...

  5. 『TensorFlow』读书笔记_多层感知机

    多层感知机 输入->线性变换->Relu激活->线性变换->Softmax分类 多层感知机将mnist的结果提升到了98%左右的水平 知识点 过拟合:采用dropout解决,本 ...

  6. TensorFlow实现自编码器及多层感知机

    1 自动编码机简介        传统机器学习任务在很大程度上依赖于好的特征工程,比如对数值型,日期时间型,种类型等特征的提取.特征工程往往是非常耗时耗力的,在图像,语音和视频中提取到有效的特征就更难 ...

  7. Tensorflow 2.0 深度学习实战 —— 详细介绍损失函数、优化器、激活函数、多层感知机的实现原理

    前言 AI 人工智能包含了机器学习与深度学习,在前几篇文章曾经介绍过机器学习的基础知识,包括了监督学习和无监督学习,有兴趣的朋友可以阅读< Python 机器学习实战 >.而深度学习开始只 ...

  8. TensorFlow多层感知机函数逼近过程详解

    http://c.biancheng.net/view/1924.html Hornik 等人的工作(http://www.cs.cmu.edu/~bhiksha/courses/deeplearni ...

  9. [ DLPytorch ] 线性回归&Softmax与分类模型&多层感知机

    线性回归 基础知识 实现过程 学习笔记 批量读取 torch_data = Data.TensorDataset(features, labels) dataset = Data.DataLoader ...

随机推荐

  1. Mysql 8.0安装

    1. 下载安装包至/usr/local目录下 下载地址:https://cdn.mysql.com/Downloads/MySQL-8.0/mysql-8.0.16-el7-x86_64.tar.gz ...

  2. Win64 驱动内核编程-6.内核里操作注册表

    内核里操作注册表 RING0 操作注册表和 RING3 的区别也不大,同样是"获得句柄->执行操作->关闭句柄"的模式,同样也只能使用内核 API 不能使用 WIN32 ...

  3. Python练习3-XML-RPC实现简单的P2P文件共享

    XML-RPC实现简单的P2P文件共享 先来个百度百科: XML-RPC的全称是XML Remote Procedure Call,即XML(标准通用标记语言下的一个子集)远程过程调用.它是一套允许运 ...

  4. 【maven和jdk】报错:系统找不到指定的文件

    创建一个maven项目出错 问题描述 在idea.log出现如下错误(系统找不到指定的文件,但是不知道指定文件是什么) com.intellij.execution.process.ProcessNo ...

  5. Java发送邮件报错:com.sun.mail.util.LineOutputStream.<init>(Ljava/io/OutputStream;Z)V

    在练习使用Java程序发送邮件的代码 运行出现了com.sun.mail.util.LineOutputStream.<init>(Ljava/io/OutputStream;Z)V报错信 ...

  6. postman Variables变量的详解与应用

    变量 变量类型(按照作用域划分) 全局变量(全局环境里面的变量) 集合变量(请求集合里声明的变量) 自定义环境变量 数据变量(在runner时文件变量) 本地变量 变量权重类型 全局变量 < 集 ...

  7. hdu - 1716 排列2 (使用set对全排列结果去重)

    题意很简单,只是有几个细节要注意,首先就是一次只是输入四个数字.输出结果要从小到大(进行全排列之前要进行排序).题目要求千位数相同的在一行,中间使用空格隔开(第二次在输出的时候判断上一次记录的千位数是 ...

  8. C++ primer plus读书笔记——第13章 类继承

    第13章 类继承 1. 如果购买厂商的C库,除非厂商提供库函数的源代码,否则您将无法根据自己的需求,对函数进行扩展或修改.但如果是类库,只要其提供了类方法的头文件和编译后的代码,仍可以使用库中的类派生 ...

  9. OOP第一章总结

    经过了三周的OO,尽管过程不太轻松,但是有所得还是值得欣慰的事! (1)程序结构 第一次作业: UML类图如下,第一次作业在结构上并没有太多面向对象的思想,只是简单的分类,一个运行类,两个对象类,预处 ...

  10. Go - 开箱即用,WEB 界面一键安装,没有项目经验,可以拿这个练手

    安装界面 启动程序之后,会在浏览器中自动打开安装界面. 因为程序会使用到 Redis 和 MySQL,所以安装前请输入 Redis.MySQL 配置信息,点击初始化按钮,会将用到的数据表和默认数据进行 ...