多层感知机

输入->线性变换->Relu激活->线性变换->Softmax分类

多层感知机将mnist的结果提升到了98%左右的水平

知识点

过拟合:采用dropout解决,本质是bagging方法,相当于集成学习,注意dropout训练时设置为0~1的小数,测试时设置为1,不需要关闭节点

学习率难以设定:Adagrad等自适应学习率方法

深层网络梯度弥散:Relu激活取代sigmoid激活,不过输出层仍然使用sigmoid激活

对于ReLU激活函数,常用截断正态分布,避免0梯度和完全对称

对于Softmax分类(也就是sigmoid激活),由于对0附近最敏感,所以采用全0初始权重

代码如下

# Author : Hellcat
# Time : 2017/12/7 import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('../../../Mnist_data',one_hot=True)
sess = tf.InteractiveSession() in_units = 784
h1_units = 300 # 对于ReLU激活函数,常用截断正态分布,避免0梯度和完全对称
# 对于Softmax分类(也就是sigmoid激活),由于对0附近最敏感,所以采用全0初始权重
W1 = tf.Variable(tf.truncated_normal([in_units, h1_units],stddev=0.1))
b1 = tf.Variable(tf.zeros([h1_units], dtype=tf.float32))
W2 = tf.Variable(tf.zeros([h1_units, 10], dtype=tf.float32))
b2 = tf.Variable(tf.zeros([10], dtype=tf.float32)) x = tf.placeholder(tf.float32, [None, in_units])
y_ = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32) hidden1 = tf.nn.relu(tf.add(tf.matmul(x, W1), b1))
hidden1_drop = tf.nn.dropout(hidden1, keep_prob)
y = tf.nn.softmax(tf.add(tf.matmul(hidden1_drop, W2), b2)) cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), axis=1))
# train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
train_step = tf.train.AdagradOptimizer(0.3).minimize(cross_entropy) correct_prediction = tf.equal(tf.argmax(y,axis=1), tf.argmax(y_,axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.global_variables_initializer().run()
for i in range(3000):
batch_xs, batch_ys = mnist.train.next_batch(100)
train_step.run({x:batch_xs, y_:batch_ys, keep_prob:0.5})
if i % 100 == 0:
print('当前迭代次数{0},当前准确率{1:.3f}'.
format(i,accuracy.eval({x:batch_xs, y_:batch_ys, keep_prob:1.0})))
print(accuracy.eval({x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))

输出如下,

当前迭代次数0,当前准确率0.350
当前迭代次数100,当前准确率0.950
当前迭代次数200,当前准确率0.960
当前迭代次数300,当前准确率0.940
当前迭代次数400,当前准确率0.940
当前迭代次数500,当前准确率0.980
当前迭代次数600,当前准确率0.990
当前迭代次数700,当前准确率0.990
当前迭代次数800,当前准确率1.000
当前迭代次数900,当前准确率0.970
当前迭代次数1000,当前准确率0.980
当前迭代次数1100,当前准确率0.960
当前迭代次数1200,当前准确率1.000
当前迭代次数1300,当前准确率0.970
当前迭代次数1400,当前准确率0.990
当前迭代次数1500,当前准确率1.000
当前迭代次数1600,当前准确率1.000
当前迭代次数1700,当前准确率1.000
当前迭代次数1800,当前准确率0.980
当前迭代次数1900,当前准确率0.980
当前迭代次数2000,当前准确率1.000
当前迭代次数2100,当前准确率1.000
当前迭代次数2200,当前准确率1.000
当前迭代次数2300,当前准确率0.990
当前迭代次数2400,当前准确率1.000
当前迭代次数2500,当前准确率1.000
当前迭代次数2600,当前准确率0.990
当前迭代次数2700,当前准确率0.980
当前迭代次数2800,当前准确率0.990
当前迭代次数2900,当前准确率0.980
0.9778

有意思的是,使用tf.train.AdagradOptimizer()优化器时偶尔会出错,使用梯度下降优化器之后再修改回来就没问题了,可能是我的解释器出问题了。

『TensorFlow』读书笔记_多层感知机的更多相关文章

  1. 『TensorFlow』读书笔记_降噪自编码器

    『TensorFlow』降噪自编码器设计  之前学习过的代码,又敲了一遍,新的收获也还是有的,因为这次注释写的比较详尽,所以再次记录一下,具体的相关知识查阅之前写的文章即可(见上面链接). # Aut ...

  2. 『TensorFlow』读书笔记_进阶卷积神经网络_分类cifar10_上

    完整项目见:Github 完整项目中最终使用了ResNet进行分类,而卷积版本较本篇中结构为了提升训练效果也略有改动 本节主要介绍进阶的卷积神经网络设计相关,数据读入以及增强在下一节再与介绍 网络相关 ...

  3. 『TensorFlow』读书笔记_进阶卷积神经网络_分类cifar10_下

    数据读取部分实现 文中采用了tensorflow的从文件直接读取数据的方式,逻辑流程如下, 实现如下, # Author : Hellcat # Time : 2017/12/9 import os ...

  4. 『TensorFlow』读书笔记_简单卷积神经网络

    如果你可视化CNN的各层级结构,你会发现里面的每一层神经元的激活态都对应了一种特定的信息,越是底层的,就越接近画面的纹理信息,如同物品的材质. 越是上层的,就越接近实际内容(能说出来是个什么东西的那些 ...

  5. 『TensorFlow』读书笔记_VGGNet

    VGGNet网络介绍 VGG系列结构图, 『cs231n』卷积神经网络工程实践技巧_下 1,全部使用3*3的卷积核和2*2的池化核,通过不断加深网络结构来提升性能. 所有卷积层都是同样大小的filte ...

  6. 『TensorFlow』读书笔记_ResNet_V2

    『PyTorch × TensorFlow』第十七弹_ResNet快速实现 要点 神经网络逐层加深有Degradiation问题,准确率先上升到饱和,再加深会下降,这不是过拟合,是测试集和训练集同时下 ...

  7. 『TensorFlow』读书笔记_AlexNet

    网络结构 创新点 Relu激活函数:效果好于sigmoid,且解决了梯度弥散问题 Dropout层:Alexnet验证了dropout层的效果 重叠的最大池化:此前以平均池化为主,最大池化避免了平均池 ...

  8. 『TensorFlow』读书笔记_Inception_V3_下

    极为庞大的网络结构,不过下一节的ResNet也不小 线性的组成,结构大体如下: 常规卷积部分->Inception模块组1->Inception模块组2->Inception模块组3 ...

  9. 『TensorFlow』读书笔记_TFRecord学习

    一.程序介绍 1.包导入 # Author : Hellcat # Time : 17-12-29 import os import numpy as np np.set_printoptions(t ...

随机推荐

  1. SQL SERVER查询的临时文件路径

    C:\Users\用户\Documents\SQL Server Management Studio\Backup Files C:\Users\用户\AppData\Local\Temp\

  2. 在Windows Server 2008 R2 Server中,连接其他服务器的数据库遇到“未启用当前数据库的 SQL Server Service Broker,因此查询通知不受支持。如果希望使用通知,请为此数据库启用 Service Broker ”

    项目代码和数据库部署在不同的Windows Server 2008 R2 Server中,错误日志显示如下: "未启用当前数据库的 SQL Server Service Broker,因此查 ...

  3. SpringBoot整合Swagger2

    相信各位在公司写API文档数量应该不少,当然如果你还处在自己一个人开发前后台的年代,当我没说,如今为了前后台更好的对接,还是为了以后交接方便,都有要求写API文档. 手写Api文档的几个痛点: 文档需 ...

  4. svn 目录

    svn介绍 SVN与Git的区别 SVN服务的模式和多种访问方式 多种访问原理图解与优缺点 SVN安装部署 svn 部署 配置 配置svn用户及密码 配置svn用户及权限 svn 启动命令讲解 svn ...

  5. cocos2d-x JS 富文本

    var str1 = "兑换成功后,系统会生成“";var str2 = "红包兑换码";var str3 = "”,请复制该兑换码,并粘贴在&quo ...

  6. Cisco Packet Tracer

    ---恢复内容开始--- 1.简单局域网组建 交换机:2960  s1 终端设备:generic  pc 配置 pc1    单击>>Descktop>>IP configur ...

  7. CentOS 7 源码编译MariaDB

    下载源码包 安装 SCL  devtoolset-7 SCL(Software Collections)可以让你在同一个操作系统上安装和使用多个版本的软件,而不会影响整个系统的安装包.SCL为社区的以 ...

  8. 无需激活直接同步登入discuz,php代码(直接可用)

    <?php /** * 抽奖 * @param int $total */ function getReward($total=1000) { $win1 = floor((0.12*$tota ...

  9. app埋点

    目前APP埋点的主流有两种方式: 第一类是预先设定好想要获取的目标数据,让程序员撰写代码把“采集器”埋到相应的页面上,用于追踪和记录的用户的行为,并把实时数据传送到后台数据库或者客户端. 第二类方法是 ...

  10. nginx与PHP配置

    一.安装依赖包 yum -y install  libxml2  libxml2-devel  openssl  openssl-devel  curl  curl-devel libjpeg  li ...