8.Dropout
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #每个批次的大小
batch_size = 64
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size #定义三个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
keep_prob=tf.placeholder(tf.float32) # 784-1000-500-10
W1 = tf.Variable(tf.truncated_normal([784,1000],stddev=0.1))
b1 = tf.Variable(tf.zeros([1000])+0.1)
L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
L1_drop = tf.nn.dropout(L1,keep_prob) W2 = tf.Variable(tf.truncated_normal([1000,500],stddev=0.1))
b2 = tf.Variable(tf.zeros([500])+0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop,W2)+b2)
L2_drop = tf.nn.dropout(L2,keep_prob) W3 = tf.Variable(tf.truncated_normal([500,10],stddev=0.1))
b3 = tf.Variable(tf.zeros([10])+0.1)
prediction = tf.nn.softmax(tf.matmul(L2_drop,W3)+b3) #交叉熵
loss = tf.losses.softmax_cross_entropy(y,prediction)
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss) #初始化变量
init = tf.global_variables_initializer() #结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) with tf.Session() as sess:
sess.run(init)
for epoch in range(31):
for batch in range(n_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.5}) test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
train_acc = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels,keep_prob:1.0})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(test_acc) +",Training Accuracy " + str(train_acc))
Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
Iter 0,Testing Accuracy 0.9201,Training Accuracy 0.91234547
Iter 1,Testing Accuracy 0.9256,Training Accuracy 0.9229636
Iter 2,Testing Accuracy 0.9359,Training Accuracy 0.9328182
Iter 3,Testing Accuracy 0.9375,Training Accuracy 0.93716365
Iter 4,Testing Accuracy 0.9408,Training Accuracy 0.9411273
Iter 5,Testing Accuracy 0.9407,Training Accuracy 0.94365454
Iter 6,Testing Accuracy 0.9472,Training Accuracy 0.9484909
Iter 7,Testing Accuracy 0.9472,Training Accuracy 0.9502
Iter 8,Testing Accuracy 0.9516,Training Accuracy 0.95336366
Iter 9,Testing Accuracy 0.9522,Training Accuracy 0.95552725
Iter 10,Testing Accuracy 0.9525,Training Accuracy 0.95632726
Iter 11,Testing Accuracy 0.9566,Training Accuracy 0.9578909
Iter 12,Testing Accuracy 0.9574,Training Accuracy 0.9606182
Iter 13,Testing Accuracy 0.9573,Training Accuracy 0.96107274
Iter 14,Testing Accuracy 0.9587,Training Accuracy 0.9614546
Iter 15,Testing Accuracy 0.9581,Training Accuracy 0.9616727
Iter 16,Testing Accuracy 0.9599,Training Accuracy 0.96369094
Iter 17,Testing Accuracy 0.9601,Training Accuracy 0.96403635
Iter 18,Testing Accuracy 0.9618,Training Accuracy 0.9658909
Iter 19,Testing Accuracy 0.9608,Training Accuracy 0.9652
Iter 20,Testing Accuracy 0.9618,Training Accuracy 0.96607274
Iter 21,Testing Accuracy 0.9634,Training Accuracy 0.96794546
Iter 22,Testing Accuracy 0.9639,Training Accuracy 0.96836364
Iter 23,Testing Accuracy 0.964,Training Accuracy 0.96965456
Iter 24,Testing Accuracy 0.9644,Training Accuracy 0.9693091
Iter 25,Testing Accuracy 0.9647,Training Accuracy 0.9703818
Iter 26,Testing Accuracy 0.9639,Training Accuracy 0.9702
Iter 27,Testing Accuracy 0.9651,Training Accuracy 0.9708909
Iter 28,Testing Accuracy 0.9666,Training Accuracy 0.9711818
Iter 29,Testing Accuracy 0.9644,Training Accuracy 0.9710364
Iter 30,Testing Accuracy 0.9659,Training Accuracy 0.97205454
8.Dropout的更多相关文章
- 在RNN中使用Dropout
dropout在前向神经网络中效果很好,但是不能直接用于RNN,因为RNN中的循环会放大噪声,扰乱它自己的学习.那么如何让它适用于RNN,就是只将它应用于一些特定的RNN连接上. LSTM的长期记 ...
- Deep Learning 23:dropout理解_之读论文“Improving neural networks by preventing co-adaptation of feature detectors”
理论知识:Deep learning:四十一(Dropout简单理解).深度学习(二十二)Dropout浅层理解与实现.“Improving neural networks by preventing ...
- 正则化方法:L1和L2 regularization、数据集扩增、dropout
正则化方法:防止过拟合,提高泛化能力 在训练数据不够多时,或者overtraining时,常常会导致overfitting(过拟合).其直观的表现如下图所示,随着训练过程的进行,模型复杂度增加,在tr ...
- 深度学习(dropout)
other_techniques_for_regularization 随手翻译,略作参考,禁止转载 www.cnblogs.com/santian/p/5457412.html Dropout: D ...
- Deep learning:四十一(Dropout简单理解)
前言 训练神经网络模型时,如果训练样本较少,为了防止模型过拟合,Dropout可以作为一种trikc供选择.Dropout是hintion最近2年提出的,源于其文章Improving neural n ...
- 简单理解dropout
dropout是CNN(卷积神经网络)中的一个trick,能防止过拟合. 关于dropout的详细内容,还是看论文原文好了: Hinton, G. E., et al. (2012). "I ...
- [转]理解dropout
理解dropout 原文地址:http://blog.csdn.net/stdcoutzyx/article/details/49022443 理解dropout 注意:图片都在github上 ...
- [CS231n-CNN] Training Neural Networks Part 1 : parameter updates, ensembles, dropout
课程主页:http://cs231n.stanford.edu/ ___________________________________________________________________ ...
- 正则化,数据集扩增,Dropout
正则化方法:防止过拟合,提高泛化能力 在训练数据不够多时,或者overtraining时,常常会导致overfitting(过拟合).其直观的表现如下图所示,随着训练过程的进行,模型复杂度增加,在tr ...
- [Neural Networks] Dropout阅读笔记
多伦多大学Hinton组 http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf 一.目的 降低overfitting的风险 二.原理 ...
随机推荐
- freeRTOS学习8-22
互斥量的优先级继承机制可以减少优先级翻转问题,通过将持有锁的任务的优先级提升到当前任务中优先级任务最高的任务. #define KEY1_EVENT (0x01 << 0) #define ...
- Linux下部署FastDFS
FastDFS的安装 首先需要准备的资源如下: libfastcommon-master:https://github.com/happyfish100/libfastcommon FastD ...
- Django 框架学习 ---- 安装
这里引用了源码方式安装: 1.git clone https://github.com/django/django.git 2.cd django/ 3.python setup.py install ...
- [转帖]兆芯发布国产X86处理器KX-6000和KH-30000,性能提升达50%,附详情介绍
兆芯发布国产X86处理器KX-6000和KH-30000,性能提升达50%,附详情介绍 2019-06-20 09:56:38作者:linux人稿源:快科技 https://ywnz.com/linu ...
- [转帖]Xenix — 微软与UNIX的短暂爱恋
Xenix — 微软与UNIX的短暂爱恋 https://www.linuxdashen.com/xenix-微软与unix的短暂爱恋 原来微软曾经 干过那么牛B的 unix系统. 微软向外宣布Mic ...
- Base64encoder干什么用的
https://baike.baidu.com/item/base64/8545775?fr=aladdin BASE64加密算法.用来给字符串加密的.已经不安全了. 一直以来Base64的加密解密都 ...
- PAT B1006 换个格式输出整数 (15)
AC代码 #include <cstdio> const int max_n = 3; char radix[max_n] = {' ', 'S', 'B'}; int ans[max_n ...
- MySql常用字符集
常用字符集 位(bit):是计算机 内部数据 储存的最小单位,11001100是一个八位二进制数. 字节(byte):是计算机中 数据处理 的基本单位,习惯上用大写 B 来表示,1B(byte,字节) ...
- SSD目标检测实战(TF项目)——人脸检测2
数据转化为VOC格式: 一.我们先看 VOC格式的数据是什么??? Annotations:存放xml 包括 文件夹信息 图片名称. 图片尺寸信息. 图片中object的信息. JPEGImage ...
- X86逆向6:易语言程序的DIY
易语言程序在中国的用户量还是很大的,广泛用于外挂的开发,和一些小工具的编写,今天我们就来看下如何给易语言程序DIY,这里是用的易语言演示,当然这门技术也是可以应用到任何一门编译型语言中的,只要掌握合适 ...