TensorFlow实现多层感知机MINIST分类
TensorFlow实现多层感知机MINIST分类
TensorFlow 支持自动求导,可以使用 TensorFlow 优化器来计算和使用梯度。使用梯度自动更新用变量定义的张量。本文将使用 TensorFlow 优化器来训练网络。
前面定义了层、权重、损失、梯度以及通过梯度更新权重。用公式实现可以帮助我们更好地理解,但随着网络层数的增加,这可能非常麻烦。
使用 TensorFlow 的一些强大功能,如 Contrib(层)来定义神经网络层及使用 TensorFlow 自带的优化器来计算和使用梯度。
通过前面的学习,已经知道如何使用 TensorFlow 的优化器。Contrib 可以用来添加各种层到神经网络模型,如添加构建块。这里使用的一个方法是
tf.contrib.layers.fully_connected,在 TensorFlow 文档中定义如下:

这样就添加了一个全连接层。
提示:上面那段代码创建了一个称为权重的变量,表示全连接的权重矩阵,该矩阵与输入相乘产生隐藏层单元的张量。如果提供了 normalizer_fn(比如batch_norm),那么就会归一化。否则,如果 normalizer_fn 是 None,并且设置了 biases_initializer,则会创建一个偏置变量并将其添加到隐藏层单元中。最后,如果 activation_fn 不是 None,它也会被应用到隐藏层单元。
具体做法
第一步是改变损失函数,尽管对于分类任务,最好使用交叉熵损失函数。这里继续使用均方误差(MSE):

接下来,使用 GradientDescentOptimizer:

对于同一组超参数,只有这两处改变,在测试数据集上的准确率只有
61.3%。增加 max_epoch,可以提高准确性,但不能有效地发挥 TensorFlow 的能力。
这是一个分类问题,所以最好使用交叉熵损失,隐藏层使用 ReLU 激活函数,输出层使用 softmax 函数。做些必要的修改,完整代码如下所示:

解读分析
修改后的 MNIST MLP 分类器在测试数据集上只用了一个隐藏层,并且在 10 个 epoch 内,只需要几行代码,就可以得到 96% 的精度:

由此可见 TensorFlow 的强大之处。
TensorFlow实现多层感知机MINIST分类的更多相关文章
- TensorFlow实现多层感知机函数逼近
TensorFlow实现多层感知机函数逼近 准备工作 对于函数逼近,这里的损失函数是 MSE.输入应该归一化,隐藏层是 ReLU,输出层最好是 Sigmoid. 下面是如何使用 MLP 进行函数逼近的 ...
- TensorFlow基础笔记(2) minist分类学习
(1) 最简单的神经网络分类器 # encoding: UTF-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist i ...
- gluon 实现多层感知机MLP分类FashionMNIST
from mxnet import gluon,init from mxnet.gluon import loss as gloss, nn from mxnet.gluon import data ...
- TensorFlow学习笔记7-深度前馈网络(多层感知机)
深度前馈网络(前馈神经网络,多层感知机) 神经网络基本概念 前馈神经网络在模型输出和模型本身之间没有反馈连接;前馈神经网络包含反馈连接时,称为循环神经网络. 前馈神经网络用有向无环图表示. 设三个函数 ...
- 『TensorFlow』读书笔记_多层感知机
多层感知机 输入->线性变换->Relu激活->线性变换->Softmax分类 多层感知机将mnist的结果提升到了98%左右的水平 知识点 过拟合:采用dropout解决,本 ...
- TensorFlow实现自编码器及多层感知机
1 自动编码机简介 传统机器学习任务在很大程度上依赖于好的特征工程,比如对数值型,日期时间型,种类型等特征的提取.特征工程往往是非常耗时耗力的,在图像,语音和视频中提取到有效的特征就更难 ...
- Tensorflow 2.0 深度学习实战 —— 详细介绍损失函数、优化器、激活函数、多层感知机的实现原理
前言 AI 人工智能包含了机器学习与深度学习,在前几篇文章曾经介绍过机器学习的基础知识,包括了监督学习和无监督学习,有兴趣的朋友可以阅读< Python 机器学习实战 >.而深度学习开始只 ...
- TensorFlow多层感知机函数逼近过程详解
http://c.biancheng.net/view/1924.html Hornik 等人的工作(http://www.cs.cmu.edu/~bhiksha/courses/deeplearni ...
- [ DLPytorch ] 线性回归&Softmax与分类模型&多层感知机
线性回归 基础知识 实现过程 学习笔记 批量读取 torch_data = Data.TensorDataset(features, labels) dataset = Data.DataLoader ...
随机推荐
- 解决小程序中Data.parse()获取时间戳IOS不兼容
由于与后台接口必须对比时间戳所以首先得前台获取时间戳.刚开始是获取手机本地时间,但用户改了时间就废了..... 后来就从服务器上获取个时间再转换为时间戳(是不是很操蛋,先从服务器上获取在TM的自己比较 ...
- 【转】如何用MTR诊断网络问题
MTR 是一个强大的网络诊断工具,管理员能够用它诊断和隔离网络错误,并向上游提供商提供有关网络状态的有用报告.MTR 通过更大的采样来跟踪路由,就像 traceroute + ping 命令的组合.本 ...
- SSM框架MavenWeb项目的测试
由于SSM项目的类都是由Spring容器托管,所以直接进行用new对象调用方法进行测试是不行不通的,会出现空指针异常NullPointExpection. 因为我们的对象由spring进行托管,调用的 ...
- 自动化测试面试官:登录或注册时有验证码怎么处理?OCR图像识别技术大揭秘!
本节大纲 读取cookie实现免登陆 pytesseract+tesseract-ocr实现图像识别 Pillow库对验证码截图 API接口实现图像识别 今天的这个技术点,为什么要给大家分享一下呢? ...
- 认识WPF
新开一节WPF桌面开发的讲解,这节先初步认识一下什么是WPF. 1.简介 WPF是 Windows Presentation Foundation 的英文缩写,意为"窗体呈现基础" ...
- 初识Vue2(一):表单输入绑定(附Demo)
在线演示 http://demo.xiongze.net/ 下载地址 https://gitee.com/xiongze/Vue2.git js引用 <!--这里可以自己下载下来引用,也可以使用 ...
- Uva 642 - Word Amalgamation sort qsort
Word Amalgamation In millions of newspapers across the United States there is a word game called J ...
- Scrum Meeting 3
Basic Info where:三号教学楼 when:2020/4/27 target: 简要汇报一下已完成任务,下一步计划与遇到的问题 Progress Team Member Position ...
- 在ZOHO企业网盘中如何快速搜索文件?
现在越来越多的企业采用企业网盘来存储文档和资料,而且现在市面上的企业网盘各种各样.在使用企业网盘过程中,很多用户会问到企业网盘中如何快速搜索文件的问题.但是无论是"标签"功能还是普 ...
- 屌炸天的3D引擎OpenCASCADE的用法及案例(转载之处:)
What CASCADE? Open CASCADE(简称OCC)平台是由法国Matra Datavision公司开发的CAD/CAE/CAM软件平台,可以说是世界上最重要的几何造型基础软件平台之一. ...