TensorFlow最佳实践样例
以下代码摘自《Tensor Flow:实战Google深度学习框架》
本套代码是在 http://www.cnblogs.com/shanlizi/p/9033330.html 基础上进行持久化,分为3部分,分别为infenrence,train,eval.
是将原代码模块化,并且持久化之后可以直接调用训练后的模型。
需要注意的一点是:本人电脑的路径,mnist_inference.py是在Mnist_New文件夹下的,所以代码中加载模块用的是:import Mnist_New.mnist_inference as mnist_inference
import tensorflow as tf # 定义神经网络结构相关的参数
INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500 # 通过tf.get_variable函数来获取变量。在训练神经网络时会创建这些变量;在测试时会通
# 过保存的模型加载这些变量的取值。而且更加方便的是,因为可以在变量加载时将滑动平均变
# 量重命名,所以可以直接通过相同的名字在训练时使用变量自身,而在测试时使用变量的滑动
# 平均值。在这个函数中也会将变量的正则化损失加入到损失集合。
def get_weight_variable(shape, regularizer):
weights = tf.get_variable("weights", shape,initializer=tf.truncated_normal_initializer(stddev=0.1))
# 当给出了正则化生成函数时,将当前变量的正则化损失加入名字为losses的集合。在这里
# 使用了add_to_collection函数将一个张量加入一个集合,而这个集合的名称为losses。
# 这是自定义的集合,不在TensorFlow自动管理的集合列表中。
if regularizer != None:
tf.add_to_collection('losses', regularizer(weights))
return weights # 定义神经网络的前向传播过程
def inference(input_tensor, regularizer):
# 声明第一层神经网络的变量并完成前向传播过程。
with tf.variable_scope('layer1'):
# 这里通过tf.get_variable或者tf.Variable没有本质区别,因为在训练或者测试
# 中没有在同一个程序中多次调用这个函数。如果在同一个程序中多次调用,在第一次
# 调用之后需要将reuse参数设置为True。
weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
biases = tf.get_variable("biases", [LAYER1_NODE],initializer=tf.constant_initializer(0.0))
layer1 = tf.nn.relu(tf.matmul(input_tensor, weights)+biases) # 类似的声明第二层神经网络的变量并完成前向传播过程。
with tf.variable_scope('layer2'):
weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
biases = tf.get_variable("biases", [OUTPUT_NODE],initializer=tf.constant_initializer(0.0))
layer2 = tf.matmul(layer1, weights) + biases # 返回最后前向传播的结果
return layer2
import os import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 加载mnist_inference.py中定义的常量和前向传播的函数。
import Mnist_New.mnist_inference as mnist_inference # 配置神经网络的参数。
BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 30000
MOVING_AVERAGE_DECAY = 0.99 # 模型保存的路径和文件名
MODEL_SAVE_PATH = "./model/"
MODEL_NAME = "model.ckpt" def train(mnist):
# 定义输入输出placeholder。
x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
# 直接使用mnist_inference.py中定义的前向传播过程
y = mnist_inference.inference(x, regularizer)
global_step = tf.Variable(0, trainable=False) # 定义损失函数、学习率、滑动平均操作以及训练过程
variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
variable_averages_op = variable_averages.apply(tf.trainable_variables())
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
cross_entropy_mean = tf.reduce_mean(cross_entropy)
loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step,
mnist.train.num_examples / BATCH_SIZE,
LEARNING_RATE_DECAY
)
train_step = tf.train.GradientDescentOptimizer(learning_rate)\
.minimize(loss, global_step=global_step)
with tf.control_dependencies([train_step, variable_averages_op]):
train_op = tf.no_op(name='train') # 初始化TensorFlow持久化类
saver = tf.train.Saver()
with tf.Session() as sess:
tf.global_variables_initializer().run() # 在训练过程中不再测试模型在验证数据上的表现,验证和测试的过程将会有一个独
# 立的程序来完成。
for i in range(TRAINING_STEPS):
xs, ys = mnist.train.next_batch(BATCH_SIZE)
_, loss_value, step = sess.run([train_op, loss, global_step],
feed_dict={x: xs, y_: ys})
# 每1000轮保存一次模型
if i % 1000 == 0:
# 输出当前的训练情况。这里只输出了模型在当前训练batch上的损失
# 函数大小。通过损失函数的大小可以大概了解训练的情况。在验证数
# 据集上正确率的信息会有一个单独的程序来生成
print("After %d training steps, loss on training "
"batch is %g." % (step, loss_value))
# 保存当前的模型。注意这里给出了global_step参数,这样可以让每个
# 被保存的模型的文件名末尾加上训练的轮数,比如“model.ckpt-1000”,
# 表示训练1000轮之后得到的模型。
saver.save(
sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME),
global_step=global_step
) def main(argv=None):
mnist = input_data.read_data_sets("../path/to/MNIST_data/", one_hot=True)
train(mnist) if __name__ == "__main__":
tf.app.run()
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 加载mnist_inference.py 和mnist_train.py中定义的常量和函数。
import Mnist_New.mnist_inference as mnist_inference
import Mnist_New.mnist_train as mnist_train # 每10秒加载一次最新的模型,并且在测试数据上测试最新模型的正确率
EVAL_INTERVAL_SECS = 10 def evaluate(mnist):
with tf.Graph().as_default() as g:
# 定义输入输出的格式。
x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
validate_feed = {x: mnist.validation.images,y_: mnist.validation.labels} # 直接通过调用封装好的函数来计算前向传播的结果。因为测试时不关注ze正则化损失的值
# 所以这里用于计算正则化损失的函数被设置为None。
y = mnist_inference.inference(x, None) # 使用前向传播的结果计算正确率。如果需要对未知的样例进行分类,那么使用
# tf.argmax(y,1)就可以得到输入样例的预测类别了。
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # 通过变量重命名的方式来加载模型,这样在前向传播的过程中就不需要调用求滑动平均
# 的函数来获取平均值了。这样就可以完全共用mnist_inference.py中定义的
# 前向传播过程。
variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore) # 每隔EVAL_INTERVAL_SECS秒调用一次计算正确率的过程以检验训练过程中正确率的
# 变化。
while True:
with tf.Session() as sess:
# tf.train.get_checkpoint_state函数会通过checkpoint文件自动
# 找到目录中最新模型的文件名。
ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
# 加载模型。
saver.restore(sess, ckpt.model_checkpoint_path)
# 通过文件名得到模型保存时迭代的轮数。
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
accuracy_score = sess.run(accuracy,feed_dict=validate_feed)
print("After %s training step(s), validation "
"accuracy = %g" % (global_step, accuracy_score))
else:
print("No checkpoint file found")
return
time.sleep(EVAL_INTERVAL_SECS) def main(argv=None):
mnist = input_data.read_data_sets("../path/to/MNIST_data/", one_hot=True)
evaluate(mnist) if __name__ == "__main__":
tf.app.run()
TensorFlow最佳实践样例的更多相关文章
- 80、tensorflow最佳实践样例程序
''' Created on Apr 21, 2017 @author: P0079482 ''' #-*- coding:utf-8 -*- import tensorflow as tf #定义神 ...
- Cloud TPU Demos(TensorFlow 云 TPU 样例代码)
Cloud TPU Demos 这是一个Python脚本的集合,适合在开源TensorFlow和 Cloud TPU 上运行. 如果您想对模型做出任何修改或改进,请提交一个 PR ! https:// ...
- 吴裕雄 python 神经网络——TensorFlow 完整神经网络样例程序
import tensorflow as tf from numpy.random import RandomState batch_size = 8 w1= tf.Variable(tf.rando ...
- 学习笔记TF061:分布式TensorFlow,分布式原理、最佳实践
分布式TensorFlow由高性能gRPC库底层技术支持.Martin Abadi.Ashish Agarwal.Paul Barham论文<TensorFlow:Large-Scale Mac ...
- Tensorflow之MNIST的最佳实践思路总结
Tensorflow之MNIST的最佳实践思路总结 在上两篇文章中已经总结出了深层神经网络常用方法和Tensorflow的最佳实践所需要的知识点,如果对这些基础不熟悉,可以返回去看一下.在< ...
- Tensorflow的最佳实践
Tensorflow的最佳实践 1.变量管理 Tensorflow提供了变量管理机制,可直接通过变量的名字获取变量,无需通过传参数传递数据.方式如下: #以下为两种创建变量的方法 v=tf.get ...
- EffectiveTensorflow:Tensorflow 教程和最佳实践
Tensorflow和其他数字计算库(如numpy)之间最明显的区别在于Tensorflow中的操作是符号. 这是一个强大的概念,允许Tensorflow进行所有类型的事情(例如自动区分),这些命令式 ...
- [持续交付实践] pipeline使用:项目样例
项目说明 本文将以一个微服务项目的具体pipeline样例进行脚本编写说明.一条完整的pipeline交付流水线通常会包括代码获取.单元测试.静态检查.打包部署.接口层测试.UI层测试.性能专项测试( ...
- TensorFlow入门之MNIST最佳实践
在上一篇<TensorFlow入门之MNIST样例代码分析>中,我们讲解了如果来用一个三层全连接网络实现手写数字识别.但是在实际运用中我们需要更有效率,更加灵活的代码.在TensorFlo ...
随机推荐
- AT24C02跨页写数据
AT24C02 EEPROM的写数据分为:字节写数据模式和页写数据模式:字节写就是一个地址一个数据的写,页写是连续写数据,一个地址多个数据的写,但是页写不能自动跨页,如果超出一页长度,超出的数据会覆盖 ...
- Elasticsearch Java Rest Client API 整理总结 (一)——Document API
目录 引言 概述 High REST Client 起步 兼容性 Java Doc 地址 Maven 配置 依赖 初始化 文档 API Index API GET API Exists API Del ...
- Revit二次开发-根据视图阶段(Phase)创建房间
最近开发业务中,有一个自动创建房间的功能,很自然的想到了Document.NewRooms2方法.但是当前功能的特殊之处在于,Revit项目视图是分阶段(Phase)的,不同阶段的房间是互相独立的. ...
- 做游戏的小伙伴们注意了,DDoS还可以这样破!
欢迎大家前往腾讯云+社区,获取更多腾讯海量技术实践干货哦~ 本文由腾讯游戏云发表于云+社区专栏 作者:腾讯DDoS安全专家.腾讯云游戏安全专家haroldchen 摘要:在游戏出海的过程中,DDoS攻 ...
- Bitmap 位图 Java实现
一.结构思想 以 bit 作为存储单位进行布尔值存取的数据结构. 表现为:给定第i位,该bit为1则表示true,为0则表示false. 二.使用场景及优点 适用于对布尔或0.1值进行(大量)存取的场 ...
- 1093. Count PAT’s (25)-统计字符串中PAT出现的个数
如题,统计PAT出现的个数,注意PAT不一定要相邻,看题目给的例子就知道了. num1代表目前为止P出现的个数,num12代表目前为止PA出现的个数,num123代表目前为止PAT出现的个数. 遇到P ...
- 第四次Scrum meeting
第四次Scrum meeting 会议内容: 沟通方面:与学霸在线组.学霸手机客户端组进行沟通,了解现阶段各个小组的进度,并针对接口结构方面进行调整 前后端:我们完全可以是不需要界面的,但是为了用户的 ...
- 2-Thirteenth Scrum Meeting-10151213
任务安排 成员 今日完成 明日任务 闫昊 获取视频播放进度 用本地数据库记录课程结构和学习进度 唐彬 阅读IOS代码+阅读上届网络核心代码 请假(编译……) 史烨轩 下载service开发 ...
- 20135327郭皓——Linux内核分析第二周 操作系统是如何工作的
操作系统是如何工作的 上章重点回顾: 计算机是如何工作的?(总结)——三个法宝 存储程序计算机工作模型,计算机系统最最基础性的逻辑结构: 函数调用堆栈,高级语言得以运行的基础,只有机器语言和汇编语言的 ...
- c# dataGridView 表头格式设置不管用
解决办法: EnableHeaderVisualStyles设为false