我自己对mnist官方例程进行了部分注解,希望分享出来有助于入门选手更好理解tensorflow的运行机制,可以拷贝到IDE再调试看看,看看具体数据流向还有一部分tensorflow里面用到的库。
我用的是pip安装的tensorflow-GPU-1.13,这段源码原始位置在https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py

代码:

 from __future__ import absolute_import
from __future__ import division
from __future__ import print_function #absl是python标准库内的
from absl import app as absl_app
from absl import flags import tensorflow as tf # pylint: disable=g-bad-import-order from official.mnist import dataset
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.misc import distribution_utils
from official.utils.misc import model_helpers LEARNING_RATE = 1e-4 #参数默认data_format = 'channels_first'
def create_model(data_format):
"""Model to recognize digits in the MNIST dataset. Network structure is equivalent to:
https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
and
https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py But uses the tf.keras API. Args:
data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is
typically faster on GPUs while 'channels_last' is typically faster on
CPUs. See
https://www.tensorflow.org/performance/performance_guide#data_formats Returns:
A tf.keras.Model.
""" #data_format:一个字符串,可以是channels_last(默认)或channels_first,\
# 表示输入中维度的顺序,channels_last对应于具有形状(batch, height, width, channels)\
# 的输入,而channels_first对应于具有形状(batch, channels, height, width)的输入.
#这里感觉输入只有三个维度,默认是单通道图?
if data_format == 'channels_first':
input_shape = [1, 28, 28]
else:
assert data_format == 'channels_last'
input_shape = [28, 28, 1] #将tf.keras.layers.MaxPooling2D传递给max_pool
l = tf.keras.layers
max_pool = l.MaxPooling2D(
(2, 2), (2, 2), padding='same', data_format=data_format)
# The model consists of a sequential chain of layers, so tf.keras.Sequential
# (a subclass of tf.keras.Model) makes for a compact description.
return tf.keras.Sequential(
[
#输入层确保输入的大小符合网络需要[28, 28]->[1, 28, 28]
l.Reshape(
target_shape=input_shape,
input_shape=(28 * 28,)),
#卷积
l.Conv2D(
32,#filters:整数, 输出空间的维数(即卷积中的滤波器数),就是卷积核个数
5,#卷积核大小,这里是5x5
padding='same',
data_format=data_format,
activation=tf.nn.relu),
#最大pooling
max_pool,
#卷积
l.Conv2D(
64,
5,
padding='same',
data_format=data_format,
activation=tf.nn.relu),
# 最大pooling
max_pool,
#在保留第0轴的情况下对输入的张量进行Flatten(扁平化),拉直?
l.Flatten(),
#fc 1024 -> units: 该层的神经单元结点数。
l.Dense(1024, activation=tf.nn.relu),
l.Dropout(0.4),
#fc输出
l.Dense(10)
]) #添加了很多参数,指定了一部分的值,数据url,模型url,batch_size等等
def define_mnist_flags():
flags_core.define_base()
flags_core.define_performance(num_parallel_calls=False)
flags_core.define_image()
flags.adopt_module_key_flags(flags_core)
#自定义项参数都在这里设置了
flags_core.set_defaults(data_dir='./tmp/mnist_data',
model_dir='./tmp/mnist_model',
batch_size=100,
train_epochs=40,
stop_threshold=0.998) def model_fn(features, labels, mode, params):
"""The model_fn argument for creating an Estimator."""
# 翻译成中文,注释的意思就是添加一个data_format的参数,下面的Estimator类需要用到
model = create_model(params['data_format'])
image = features
# 来判断一个对象是否是一个已知的类型。
if isinstance(image, dict):
image = features['image'] #测试模式
if mode == tf.estimator.ModeKeys.PREDICT:
logits = model(image, training=False)
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits),
}
#如果只是测试到这里就返回了
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={
'classify': tf.estimator.export.PredictOutput(predictions)
}) #训练模式
if mode == tf.estimator.ModeKeys.TRAIN:
#设置LEARNING_RATE
optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE) logits = model(image, training=True)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
accuracy = tf.metrics.accuracy(
labels=labels, predictions=tf.argmax(logits, axis=1)) # Name tensors to be logged with LoggingTensorHook.
tf.identity(LEARNING_RATE, 'learning_rate')
tf.identity(loss, 'cross_entropy')
tf.identity(accuracy[1], name='train_accuracy') # Save accuracy scalar to Tensorboard output.
tf.summary.scalar('train_accuracy', accuracy[1]) return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN,
loss=loss,
train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
if mode == tf.estimator.ModeKeys.EVAL:
logits = model(image, training=False)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.EVAL,
loss=loss,
eval_metric_ops={
'accuracy':
tf.metrics.accuracy(
labels=labels, predictions=tf.argmax(logits, axis=1)),
}) def run_mnist(flags_obj):
"""Run MNIST training and eval loop. Args:
flags_obj: An object containing parsed flag values.
""" #apply_clean是官方例程里面提供的用来清理现存model的方法,\
# 取决于flags_obj.clean(True则清理flags_obj.model_dir内的文件)
model_helpers.apply_clean(flags_obj) #把自定义的实现传给tf.estimator.Estimator
model_function = model_fn #tf.ConfigProto()主要的作用是配置tf.Session的运算方式,比如gpu运算或者cpu运算
session_config = tf.ConfigProto(
#设置线程一个操作内部并行运算的线程数,比如矩阵乘法,如果设置为0,则表示以最优的线程数处理
inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
#设置多个操作并行运算的线程数,比如 c = a + b,d = e + f . 可以并行运算
intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
#有时候,不同的设备,它的cpu和gpu是不同的,如果将这个选项设置成True,\
# 那么当运行设备不满足要求时,会自动分配GPU或者CPU
allow_soft_placement=True) #获取gpu数目,优化算法等,用于优化
distribution_strategy = distribution_utils.get_distribution_strategy(
flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg) #所有输出(检查点,事件文件等)都被写入model_dir或其子目录.如果model_dir未设置,则使用临时目录.
#可以通过RunConfig对象(包含了有关执行环境的信息)传递config参数.它被传递给model_fn,\
# 如果model_fn有一个名为“config”的参数(和输入函数以相同的方式).如果该config参数未被传递,\
# 则由Estimator进行实例化.不传递配置意味着使用对本地执行有用的默认值.Estimator使配置对模型\
# 可用(例如,允许根据可用的工作人员数量进行专业化),并且还使用其一些字段来控制内部,特别是关于检查点
run_config = tf.estimator.RunConfig(
train_distribute=distribution_strategy, session_config=session_config) data_format = flags_obj.data_format
#channels_first,即(3,128,128,128)通道数在最前面
#channels_last,即(128,128,128,3)通道数在最后面
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')#判断安装的TF是否支持GPU #estimator类对TensorFlow模型进行训练和计算.
#Estimator对象包装由model_fn指定的模型,其中,给定输入和其他一些参数,返回需要进行训练、计算,或预测的操作.
mnist_classifier = tf.estimator.Estimator(
#这个model_fn是参数名而已
model_fn=model_function,#模型对象
model_dir=flags_obj.model_dir,#模型目录,如果为空会创建一个临时目录
#猜测会去model_dir中寻找数据
config=run_config,#运行的一些参数
params={
'data_format': data_format,#数据类型
}) #这里定义了两个内部函数,只能被这个语句块的内部调用
# Set up training and evaluation input functions.
def train_input_fn():
"""Prepare data for training.""" # When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes use less memory. MNIST is a small
# enough dataset that we can easily shuffle the full epoch.
ds = dataset.train(flags_obj.data_dir)
ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size) # Iterate through the dataset a set number (`epochs_between_evals`) of times
# during each training session.
ds = ds.repeat(flags_obj.epochs_between_evals)
return ds def eval_input_fn():
return dataset.test(flags_obj.data_dir).batch(
flags_obj.batch_size).make_one_shot_iterator().get_next() # Set up hook that outputs training logs every 100 steps.
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks, model_dir=flags_obj.model_dir,
batch_size=flags_obj.batch_size) # Train and evaluate model.
for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
#训练一次,验证一次
mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print('\nEvaluation results:\n\t%s\n' % eval_results) #如果eval_results['accuracy'] >= flags_obj.stop_threshold 说明模型训练好了
if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
eval_results['accuracy']):
break # Export the model
if flags_obj.export_dir is not None:
#预分配内存,等待数据进入
image = tf.placeholder(tf.float32, [None, 28, 28])
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'image': image,
})
#输出模型
mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn) def main(_):
run_mnist(flags.FLAGS) if __name__ == '__main__':
#日志
tf.logging.set_verbosity(tf.logging.INFO)
#给flags.FLAGS添加了很多参数项目
define_mnist_flags()
#带参数的启动
absl_app.run(main)

tensorflow--mnist注解的更多相关文章

  1. TensorFlow MNIST(手写识别 softmax)实例运行

    TensorFlow MNIST(手写识别 softmax)实例运行 首先要有编译环境,并且已经正确的编译安装,关于环境配置参考:http://www.cnblogs.com/dyufei/p/802 ...

  2. 学习笔记TF056:TensorFlow MNIST,数据集、分类、可视化

    MNIST(Mixed National Institute of Standards and Technology)http://yann.lecun.com/exdb/mnist/ ,入门级计算机 ...

  3. TensorFlow MNIST 问题解决

    TensorFlow MNIST 问题解决 一.数据集下载错误 错误:IOError: [Errno socket error] [Errno 101] Network is unreachable ...

  4. Mac tensorflow mnist实例

    Mac tensorflow mnist实例 前期主要需要安装好tensorflow的环境,Mac 如果只涉及到CPU的版本,推荐使用pip3,傻瓜式安装,一行命令!代码使用python3. 在此附上 ...

  5. tensorflow MNIST Convolutional Neural Network

    tensorflow MNIST Convolutional Neural Network MNIST CNN 包含的几个部分: Weight Initialization Convolution a ...

  6. tensorflow MNIST新手教程

    官方教程代码如下: import gzip import os import tempfile import numpy from six.moves import urllib from six.m ...

  7. TensorFlow MNIST初级学习

    MNIST MNIST 是一个入门级计算机视觉数据集,包含了很多手写数字图片,如图所示: 数据集中包含了图片和对应的标注,在 TensorFlow 中提供了这个数据集,我们可以用如下方法进行导入: f ...

  8. 学习笔记TF057:TensorFlow MNIST,卷积神经网络、循环神经网络、无监督学习

    MNIST 卷积神经网络.https://github.com/nlintz/TensorFlow-Tutorials/blob/master/05_convolutional_net.py .Ten ...

  9. AI tensorflow MNIST

    MNIST 数据 train-images-idx3-ubyte.gz:训练集图片 train-labels-idx1-ubyte.gz:训练集图片类别 t10k-images-idx3-ubyte. ...

  10. tensorflow——MNIST机器学习入门

    将这里的代码在项目中执行下载并安装数据集. 执行下面代码,训练.并评估模型: # _*_coding:utf-8_*_ import inputdata mnist = inputdata.read_ ...

随机推荐

  1. 内存溢出OOM

    如何避免OOM 异常? 想要避免OOM 异常首先我们要知道什么情况下会导致OOM 异常. 1.图片过大导致OOM Android 中用bitmap 时很容易内存溢出,比如报如下错误:Java.lang ...

  2. 支持“XXX”上下文的模型已在数据库创建后发生更改。请考虑使用 Code First 迁移更新数据库(http://go.microsoft.com/fwlink/?LinkId=238269)。

    在Global.asax文件中的Application_Start()方法中加入以下代码 Database.SetInitializer<XXX>(null);

  3. 《JAVA与模式》之工厂方法模式

    在阎宏博士的<JAVA与模式>一书中开头是这样描述工厂方法模式的: 工厂方法模式是类的创建模式,又叫做虚拟构造子(Virtual Constructor)模式或者多态性工厂(Polymor ...

  4. django xadmin(2) 在xadmin基础上完成自定义页面

    1.在xadmin.py,GlobalSettings中自定义菜单 2.自定义视图函数,并获取原来的菜单等一下信息(主要是为了用xadmin的模板),具体的自己看xadmin源码 3.在adminx. ...

  5. 上板子在线抓波发现app_rdy一直为低

    现象 使用Xilinx的MIG IP测试外挂DDR3的读写发现一段很短的时间后app_rdy恒为低,并且最后一个读出的数据全是F. (1)不读写数据,app_rdy正常为高,MIG IP初始化信号为高 ...

  6. DRF 商城项目 - 购物( 购物车, 订单, 支付 )逻辑梳理

    购物车 购物车模型 购物车中的数据不应该重复. 即对相同商品的增加应该是对购买数量的处理而不是增加一条记录 因此对此进行联合唯一索引, 但是也因此存在一些问题 class ShoppingCart(m ...

  7. Linux内核模块编程——Hello World模块

    Linux内核模块编程 编程环境 Ubuntu 16.04 LTS 什么是模块 内核模块的全称是动态可加载内核模块(Loadable Kernel Modul,KLM),可以动态载入内核,让它成为内核 ...

  8. Java代码的编译与反编译那些事儿

    原文:Java代码的编译与反编译那些事儿 编程语言 在介绍编译和反编译之前,我们先来简单介绍下编程语言(Programming Language).编程语言(Programming Language) ...

  9. easyui Datagrid 表格高度计算及自适应页面的实现

    因为页面上既要计算表格的高度,又要自适应浏览器大小,之前都都采用固定表格高度,这样就会导致不同的分辨率电脑上看起来表格高矮不一, 所以采用了计算网页高度和其他div 的高度之差作为表格的初始高度: H ...

  10. 第九周博客作业<西北师范大学|李晓婷>

    1.助教博客链接:https://home.cnblogs.com/u/lxt-/ 2.作业要求博客链接:https://www.cnblogs.com/nwnu-daizh/p/10726884.h ...