04Dropout
不加Dropout,训练数据的准确率高,基本上可以接近100%,但是,对于测试集来说,效果并不好;
加上Dropout,训练数据的准确率可能变低,但是,对于测试集来说,效果更好了,所以说Dropout可以防止过拟合。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 每个批次的大小
batch_size = 100
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32) # 创建一个简单的神经网络
W1 = tf.Variable(tf.truncated_normal([784, 2000], stddev=0.1))
b1 = tf.Variable(tf.zeros([2000]) + 0.1)
L1 = tf.nn.tanh(tf.matmul(x, W1) + b1)
L1_drop = tf.nn.dropout(L1, keep_prob) W2 = tf.Variable(tf.truncated_normal([2000, 2000], stddev=0.1))
b2 = tf.Variable(tf.zeros([2000]) + 0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop, W2) + b2)
L2_drop = tf.nn.dropout(L2, keep_prob) W3 = tf.Variable(tf.truncated_normal([2000, 1000], stddev=0.1))
b3 = tf.Variable(tf.zeros([1000]) + 0.1)
L3 = tf.nn.tanh(tf.matmul(L2_drop, W3) + b3)
L3_drop = tf.nn.dropout(L3, keep_prob) W4 = tf.Variable(tf.truncated_normal([1000, 10], stddev=0.1))
b4 = tf.Variable(tf.zeros([10]) + 0.1)
prediction = tf.nn.softmax(tf.matmul(L3_drop, W4) + b4) # 二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=prediction))
# 使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化变量
init = tf.global_variables_initializer() # 结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) #argmax返回一维张量中最大的值所在的位置
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) with tf.Session() as sess:
sess.run(init)
for epoch in range(31):
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys, keep_prob: 0.7}) test_acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0})
train_acc = sess.run(accuracy, feed_dict={x: mnist.train.images, y: mnist.train.labels, keep_prob: 1.0})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(test_acc) + ",Training Accuracy " + str(train_acc))

04Dropout的更多相关文章
随机推荐
- bzoj 4298 [ONTAK2015]Bajtocja——哈希+启发式合并
题目:https://www.lydsy.com/JudgeOnline/problem.php?id=4298 题面: 给定d张无向图,每张图都有n个点.一开始,在任何一张图中都没有任何边.接下来有 ...
- SQL JOIN INNER LEFT RIGHT FULL
1.引用2个表(效果同INNER JOIN) SELECT Persons.LastName, Persons.FirstName, Orders.OrderNo FROM Persons, Ord ...
- 原生javascript代码懒加载
1.先定义需要懒加载的样式: class="lazyload" 2.设置初始透明度为0.1: .lazyload{ filter: Alpha(opacity=10); -moz- ...
- CSS札记(一):CSS选择器
一.语法规则 选择器{ 属性1:属性值1; 属性2:属性值2; ...... } /*注释*/ 二.如何在html中应用CSS 1. 外部引用css文件 css文件:css/layout.css(cs ...
- 《SQL Server 2012 T-SQL基础》读书笔记 - 1.背景
几个缩写的全称:Data Definition Language (DDL), Data Manipulation Language (DML), and Data Control Language ...
- 修改docker默认网段
一. 修改普通docker run启动的容器的网段 https://blog.51cto.com/13670314/2345518?source=dra https://blog.csdn.net/w ...
- 1>/dev/null 2>&1的含义
shell中可能经常能看到:>/dev/null 2>&1 分解这个组合:“>/dev/null 2>&1” 为五部分. 1:> 代表重定向到哪 ...
- JavaScript-Templates
https://github.com/blueimp/JavaScript-Templates https://blueimp.github.io/JavaScript-Templates/ http ...
- c# access oledb helper class
连接Access数据库 using System; using System.Collections.Generic; using System.Linq; using System.Text; us ...
- 阶段1 语言基础+高级_1-3-Java语言高级_06-File类与IO流_06 Properties集合_2_Properties集合中的方法store
第一行是注释,第二行是时间,时间是自动加的 使用FileOutputStream. 写入中文会乱码