多层感知机

输入->线性变换->Relu激活->线性变换->Softmax分类

多层感知机将mnist的结果提升到了98%左右的水平

知识点

过拟合:采用dropout解决,本质是bagging方法,相当于集成学习,注意dropout训练时设置为0~1的小数,测试时设置为1,不需要关闭节点

学习率难以设定:Adagrad等自适应学习率方法

深层网络梯度弥散:Relu激活取代sigmoid激活,不过输出层仍然使用sigmoid激活

对于ReLU激活函数,常用截断正态分布,避免0梯度和完全对称

对于Softmax分类(也就是sigmoid激活),由于对0附近最敏感,所以采用全0初始权重

代码如下

# Author : Hellcat
# Time : 2017/12/7 import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('../../../Mnist_data',one_hot=True)
sess = tf.InteractiveSession() in_units = 784
h1_units = 300 # 对于ReLU激活函数,常用截断正态分布,避免0梯度和完全对称
# 对于Softmax分类(也就是sigmoid激活),由于对0附近最敏感,所以采用全0初始权重
W1 = tf.Variable(tf.truncated_normal([in_units, h1_units],stddev=0.1))
b1 = tf.Variable(tf.zeros([h1_units], dtype=tf.float32))
W2 = tf.Variable(tf.zeros([h1_units, 10], dtype=tf.float32))
b2 = tf.Variable(tf.zeros([10], dtype=tf.float32)) x = tf.placeholder(tf.float32, [None, in_units])
y_ = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32) hidden1 = tf.nn.relu(tf.add(tf.matmul(x, W1), b1))
hidden1_drop = tf.nn.dropout(hidden1, keep_prob)
y = tf.nn.softmax(tf.add(tf.matmul(hidden1_drop, W2), b2)) cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), axis=1))
# train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
train_step = tf.train.AdagradOptimizer(0.3).minimize(cross_entropy) correct_prediction = tf.equal(tf.argmax(y,axis=1), tf.argmax(y_,axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.global_variables_initializer().run()
for i in range(3000):
batch_xs, batch_ys = mnist.train.next_batch(100)
train_step.run({x:batch_xs, y_:batch_ys, keep_prob:0.5})
if i % 100 == 0:
print('当前迭代次数{0},当前准确率{1:.3f}'.
format(i,accuracy.eval({x:batch_xs, y_:batch_ys, keep_prob:1.0})))
print(accuracy.eval({x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))

输出如下,

当前迭代次数0,当前准确率0.350
当前迭代次数100,当前准确率0.950
当前迭代次数200,当前准确率0.960
当前迭代次数300,当前准确率0.940
当前迭代次数400,当前准确率0.940
当前迭代次数500,当前准确率0.980
当前迭代次数600,当前准确率0.990
当前迭代次数700,当前准确率0.990
当前迭代次数800,当前准确率1.000
当前迭代次数900,当前准确率0.970
当前迭代次数1000,当前准确率0.980
当前迭代次数1100,当前准确率0.960
当前迭代次数1200,当前准确率1.000
当前迭代次数1300,当前准确率0.970
当前迭代次数1400,当前准确率0.990
当前迭代次数1500,当前准确率1.000
当前迭代次数1600,当前准确率1.000
当前迭代次数1700,当前准确率1.000
当前迭代次数1800,当前准确率0.980
当前迭代次数1900,当前准确率0.980
当前迭代次数2000,当前准确率1.000
当前迭代次数2100,当前准确率1.000
当前迭代次数2200,当前准确率1.000
当前迭代次数2300,当前准确率0.990
当前迭代次数2400,当前准确率1.000
当前迭代次数2500,当前准确率1.000
当前迭代次数2600,当前准确率0.990
当前迭代次数2700,当前准确率0.980
当前迭代次数2800,当前准确率0.990
当前迭代次数2900,当前准确率0.980
0.9778

有意思的是,使用tf.train.AdagradOptimizer()优化器时偶尔会出错,使用梯度下降优化器之后再修改回来就没问题了,可能是我的解释器出问题了。

『TensorFlow』读书笔记_多层感知机的更多相关文章

  1. 『TensorFlow』读书笔记_降噪自编码器

    『TensorFlow』降噪自编码器设计  之前学习过的代码,又敲了一遍,新的收获也还是有的,因为这次注释写的比较详尽,所以再次记录一下,具体的相关知识查阅之前写的文章即可(见上面链接). # Aut ...

  2. 『TensorFlow』读书笔记_进阶卷积神经网络_分类cifar10_上

    完整项目见:Github 完整项目中最终使用了ResNet进行分类,而卷积版本较本篇中结构为了提升训练效果也略有改动 本节主要介绍进阶的卷积神经网络设计相关,数据读入以及增强在下一节再与介绍 网络相关 ...

  3. 『TensorFlow』读书笔记_进阶卷积神经网络_分类cifar10_下

    数据读取部分实现 文中采用了tensorflow的从文件直接读取数据的方式,逻辑流程如下, 实现如下, # Author : Hellcat # Time : 2017/12/9 import os ...

  4. 『TensorFlow』读书笔记_简单卷积神经网络

    如果你可视化CNN的各层级结构,你会发现里面的每一层神经元的激活态都对应了一种特定的信息,越是底层的,就越接近画面的纹理信息,如同物品的材质. 越是上层的,就越接近实际内容(能说出来是个什么东西的那些 ...

  5. 『TensorFlow』读书笔记_VGGNet

    VGGNet网络介绍 VGG系列结构图, 『cs231n』卷积神经网络工程实践技巧_下 1,全部使用3*3的卷积核和2*2的池化核,通过不断加深网络结构来提升性能. 所有卷积层都是同样大小的filte ...

  6. 『TensorFlow』读书笔记_ResNet_V2

    『PyTorch × TensorFlow』第十七弹_ResNet快速实现 要点 神经网络逐层加深有Degradiation问题,准确率先上升到饱和,再加深会下降,这不是过拟合,是测试集和训练集同时下 ...

  7. 『TensorFlow』读书笔记_AlexNet

    网络结构 创新点 Relu激活函数:效果好于sigmoid,且解决了梯度弥散问题 Dropout层:Alexnet验证了dropout层的效果 重叠的最大池化:此前以平均池化为主,最大池化避免了平均池 ...

  8. 『TensorFlow』读书笔记_Inception_V3_下

    极为庞大的网络结构,不过下一节的ResNet也不小 线性的组成,结构大体如下: 常规卷积部分->Inception模块组1->Inception模块组2->Inception模块组3 ...

  9. 『TensorFlow』读书笔记_TFRecord学习

    一.程序介绍 1.包导入 # Author : Hellcat # Time : 17-12-29 import os import numpy as np np.set_printoptions(t ...

随机推荐

  1. Git命令行基本操作

    Git--- download网址:https://git-scm.com/downloads 0. 安装Git 网上有很多Git安装教程,如果需要图形界面,windows下建议使用TortoiseG ...

  2. java框架之Hibernate(3)-一对多和多对多关系操作

    一对多 例:一个班级可以有多个学生,而一个学生只能属于一个班级. 模型 package com.zze.bean; import java.util.HashSet; import java.util ...

  3. Cocos Creator 加载和切换场景(官方文档摘录)

    Cocos Creator 加载和切换场景(官方文档摘录) 在 Cocos Creator 中,我们使用场景文件名( 可以不包含扩展名)来索引指代场景.并通过以下接口进行加载和切换操作: cc.dir ...

  4. idea快捷键使用

    idea                                 eclipse project                           workspace module     ...

  5. jQuery-form实现文件分步上传

    分步上传:当你需要提交两个及以上的文件,在一个文件成功后再提交另一个文件,并且最后需要提交所有文件的地址组成的数据 HTML: <form id="uploadVideoForm&qu ...

  6. python_函数式编程

    函数式编程是一种编程范式 (而面向过程编程和面向对象编程也都是编程范式).在面向过程编程中,我们见到过函数(function):在面向对象编程中,我们见过对象(object).函数和对象的根本目的是以 ...

  7. linq 表分组后关联查询

    测试linq,获取有教师名额的学校.比如学校有5个教师名额,teacher数量没超过5个,发现有空额 var query = (from teacher in _repositoryTeacher.T ...

  8. count列表中字符出现的次数

    如何count列表中字符出现的次数?可以将其生成一个字典.key是列表中的字符串,value是出现的次数 例如gen = [2, 3, 4, 5, 6, 7, 3, 4, 5, 6, 7, 8, 4, ...

  9. lua语言中的假

    [1]测试及结论 (1)代码 local var_false = false local var_nil = nil if var_zero then print('var_zero : true') ...

  10. java String 类型总结

    java中String是个对象,是引用类型?,基础类型与引用类型的区别是,基础类型只表示简单的字符或数字,引用类型可以是任何复杂的数据结构,基本类型仅表示简单的数据类型,引用类型可以表示复杂的数据类型 ...