不加Dropout,训练数据的准确率高,基本上可以接近100%,但是,对于测试集来说,效果并不好;

加上Dropout,训练数据的准确率可能变低,但是,对于测试集来说,效果更好了,所以说Dropout可以防止过拟合。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 每个批次的大小
batch_size = 100
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32) # 创建一个简单的神经网络
W1 = tf.Variable(tf.truncated_normal([784, 2000], stddev=0.1))
b1 = tf.Variable(tf.zeros([2000]) + 0.1)
L1 = tf.nn.tanh(tf.matmul(x, W1) + b1)
L1_drop = tf.nn.dropout(L1, keep_prob) W2 = tf.Variable(tf.truncated_normal([2000, 2000], stddev=0.1))
b2 = tf.Variable(tf.zeros([2000]) + 0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop, W2) + b2)
L2_drop = tf.nn.dropout(L2, keep_prob) W3 = tf.Variable(tf.truncated_normal([2000, 1000], stddev=0.1))
b3 = tf.Variable(tf.zeros([1000]) + 0.1)
L3 = tf.nn.tanh(tf.matmul(L2_drop, W3) + b3)
L3_drop = tf.nn.dropout(L3, keep_prob) W4 = tf.Variable(tf.truncated_normal([1000, 10], stddev=0.1))
b4 = tf.Variable(tf.zeros([10]) + 0.1)
prediction = tf.nn.softmax(tf.matmul(L3_drop, W4) + b4) # 二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=prediction))
# 使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化变量
init = tf.global_variables_initializer() # 结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) #argmax返回一维张量中最大的值所在的位置
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) with tf.Session() as sess:
sess.run(init)
for epoch in range(31):
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys, keep_prob: 0.7}) test_acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0})
train_acc = sess.run(accuracy, feed_dict={x: mnist.train.images, y: mnist.train.labels, keep_prob: 1.0})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(test_acc) + ",Training Accuracy " + str(train_acc))

04Dropout的更多相关文章

随机推荐

  1. css内容过长显示省略号的几种解决方法

    单行文本(方法一): 语法: text-overflow : clip | ellipsis 参数: clip : 不显示省略标记(...),而是简单的裁切 (clip这个参数是不常用的!) elli ...

  2. ionic使用自定义icon

    参考文档:https://www.jianshu.com/p/5346fee9fd80  angular+ionic 自定义图标 注意: 这里不用name 用class类名显示出来 最后出来图标是个小 ...

  3. SpringBoot:配置文件及自动配置原理

    西部开源-秦疆老师:基于SpringBoot 2.1.6 的博客教程 秦老师交流Q群号: 664386224 未授权禁止转载!编辑不易 , 转发请注明出处!防君子不防小人,共勉! SpringBoot ...

  4. CSS基础-background的那些属性

    background的那些属性 background:背景的意思常用的六个属性 1.background-color:背景颜色 2.background-image:背景图像 3.background ...

  5. 20175214 《Java程序设计》第11周学习总结

    20175214 <Java程序设计>第11周学习总结 本周学习任务总结 1.根据<java2实用教程>和蓝墨云学习视频学习第十三章: 2.尝试将课本重点内容用自己的话复述手打 ...

  6. scipy几乎实现numpy的所有函数

    NumPy和SciPy的关系?   numpy提供了数组对象,面向的任何使用者.scipy在numpy的基础上,面向科学家和工程师,提供了更为精准和广泛的函数.scipy几乎实现numpy的所有函数, ...

  7. netflow-module

    https://www.elastic.co/guide/en/logstash/current/netflow-module.html

  8. 阶段1 语言基础+高级_1-3-Java语言高级_08-JDK8新特性_第4节 方法引用_7方法引用_数组的构造器引用

    先创建函数式接口 创建测试类 打印长度是10...... 方法引用优化

  9. jmeter接口测试初体验

    今天初体验了一把jmeter,把操作的一些经历贴出来,督促自己进步.等逐步掌握后再次回首时,希望是有所思的,欣慰的! jmeter: Apache JMeter是Apache组织开发的基于Java的压 ...

  10. WPF与DevExpress之——实现类似于安装程序下一步下一步的样式窗体

    话不多说先上图  点击下一步  跳转到第二页  项目准备: 1.DevExpress 19/18/17(三个版本都可以) 2.Vs2019 3..Net framework>4.0 项目结构: ...