分为三个文件:mnist_inference.py:定义前向传播的过程以及神经网络中的参数,抽象成为一个独立的库函数;mnist_train.py:定义神经网络的训练过程,在此过程中,每个一段时间保存一次模型训练的中间结果;mnist_eval.py:定义测试过程。

  1. mnist_inference.py
  1. #coding=utf8
  2. import tensorflow as tf
  3.  
  4. #1. 定义神经网络结构相关的参数。
  5.  
  6. INPUT_NODE = 784
  7. OUTPUT_NODE = 10
  8. LAYER1_NODE = 500
  9.  
  10. #2. 通过tf.get_variable函数来获取变量。
  11. def get_weight_variable(shape, regularizer):
  12. weights = tf.get_variable("weights", shape, initializer=tf.truncated_normal_initializer(stddev=0.1))
  13. if regularizer != None: tf.add_to_collection('losses', regularizer(weights))
  14. return weights
  15.  
  16. #3. 定义神经网络的前向传播过程。使用命名空间方式,不需要把所有的变量都作为变量传递到不同的函数中提高程序的可读性
  17.  
  18. def inference(input_tensor, regularizer):
  19. with tf.variable_scope('layer1'):
  20.  
  21. weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
  22. biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0))
  23. layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)
  24.  
  25. with tf.variable_scope('layer2'):
  26. weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
  27. biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))
  28. layer2 = tf.matmul(layer1, weights) + biases
  29.  
  30. return layer2
  31.  
  1. mnist_train.py

#coding=utf8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference

import os

  1. #1. 定义神经网络结构相关的参数。
  2. BATCH_SIZE = 100
  3. LEARNING_RATE_BASE = 0.8
  4. LEARNING_RATE_DECAY = 0.99
  5. REGULARIZATION_RATE = 0.0001
  6. TRAINING_STEPS = 30000
  7. MOVING_AVERAGE_DECAY = 0.99
  8. MODEL_SAVE_PATH="MNIST_model/"
  9. MODEL_NAME="mnist_model"
  10. #2. 定义训练过程。
  11. def train(mnist):
  12. # 定义输入输出placeholder。
  13. x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
  14. y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
  15. regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
  16. y = mnist_inference.inference(x, regularizer)
  17. global_step = tf.Variable(0, trainable=False)
  18. # 定义损失函数、学习率、滑动平均操作以及训练过程。
  19. variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
  20. variables_averages_op = variable_averages.apply(tf.trainable_variables())
  21. cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
  22. cross_entropy_mean = tf.reduce_mean(cross_entropy)
  23. loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
  24. learning_rate = tf.train.exponential_decay(
  25. LEARNING_RATE_BASE,
  26. global_step,
  27. mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,
  28. staircase=True)
  29. train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
  30. with tf.control_dependencies([train_step, variables_averages_op]):
  31. train_op = tf.no_op(name='train')
  32. # 初始化TensorFlow持久化类。
  33. saver = tf.train.Saver()
  34. with tf.Session() as sess:
  35. tf.global_variables_initializer().run()
  36. for i in range(TRAINING_STEPS):
  37. xs, ys = mnist.train.next_batch(BATCH_SIZE)
  38. _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
  39. if i % 1000 == 0:
  40. print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
  41. saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
  42. def main(argv=None):
  43. mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
  44. train(mnist)
  45. if __name__ == '__main__':
  46. main()
  47. 结果如下:

mnist_eval.py:

  1.  
  2. import time
  3. import tensorflow as tf
  4. from tensorflow.examples.tutorials.mnist import input_data
  5. import mnist_inference
  6. #coding=utf8
  7. import mnist_train
  8.  
  9. #1. 每10秒加载一次最新的模型
  10.  
  11. # 加载的时间间隔。
  12. EVAL_INTERVAL_SECS = 10
  13.  
  14. def evaluate(mnist):
  15. with tf.Graph().as_default() as g:
  16. x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
  17. y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
  18. validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
  19.  
  20. y = mnist_inference.inference(x, None)
  21. correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
  22. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  23.  
  24. variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)
  25. variables_to_restore = variable_averages.variables_to_restore()
  26. saver = tf.train.Saver(variables_to_restore)
  27.  
  28. while True:
  29. with tf.Session() as sess:
  30. ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)
  31. if ckpt and ckpt.model_checkpoint_path:
  32. saver.restore(sess, ckpt.model_checkpoint_path)
  33. global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
  34. accuracy_score = sess.run(accuracy, feed_dict=validate_feed)
  35. print("After %s training step(s), validation accuracy = %g" % (global_step, accuracy_score))
  36. else:
  37. print('No checkpoint file found')
  38. return
  39. time.sleep(EVAL_INTERVAL_SECS)
  40.  
  41. def main(argv=None):
  42. mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
  43. evaluate(mnist)
  44.  
  45. if __name__ == '__main__':
  46. main()

结果如下:

Tensorflow 解决MNIST问题的重构程序的更多相关文章

  1. 深入浅出TensorFlow(二):TensorFlow解决MNIST问题入门

    2017年2月16日,Google正式对外发布Google TensorFlow 1.0版本,并保证本次的发布版本API接口完全满足生产环境稳定性要求.这是TensorFlow的一个重要里程碑,标志着 ...

  2. tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)

    tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...

  3. 一个简单的TensorFlow可视化MNIST数据集识别程序

    下面是TensorFlow可视化MNIST数据集识别程序,可视化内容是,TensorFlow计算图,表(loss, 直方图, 标准差(stddev)) # -*- coding: utf-8 -*- ...

  4. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  5. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  6. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  7. win10下通过Anaconda安装TensorFlow-GPU1.3版本,并配置pycharm运行Mnist手写识别程序

    折腾了一天半终于装好了win10下的TensorFlow-GPU版,在这里做个记录. 准备安装包: visual studio 2015: Anaconda3-4.2.0-Windows-x86_64 ...

  8. Tensorflow之MNIST的最佳实践思路总结

    Tensorflow之MNIST的最佳实践思路总结   在上两篇文章中已经总结出了深层神经网络常用方法和Tensorflow的最佳实践所需要的知识点,如果对这些基础不熟悉,可以返回去看一下.在< ...

  9. tensorflow处理mnist(二)

    用卷积神经网络解决mnist的分类问题. 简单的例子 一行一行解释这个代码. 这个不是google官方的例子,但是很简洁,便于入门.tensorflow是先定义模型,最后赋值,计算.为了讨论问题方便, ...

随机推荐

  1. Qt5.3.2_CentOS6.4_单步调试环境__20160306【勿删,繁琐】

    20160306 全程没有f/q ZC:使用的虚拟机环境是:博客园VMwareSkill 的 “CentOS6.4_x86_120g__20160306.rar” 需要调试器 gdb ,从“http: ...

  2. 雷林鹏分享:Ruby 类案例

    Ruby 类案例 下面将创建一个名为 Customer 的 Ruby 类,您将声明两个方法: display_details:该方法用于显示客户的详细信息. total_no_of_customers ...

  3. [.NET开发] C#实现发送手机验证码功能

    之前不怎么了解这个,一直以为做起来很复杂. 直到前两天公司要求要做这个功能. 做了之后才发现 这不过就是一个POST请求就能实现的东西.现在给大家分享一下,有不足之处还请多多指教. 废话不多说 直接上 ...

  4. android--------ExpandableListView的使用多级列表

    多级列表ExpandableListView 扩展列表能够显示一个指示在每项显示项的当前状态(状态通常是一个扩展的组,组的孩子,或倒塌,最后一个孩子).使用setchildindicator(draw ...

  5. 『PyTorch』第七弹_nn.Module扩展层

    有下面代码可以看出torch层函数(nn.Module)用法,使用超参数实例化层函数类(常位于网络class的__init__中),而网络class实际上就是一个高级的递归的nn.Module的cla ...

  6. 『科学计算』图像检测微型demo

    这里是课上老师给出的一个示例程序,演示图像检测的过程,本来以为是传统的滑窗检测,但实际上引入了selectivesearch来选择候选窗,所以看思路应该是RCNN的范畴,蛮有意思的,由于老师的注释写的 ...

  7. oracle图形界面配置tns

    oracle图形界面配置tns       启动orcl服务        

  8. sql server2008 如何获取上月、上周、昨天、今天、本周、本月的查询周期(通过存储过程)

    我这边有一个需求要统计订单数据,需要统计订单的上传日期,统计的模块大概是 那么上月.上周.昨天.今天.本周.本月应该是怎样呢? 1.数据分析 因为今天是动态数据,我要查月份(上月.本月),应该是一个日 ...

  9. 快速切题 sgu113 Nearly prime numbers 难度:0

    113. Nearly prime numbers time limit per test: 0.25 sec. memory limit per test: 4096 KB Nearly prime ...

  10. 快速切题 usaco ariprog

    题目:给定3<=n<=25,m<250,求m及以内的两两平方和能否构成为n的等差数列 1 WA 没有注意到应该按照公差-首项的顺序排序 2 MLE 尝试使用桶,但是实际上那可能是分散 ...