import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 # 输入节点
OUTPUT_NODE = 10 # 输出节点
LAYER1_NODE = 500 # 隐藏层数 BATCH_SIZE = 100 # 每次batch打包的样本个数 # 模型相关的参数
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARAZTION_RATE = 0.0001
TRAINING_STEPS = 5000 def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2):
# 不使用滑动平均类
if avg_class == None:
layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)
return tf.matmul(layer1, weights2) + biases2
else:
# 使用滑动平均类
layer1 = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weights1)) + avg_class.average(biases1))
return tf.matmul(layer1, avg_class.average(weights2)) + avg_class.average(biases2) def train(mnist):
x = tf.placeholder(tf.float32, [None, INPUT_NODE], name='x-input')
y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')
# 生成隐藏层的参数。
weights1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER1_NODE], stddev=0.1))
biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE]))
# 生成输出层的参数。
weights2 = tf.Variable(tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], stddev=0.1))
biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE])) # 计算不含滑动平均类的前向传播结果
y = inference(x, None, weights1, biases1, weights2, biases2) # 定义训练轮数及相关的滑动平均类
global_step = tf.Variable(0, trainable=False) # 计算交叉熵及其平均值
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
cross_entropy_mean = tf.reduce_mean(cross_entropy) # 损失函数的计算
regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)
regularaztion = regularizer(weights1) + regularizer(weights2)
loss = cross_entropy_mean + regularaztion # 设置指数衰减的学习率。
learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step,
mnist.train.num_examples / BATCH_SIZE,
LEARNING_RATE_DECAY,
staircase=True) # 优化损失函数
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step) # 反向传播更新参数
with tf.control_dependencies([train_step]):
train_op = tf.no_op(name='train') # 计算正确率
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # 初始化会话,并开始训练过程。
with tf.Session() as sess:
tf.global_variables_initializer().run()
validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
test_feed = {x: mnist.test.images, y_: mnist.test.labels} # 循环的训练神经网络。
for i in range(TRAINING_STEPS):
if i % 1000 == 0:
validate_acc = sess.run(accuracy, feed_dict=validate_feed)
print("After %d training step(s), validation accuracy using average model is %g " % (i, validate_acc))
xs,ys=mnist.train.next_batch(BATCH_SIZE)
sess.run(train_op,feed_dict={x:xs,y_:ys})
test_acc=sess.run(accuracy,feed_dict=test_feed)
print(("After %d training step(s), test accuracy using average model is %g" %(TRAINING_STEPS, test_acc))) def main(argv=None):
mnist = input_data.read_data_sets("E:\\MNIST_data\\", one_hot=True)
train(mnist) if __name__=='__main__':
main()

吴裕雄 python 神经网络——TensorFlow训练神经网络:不使用滑动平均的更多相关文章

  1. 吴裕雄 python 神经网络——TensorFlow训练神经网络:不使用隐藏层

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 ...

  2. 吴裕雄 python 神经网络——TensorFlow训练神经网络:不使用激活函数

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 ...

  3. 吴裕雄 python 神经网络——TensorFlow训练神经网络:不使用指数衰减的学习率

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 ...

  4. 吴裕雄 python 神经网络——TensorFlow训练神经网络:不使用正则化

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 ...

  5. 吴裕雄 python 神经网络——TensorFlow训练神经网络:全模型

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 ...

  6. 吴裕雄 python 神经网络——TensorFlow训练神经网络:MNIST最佳实践

    import os import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_N ...

  7. 吴裕雄 python 神经网络——TensorFlow训练神经网络:花瓣识别

    import os import glob import os.path import numpy as np import tensorflow as tf from tensorflow.pyth ...

  8. 吴裕雄 python 神经网络——TensorFlow训练神经网络:卷积层、池化层样例

    import numpy as np import tensorflow as tf M = np.array([ [[1],[-1],[0]], [[-1],[2],[1]], [[0],[2],[ ...

  9. 吴裕雄--天生自然 Tensorflow卷积神经网络:花朵图片识别

    import os import numpy as np import matplotlib.pyplot as plt from PIL import Image, ImageChops from ...

随机推荐

  1. 用户 'sa' 登录失败。该用户与可信 SQL Server 连接无关联'。错误代码:18452 解决办法

    原文:https://blog.csdn.net/wuxianwei/article/details/6330270 SQLSERVER 2005采用'SQLSERVER身份验证'去登录, 出错的原因 ...

  2. How To Use These LED Garden Lights

    Are you considering the lighting options for the outdoor garden? Depending on how you use it, LED ga ...

  3. 牛客1080D tokitsukaze and Event (双向最短路)

    题目链接:https://ac.nowcoder.com/acm/contest/1080/D 首先建两个图,一个是权值为a的图,一个是权值为b的图. 从s起点以spfa算法跑权值为ai的最短路到t点 ...

  4. 可以使用的一些API(转存)

    聚合数据 juhe.com 转存的格式不如原文的好看,可以直接访问原文 https://www.jianshu.com/p/9a0acf69b789 api接口应该会越来越火,上个全的,楼主自己找找吧 ...

  5. Flask 教程 第十七章:Linux上的部署

    本文翻译自The Flask Mega-Tutorial Part XVII: Deployment on Linux 这是Flask Mega-Tutorial系列的第十七部分,我将把Microbl ...

  6. C++的四种转换(const_cast、static_cast、dynamic_cast、reinterpreter_cast)

    static_cast 相当于C语言中的强制转换:(类型)表达式或类型(表达式),用于各种隐式转换 非const转const.void*转指针.int和char相互转换 用于基类和子类之间的指针和引用 ...

  7. Spring bean继承

    Bean 定义继承 bean 定义可以包含很多的配置信息,包括构造函数的参数,属性值,容器的具体信息例如初始化方法,静态工厂方法名,等等. 子 bean 的定义继承父定义的配置数据.子定义可以根据需要 ...

  8. jeecg /ant-design-vuepro 前端使用

    1.原生axios使用 <script> import Vue from 'vue'; import axios from 'axios'; axios.defaults.baseURL ...

  9. tomcat6w.exe启动tomcat

    在使用tomcat中,我们可能经常点击startup.bat来启动tomcat,但也不少通过tomcat6w.exe来启动的. 但是当我们点击tomcat6w.exe的时候会报错,信息如下:提示 指定 ...

  10. Selenium元素定位之页面检测技巧

    我们在进行web自动化测试的时候进行XPath或者CSS定位,需要检测页面元素定位是否正确,如果用脚本去检测,那么效率是极低的. 一般网上推选装额外的插件来实现页面元素定位检测 如:firebug. ...