————————————————————————————————————

写在开头:此文参照莫烦python教程(墙裂推荐!!!)

————————————————————————————————————

dropout解决overfitting问题

  • overfitting:当机器学习学习得太好了,就会出现过拟合(overfitting)问题。所以,我们就要采取一些措施来避免过拟合的问题。此实验就来看一下dropout对于解决过拟合问题的效果。
  • 例子实验内容:识别手写数字。此实验的步骤和上一篇的识别手写数字步骤很相似。
  • 例子实验的数据集:sklearn中的datasets

  • 主要运用的函数tf.nn.dropout()

  • 主要参数keep_prob。keep_prob表示留下来的结果的百分比,比如你要drop0.4,那么keep_prob就为0.6
import tensorflow as tf
from sklearn.datasets import load_digits
from sklearn.cross_validation import train_test_split
from sklearn.preprocessing import LabelBinarizer #加载数据
digits = load_digits()
X = digits.data
y = digits.target
y = LabelBinarizer().fit_transform(y) #把数字变成1x10的向量
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size = .3) #把数据分成train数据和test数据 #定义添加层
def add_layer(inputs,in_size,out_size,activation_function=None):
#定义添加层内容,返回这层的outputs
Weights = tf.Variable(tf.random_normal([in_size,out_size]))#Weigehts是一个in_size行、out_size列的矩阵,开始时用随机数填满
biases = tf.Variable(tf.zeros([1,out_size])+0.1) #biases是一个1行out_size列的矩阵,用0.1填满
Wx_plus_b = tf.matmul(inputs,Weights)+biases #预测
#实现dropout,keep_drop为丢弃后剩下的百分比
Wx_plus_b = tf.nn.dropout(Wx_plus_b, keep_prob)
if activation_function is None: #如果没有激励函数,那么outputs就是预测值
outputs = Wx_plus_b
else: #如果有激励函数,那么outputs就是激励函数作用于预测值之后的值
outputs = activation_function(Wx_plus_b)
return outputs #定义计算正确率的函数
def t_accuracy(t_xs,t_ys):
global prediction
y_pre = sess.run(prediction,feed_dict={xs:t_xs,keep_prob:1})#测试结果不dropout
correct_pre = tf.equal(tf.argmax(y_pre,1),tf.argmax(t_ys,1))
accuracy = tf.reduce_mean(tf.cast(correct_pre,tf.float32))
result = sess.run(accuracy,feed_dict={xs:t_xs,ys:t_ys,keep_prob:1})
return result #定义输入输出值,和keep_drop值
keep_prob = tf.placeholder(tf.float32)
xs = tf.placeholder(tf.float32, [None, 64]) # 8x8
ys = tf.placeholder(tf.float32, [None, 10]) #添加层
l1 = add_layer(xs, 64, 50,activation_function=tf.nn.tanh)
prediction = add_layer(l1, 50, 10,activation_function=tf.nn.softmax) #误差
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),reduction_indices=[1])) # loss #训练
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) #开始训练
sess = tf.Session()
merged = tf.summary.merge_all()
init = tf.global_variables_initializer()
sess.run(init)
for i in range(1000):
# 设置keep_drop为1,即不进行dropout
sess.run(train_step, feed_dict={xs: X_train, ys: y_train, keep_prob: 1})
if i % 50 == 0:
# 输出正确率
print (t_accuracy(X_test,y_test))
0.20925926
0.7574074
0.81296295
0.8388889
0.85555553
0.8537037
0.84814817
0.8537037
0.85555553
0.8537037
0.85555553
0.8537037
0.8574074
0.85555553
0.8574074
0.8574074
0.8611111
0.8574074
0.85925925
0.8611111
for i in range(1000):
# 设置keep_drop为0.5
sess.run(train_step, feed_dict={xs: X_train, ys: y_train, keep_prob: 0.5})
if i % 50 == 0:
# 输出正确率
print (t_accuracy(X_test,y_test))
0.86851853
0.89444447
0.91481483
0.9166667
0.91481483
0.9222222
0.9259259
0.9222222
0.9296296
0.94074076
0.94074076
0.9351852
0.9351852
0.9351852
0.9351852
0.93333334
0.94074076
0.9351852
0.93703705
0.9351852

由上面的结果可知,当dropout为0.5时,效果明显比一点儿也不丢弃的好!


*点击[这儿:TensorFlow]发现更多关于TensorFlow的文章*


4 TensorFlow入门之dropout解决overfitting问题的更多相关文章

  1. tensorflow学习之(八)使用dropout解决overfitting(过拟合)问题

    #使用dropout解决overfitting(过拟合)问题 #如果有dropout,在feed_dict的参数中一定要加入dropout的值 import tensorflow as tf from ...

  2. TensorFlow实战第七课(dropout解决overfitting)

    Dropout 解决 overfitting overfitting也被称为过度学习,过度拟合.他是机器学习中常见的问题. 图中的黑色曲线是正常模型,绿色曲线就是overfitting模型.尽管绿色曲 ...

  3. tensorflow用dropout解决over fitting-【老鱼学tensorflow】

    在机器学习中可能会存在过拟合的问题,表现为在训练集上表现很好,但在测试集中表现不如训练集中的那么好. 图中黑色曲线是正常模型,绿色曲线就是overfitting模型.尽管绿色曲线很精确的区分了所有的训 ...

  4. tensorflow用dropout解决over fitting

    在机器学习中可能会存在过拟合的问题,表现为在训练集上表现很好,但在测试集中表现不如训练集中的那么好. 图中黑色曲线是正常模型,绿色曲线就是overfitting模型.尽管绿色曲线很精确的区分了所有的训 ...

  5. #tensorflow入门(1)

    tensorflow入门(1) 关于 TensorFlow TensorFlow™ 是一个采用数据流图(data flow graphs),用于数值计算的开源软件库.节点(Nodes)在图中表示数学操 ...

  6. TensorFlow入门(五)多层 LSTM 通俗易懂版

    欢迎转载,但请务必注明原文出处及作者信息. @author: huangyongye @creat_date: 2017-03-09 前言: 根据我本人学习 TensorFlow 实现 LSTM 的经 ...

  7. 转:TensorFlow入门(六) 双端 LSTM 实现序列标注(分词)

    http://blog.csdn.net/Jerr__y/article/details/70471066 欢迎转载,但请务必注明原文出处及作者信息. @author: huangyongye @cr ...

  8. TensorFlow 入门之手写识别CNN 三

    TensorFlow 入门之手写识别CNN 三 MNIST 卷积神经网络 Fly 多层卷积网络 多层卷积网络的基本理论 构建一个多层卷积网络 权值初始化 卷积和池化 第一层卷积 第二层卷积 密集层连接 ...

  9. (转)TensorFlow 入门

        TensorFlow 入门 本文转自:http://www.jianshu.com/p/6766fbcd43b9 字数3303 阅读904 评论3 喜欢5 CS224d-Day 2: 在 Da ...

随机推荐

  1. I帧、B帧、P帧、NALU类型

    i帧 i frame,即内部画面 intra picture,通常是GOP的第一个帧(即IDR)I帧是最大去除图像空间冗余信息而压缩得到的帧,自带全部信息,不参考其他帧可独立解码,称为帧内编码帧所有视 ...

  2. 【实验一 】Spring Boot 集成 hibernate & JPA

    转眼间,2018年的十二分之一都快过完了,忙于各类事情,博客也都快一个月没更新了.今天我们继续来学习Springboot对象持久化. 首先JPA是Java持久化API,定义了一系列对象持久化的标准,而 ...

  3. php 判断时间是否超过

    $str="2014-10-11"; echo "".strtotime($str); echo "<br/>"; echo & ...

  4. 数据挖据之GeoHash核心原理解析

    引子 机机是个好动又好学的孩子,平日里就喜欢拿着手机地图点点按按来查询一些好玩的东西.某一天机机到北海公园游玩,肚肚饿了,于是乎打开手机地图,搜索北海公园附近的餐馆,并选了其中一家用餐. 饭饱之后机机 ...

  5. petrozavodsk summer 2018 游记&&总结

    day0: 出发前训了一场比较水bapc2017保持手感(恢复信心),成功AK了,不过罚时略高.然后三人打车从紫金港到杭州东站,坐高铁到上海虹桥,再坐机场快线到浦东机场(傻乎乎的jsb帮爸爸付了钱,然 ...

  6. 【BZOJ】1687: [Usaco2005 Open]Navigating the City 城市交通(bfs)

    http://www.lydsy.com/JudgeOnline/problem.php?id=1687 bfs后然后逆向找图即可.因为题目保证最短路唯一 #include <cstdio> ...

  7. iOS开发-你真的会用SDWebImage?

    SDWebImage作为眼下最受欢迎的图片下载第三方框架,使用率非常高.可是你真的会用吗?本文接下来将通过例子分析怎样合理使用SDWebImage. 使用场景:自己定义的UITableViewCell ...

  8. WinCC7.3 Win764位系统安装教程

    WinCC7.3 Win764位安装教程 (1)将ISO文件解压缩. (2)编辑Setup.ini文件 watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQv/fo ...

  9. 卡友pos机使用流程

    Q: pos机正常使用步骤 A: 1. 按开机键开机2. 输入“01”进行签到3. 系统提示输入密码,密码为“0000”4. 系统提示“请刷卡”,可正常刷卡消费首次使用请务必登陆商户后台核对结算收款账 ...

  10. 快速开发微信小程序

    image.png 最近婷主在做微信小程序.自己的微信公众号也需要添加点料,乘着这次放假,把微信小程序研究了下.虽然没有做什么很强大的功能,不过好歹自己的公众号也有了微信小程序.够用即可. 1.需要先 ...