04Dropout
不加Dropout,训练数据的准确率高,基本上可以接近100%,但是,对于测试集来说,效果并不好;
加上Dropout,训练数据的准确率可能变低,但是,对于测试集来说,效果更好了,所以说Dropout可以防止过拟合。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 每个批次的大小
batch_size = 100
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32) # 创建一个简单的神经网络
W1 = tf.Variable(tf.truncated_normal([784, 2000], stddev=0.1))
b1 = tf.Variable(tf.zeros([2000]) + 0.1)
L1 = tf.nn.tanh(tf.matmul(x, W1) + b1)
L1_drop = tf.nn.dropout(L1, keep_prob) W2 = tf.Variable(tf.truncated_normal([2000, 2000], stddev=0.1))
b2 = tf.Variable(tf.zeros([2000]) + 0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop, W2) + b2)
L2_drop = tf.nn.dropout(L2, keep_prob) W3 = tf.Variable(tf.truncated_normal([2000, 1000], stddev=0.1))
b3 = tf.Variable(tf.zeros([1000]) + 0.1)
L3 = tf.nn.tanh(tf.matmul(L2_drop, W3) + b3)
L3_drop = tf.nn.dropout(L3, keep_prob) W4 = tf.Variable(tf.truncated_normal([1000, 10], stddev=0.1))
b4 = tf.Variable(tf.zeros([10]) + 0.1)
prediction = tf.nn.softmax(tf.matmul(L3_drop, W4) + b4) # 二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=prediction))
# 使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化变量
init = tf.global_variables_initializer() # 结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) #argmax返回一维张量中最大的值所在的位置
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) with tf.Session() as sess:
sess.run(init)
for epoch in range(31):
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys, keep_prob: 0.7}) test_acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0})
train_acc = sess.run(accuracy, feed_dict={x: mnist.train.images, y: mnist.train.labels, keep_prob: 1.0})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(test_acc) + ",Training Accuracy " + str(train_acc))

04Dropout的更多相关文章
随机推荐
- 【转】结构化日志类库 ---- Serilog库
源地址:https://www.cnblogs.com/mq0036/p/8479956.html 解决异常: Invalid cast from 'System.String' to 'Serilo ...
- 解决axios异步问题
- Mac sublime安装package controller
https://packagecontrol.io/installation#st2 链接被墙了这个. 我拿来放在这里. The simplest method of installation is ...
- 1px渲染成2px的场景及解决方案
1.场景一: IE6 下默认div最小高度为2px,如何创建高为1px的容器? .minContainer{font-size:0px;overflow:hidden} 2.场景二: 移动端高分辨 ...
- java将url里面的中文改成ASCII字符集 和 SCII字符集 改成 中文
package com.example.demo; import java.io.UnsupportedEncodingException; import java.net.URLDecoder; / ...
- 012-elasticsearch5.4.3【五】-搜索API【一】搜索匹配所有matchAllQuery、全文查询[matchQuery、multiMatchQuery、commonTermsQuery、queryStringQuery、simpleQueryStringQuery]
一.概述 查询所使用的 QueryBuilders来源于以下 import static org.elasticsearch.index.query.QueryBuilders.*; 请注意,您可以使 ...
- delphi 函数isiconic 函数 判断窗口是否最小化
http://blog.sina.com.cn/s/blog_66357ab901012t2h.html delphi 函数isiconic 函数 判断窗口是否最小化 (2012-05-26 22:0 ...
- 字符串中的TOUPPER函数
std::string& str_toupper(std::string& s) { std::transform(s.begin(), s.end(), s.begin(), []( ...
- G2 基本使用 折线图 柱状图 饼图 基本配置
G2的基本使用 1.浏览器引入 <!-- 引入在线资源 --> <script src="https://gw.alipayobjects.com/os/lib/antv ...
- Python3之异常处理
写自动化脚本时经常会用到异常处理,下面将python中的异常处理做一整理: 注意:以下所有事列中的111.txt文件不存在,所以会引起异常 用法一:try...except...else..类型 1. ...