本文以TensorFlow源码中自带的手写数字识别Example为例,引出TensorFlow中的几个主要概念。并结合Example源码一步步分析该模型的实现过程。

一、什么是TensorFlow

  在这里,引入TensorFlow中文社区首页中的两段描述。

关于 TensorFlow

TensorFlow™ 是一个采用数据流图(data flow

graphs),用于数值计算的开源软件库。节点(Nodes)在图中表示数学操作,图中的线(edges)则表示在节点间相互联系的多维数据数组,即张量(tensor)。它灵活的架构让你可以在多种平台上展开计算,例如台式计算机中的一个或多个CPU(或GPU),服务器,移动设备等等。TensorFlow

最初由Google大脑小组(隶属于Google机器智能研究机构)的研究员和工程师们开发出来,用于机器学习和深度神经网络方面的研究,但这个系统的通用性使其也可广泛用于其他计算领域。

什么是数据流图(Data Flow Graph)?

数据流图用“结点”(nodes)和“线”(edges)的有向图来描述数学计算。“节点”

一般用来表示施加的数学操作,但也可以表示数据输入(feed in)的起点/输出(push

out)的终点,或者是读取/写入持久变量(persistent

variable)的终点。“线”表示“节点”之间的输入/输出关系。这些数据“线”可以输运“size可动态调整”的多维数据数组,即“张量”(tensor)。张量从图中流过的直观图像是这个工具取名为“Tensorflow”的原因。一旦输入端的所有张量准备好,节点将被分配到各种计算设备完成异步并行地执行运算。

二、示例

  接下来的示例中,主要使用到以下两个文件。

mnist.py
fully_connected_feed.py

  该示例的目的是建立一个手写图像识别模型,通过该模型,可以准确识别输入的28 * 28像素的手写图片是0~9这十个数字中的哪一个。

1、运行文件准备

  需要下载好tensorflow源代码,注意这里的源代码版本需要与安装的TensorFlow版本保持一致。

  在/home/mlusr/files/tensorflow/下解压缩该文件。进入示例文件路径中,运行

cd ~/files/tensorflow/tensorflow-r0.11/tensorflow/examples/tutorials/mnist
python fully_connected_feed.py

  运行过程中,需要联网下载训练数据,数据文件保存到~/files/tensorflow/tensorflow-r0.11/tensorflow/examples/tutorials/mnist/data路径下,如果不能联网的话,可以手动到http://yann.lecun.com/exdb/mnist/,下载好以下四个文件,放入data目录。

t10k-images-idx3-ubyte.gz
t10k-labels-idx1-ubyte.gz
train-images-idx3-ubyte.gz
train-labels-idx1-ubyte.gz

2、运行

  直接运行fully_connected_feed.py文件。

python fully_connected_feed.py

  输出信息如下:

Extracting data/train-images-idx3-ubyte.gz
Extracting data/train-labels-idx1-ubyte.gz
Extracting data/t10k-images-idx3-ubyte.gz
Extracting data/t10k-labels-idx1-ubyte.gz
Step 0: loss = 2.30 (0.007 sec)
Step 100: loss = 2.13 (0.005 sec)
Step 200: loss = 1.87 (0.004 sec)
Step 300: loss = 1.55 (0.004 sec)
Step 400: loss = 1.26 (0.004 sec)
Step 500: loss = 0.87 (0.004 sec)
Step 600: loss = 0.87 (0.004 sec)
Step 700: loss = 0.65 (0.005 sec)
Step 800: loss = 0.43 (0.004 sec)
Step 900: loss = 0.65 (0.005 sec)
Training Data Eval:
Num examples: 55000 Num correct: 47184 Precision @ 1: 0.8579
Validation Data Eval:
Num examples: 5000 Num correct: 4349 Precision @ 1: 0.8698
Test Data Eval:
Num examples: 10000 Num correct: 8663 Precision @ 1: 0.8663
Step 1000: loss = 0.47 (0.006 sec)
Step 1100: loss = 0.40 (0.051 sec)
Step 1200: loss = 0.55 (0.005 sec)
Step 1300: loss = 0.43 (0.004 sec)
Step 1400: loss = 0.39 (0.004 sec)
Step 1500: loss = 0.57 (0.005 sec)
Step 1600: loss = 0.50 (0.004 sec)
Step 1700: loss = 0.37 (0.005 sec)
Step 1800: loss = 0.38 (0.006 sec)
Step 1900: loss = 0.35 (0.004 sec)
Training Data Eval:
Num examples: 55000 Num correct: 49292 Precision @ 1: 0.8962
Validation Data Eval:
Num examples: 5000 Num correct: 4525 Precision @ 1: 0.9050
Test Data Eval:
Num examples: 10000 Num correct: 9027 Precision @ 1: 0.9027

3、启动TensorBoard并查看

  在启动TensorBoard时注意指定输出log文件路径,在本例中启动命令如下

tensorboard --logdir /home/mlusr/files/tensorflow/tensorflow-r0.11/tensorflow/examples/tutorials/mnist/data

  启动输出信息如下所示:

Starting TensorBoard 29 on port 6006
(You can navigate to http://192.168.1.100:6006)

  浏览器访问页面指定ip和端口:

  

  在TensorBoard中还可以查看该模型的更多信息。

  

  本文接下来的部分,将以mnist.pyfully_connected_feed.py两个文件中的内容

二、数据下载和输入

  MNIST的数据主要分成以下三个部分,

  

数据集 作用
data_sets.train 55000条image和label数据,主要用于训练模型
data_sets.validation 5000条image和label数据,用于在迭代过程中确定模型准确率
data_sets.test 10000条image和label数据,用于最终评估模型的准确率

1、概念一:Placeholder

  Placeholder的更多描述,请看这里。使用Placeholder的地方,在构造Graph时并不包含实际的数据,只是在应用运行时才会动态的用数据来替代。

  在fully_connected_feed.py文件中的placeholder_inputs方法中,通过调用tf.placeholder方法分别生成了代表imageslabelsplaceholder

IMAGE_SIZE = 28
IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE images_placeholder = tf.placeholder(tf.float32, shape=(batch_size, mnist.IMAGE_PIXELS))
labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))

  在生成placeholder时,只需要指定其中的数据类型,以及维度。上面images_placeholder中的元素为float类型,维度为batch_size * IMAGE_PIXELSlagels_placeholder中的元素为int类型,维度为batch_sizebatch_size参数在程序调用placeholder_inputs时指定。

  看到这里可以发现images_placeholder和labels_placeholder仅仅只是指定了其中元素的类型和shape,具体数值是在后续程序运行时才会填充进来的。所以叫做Placeholder。在这里这两个Placeholder代表了输入的两个数据源。

三、概念二:Graph

  

  Graph是TensorFlow中又一个重要概念。Graph可以理解成TensorFlow中的一个调整好参数的执行计划。构建好这个Graph之后,所有输入数据,中间转换过程,以及输出数据的流程和格式便固定下来,数据进入Graph后按照特定的结构和参数,就能得到对应的输出结果。如下图所示:

  

  构建一个Graph主要分成以下三步。

1、确定Graph结构

  inference方法,以images_placeholder作为输入,连接到维度为(28 * 28, 128)的隐层1,隐层1连接到维度为(128, 32)的隐层2,最后的输出层logits为10个节点。各层之间的激活函数为Relu

  

  下面代码中使用到的常量

IMAGE_SIZE = 28
IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE
flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
NUM_CLASSES = 10

  构建隐层1,

with tf.name_scope('hidden1'):
 weights = tf.Variable(tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
  stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
  name='weights')
 biases = tf.Variable(tf.zeros([hidden1_units]),
  name='biases')

  上面定义了两个Variableweights表示连接权重,biases表示偏置量。

  biases比较简单,定义了一个名为biases的元素全为0的变量,其长度为hiden1_units,默认为128

  weights的维度为IMAGE_PIXELS * hidden1_units,其中的初始值为标准差为1 / math.sqrt(float(IMAGE_PIXELS)的截断正态分布值。

  构建隐层2,

with tf.name_scope('hidden2'):
weights = tf.Variable(tf.truncated_normal([hidden1_units, hidden2_units],
stddev =1.0 / math.sqrt(float(hidden1_units))),
name = 'weights')
biases = tf.Variable(tf.zeros([hidden2_units]),
name ='biases')

  构建输出层,

with tf.name_scope('softmax_linear'):
weights = tf.Variable(tf.truncated_normal([hidden2_units, NUM_CLASSES] ,
stddev =1.0 / math.sqrt(float(hidden2_units))) ,
name = 'weights')
biases = tf.Variable(tf.zeros([NUM_CLASSES]),
name ='biases')

  基于上面的权重和偏置量值,使用relu激活函数连接各层,

hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
logits = tf.matmul(hidden2, weights) + biases

  前面的三组weightsbiases变量名都相同,区分的办法是前面的with tf.name_scope('hidden1')。在hidden1命名空间下的wiehts参数的完整表示为"hidden1/weights"

2、确定损失函数

  上一步确定好模型各层结构和参数后,接下来需要定义一个损失函数的计算逻辑。

  在mnist.py文件中有一个loss()方法,输入两个参数,第一个为上面模型的输出结果logits,第二个为images对应的实际labels,在调用该方法时,传入的是前面定义的labels_placeholder

  

def loss(logits, labels):
labels = tf.to_int64(labels)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits, labels , name = 'xentropy')
loss = tf.reduce_mean(cross_entropy, name ='xentropy_mean')
return loss

  上面的tf.nn.sparse_softmax_cross_entropy_with_logits会根据labels的内容自动生成1-hot编码,并且计算与输出logits的1-hot编码的交叉熵[cross entropy][http://blog.csdn.net/rtygbwwwerr/article/details/50778098])

  最后,调用reduce_mean方法,计算交叉熵的平均值。

3、参数训练

  

  调用training方法的调用形式为,传入上面的损失值和学习率。

train_op = mnist.training(loss, FLAGS.learning_rate)

  接下来,mnist.py文件中的training方法,将使用梯度下降法来计算使得损失值最小的模型参数。首先将损失值loss传入tf.scala_summary中,这个操作主要是用于在后面使用SummaryWriter时向events file中生成求和值,将每一次得到的损失值写出到事件文tf.scalar_summary(loss.op.name, loss)后,调用tf.train.GradientDesecentOptimizer按指定的学习率实现梯度下降算法。

  

# Create the gradient descent optimizer with the given learning rate.
optimizer = tf.train.GradientDescentOptimizer(learning_rate)

  最后,使用一个名为global_stepvariable来记录每一次训练的步长。optimizer.minimize操作用于更新系统的权重,同时增加步长。

# Create a variable to track the global step.
global_step = tf.Variable(0, name = 'global_step', trainable =False)
# Use the optimizer to apply the gradients that minimize the loss
# (and also increment the global step counter) as a single training step.
train_op = optimizer.minimize(loss, global_step=global_step)

四、训练模型

  当第三步中的Graph构造完成之后,就可以迭代的训练和评估模型了。

1、Graph

  在run_training()方法的最前面,使用一个with命令表明所有的操作都要与tf.Graph的默认全局graph相关联。

with tf.Graph().as_default():

  tf.Graph表示需要在一起运行的操作集合。在大多数情况下,TensorFlow使用一个默认的graph就已经够用了。

2、Session

  接下来就需要为应用运行准备环境了。在TensorFlow中使用的是Session

sess = tf.Session()

  另外,除了按上面这行代码生成sess对象外,还可以使用with命令生成,如下所示,

with tf.Session() as sess:

  在获得sess对象后,首先可以将之前定义的variable进行初始化,

init = tf.initialize_all_variables()
sess.run(init)

3、循环训练

  初始化之后就可以开始循环训练模型了。

  可以通过如下代码实现一个最简单的训练循环,在这个循环中可以控制每次循环的步长。

for step in xrange(FLAGS.max_steps):
sess.run(train_op)

  但是在本教程中的例子比较复杂。这是因为必须把输入的数据根据每一步的情况进行切分,替换到之前的placeholder处。具体可以继续看以下部分。

4、向Graph输入数据

  TensorFlow的feed机制可以在应用运行时向Graph输入数据。在每一步训练过程中,首先会根据训练数据生成一个feed dictionary,这里面会包含本次循环中使用到的训练数据集。

feed_dict = fill_feed_dict(data_sets.train,
images_placeholder,
labels_placeholder)

  fill_feed_dict方法如下,每次从训练数据集中根据batch_size取出指定数量的images_feedlabels_feed,然后以images_pllabels_plkey存入字典中。

def fill_feed_dict (data_set, images_pl, labels_pl):
images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
FLAGS.fake_data)
feed_dict = {
images_pl: images_feed,
labels_pl: labels_feed,
}
return feed_dict

5、检查状态

  接下来以上面获取到的每个batch的数据开始执行训练过程。

for step in xrange(FLAGS.max_steps):
feed_dict = fill_feed_dict(data_sets.train,
images_placeholder,
labels_placeholder)
_, loss_value = sess.run([train_op, loss],
feed_dict=feed_dict)

  在这里传入train_oploss后,sess.run方法返回一个包含两个Tensortuple对象。由于train_op并没有返回值,所以只记录loss的返回值loss_value

  假设训练过程很正常,那么每过100次训练将会打印一次当前的loss值,

if step % 100 == 0 :
print ('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))

6、状态可视化

  在上面每隔100次打印一次loss值之外,还有两个操作将当前的loss值写入到事件文件中,供TensorBoard作展示用。

  summary_str = sess.run(summary, feed_dict=feed_dict)
summary_writer.add_summary(summary_str, step)
summary_writer.flush()

7、设置检查点

  在TensorFlow中使用tf.train.Saver将训练好的模型进行保存。

saver = tf.train.Saver()

  在循环训练过程中,saver.save()方法会定期执行,用于将模型当前状态写入到检查点文件中。

checkpoint_file = os.path.join(FLAGS.log_dir , 'model.ckpt')
saver.save(sess, checkpoint_file, global_step =step)

  如果需要使用到该检查点文件中保存的模型时,可以使用saver.restore()方法进行加载,

saver.restore(sess, FLAGS.train_dir)

五、评估模型

  在每次保存检查点文件时,会同时计算此时模型在训练数据集,检验数据集和测试数据集上的误差。

print('Training Data Eval:')
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.train)
# Evaluate against the validation set.
print ('Validation Data Eval:')
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.validation)
# Evaluate against the test set.
print ('Test Data Eval:')
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.test)

1、建立评估Graph

  注意上面代码中的do_eval方法,以及该方法的eval_correct参数。eval_correct操作是在循环开始前就已经定义好了的。

eval_correct = mnist.evaluation(logits, labels_placeholder)

  这个evaluation从参数上看是用于比较预测值与真实值直接的差异。

def evaluation (logits, labels):
correct = tf.nn.in_top_k(logits, labels, 1)
return tf.reduce_sum(tf.cast(correct, tf.int32))

  返回一个长度为batch_sizetensor,如果预测值与真实值相同则为true,否则为false

2、评估模型输出

  最后,在do_eval方法中,处理该误差并输出。类似于模型训练过程中,这里也创建一个feed_dict对象,在给定的数据集上调用sess.run方法,计算预测值中有多少与实际值相一致。

for step in xrange(steps_per_epoch):
feed_dict = fill_feed_dict(data_set,
images_placeholder,
labels_placeholder)
true_count += sess.run(eval_correct, feed_dict =feed_dict)

  最后,将预测正确的记录数与当前的总数据数进行比较,得到本次的预测精度。

precision = float(true_count) / num_examples
print ('Num examples: %d Num correct: %d Precision @ 1: %0.04f' %
(num_examples, true_count, precision))

TensorFlow入门和示例分析的更多相关文章

  1. TensorFlow入门,基本介绍,基本概念,计算图,pip安装,helloworld示例,实现简单的神经网络

    TensorFlow入门,基本介绍,基本概念,计算图,pip安装,helloworld示例,实现简单的神经网络

  2. TensorFlow入门(四) name / variable_scope 的使

    name/variable_scope 的作用 欢迎转载,但请务必注明原文出处及作者信息. @author: huangyongye @creat_date: 2017-03-08 refer to: ...

  3. (转)TensorFlow 入门

        TensorFlow 入门 本文转自:http://www.jianshu.com/p/6766fbcd43b9 字数3303 阅读904 评论3 喜欢5 CS224d-Day 2: 在 Da ...

  4. #tensorflow入门(1)

    tensorflow入门(1) 关于 TensorFlow TensorFlow™ 是一个采用数据流图(data flow graphs),用于数值计算的开源软件库.节点(Nodes)在图中表示数学操 ...

  5. TensorFlow入门(五)多层 LSTM 通俗易懂版

    欢迎转载,但请务必注明原文出处及作者信息. @author: huangyongye @creat_date: 2017-03-09 前言: 根据我本人学习 TensorFlow 实现 LSTM 的经 ...

  6. 转:TensorFlow入门(六) 双端 LSTM 实现序列标注(分词)

    http://blog.csdn.net/Jerr__y/article/details/70471066 欢迎转载,但请务必注明原文出处及作者信息. @author: huangyongye @cr ...

  7. TensorFlow入门之MNIST最佳实践

    在上一篇<TensorFlow入门之MNIST样例代码分析>中,我们讲解了如果来用一个三层全连接网络实现手写数字识别.但是在实际运用中我们需要更有效率,更加灵活的代码.在TensorFlo ...

  8. TensorFlow入门之MNIST最佳实践-深度学习

    在上一篇<TensorFlow入门之MNIST样例代码分析>中,我们讲解了如果来用一个三层全连接网络实现手写数字识别.但是在实际运用中我们需要更有效率,更加灵活的代码.在TensorFlo ...

  9. 利用 TensorFlow 入门 Word2Vec

    利用 TensorFlow 入门 Word2Vec 原创 2017-10-14 chen_h coderpai 博客地址:http://www.jianshu.com/p/4e16ae0aad25 或 ...

随机推荐

  1. python中的多线程

    一个程序可以理解为一个进程,这个进程有其代号,可以依据这个代号将其杀死. 一个进程肯定有且只有一个主线程,他可以有很多子线程. 运行一个任务如果可以有许多子线程同时去做,当然会提高效率. 但是,在py ...

  2. javascript中的事件类型

    表单事件 submit reset click change focus blur input window事件 load DomContentLoaded readyStatechange unlo ...

  3. 减小APK大小

    本篇文章翻译自Reduce APK Size 用户通常不会去下载体积过大的应用程序,特别是当自己的设备连接的是 2G/3G 或者按字节付费的网络.这篇文章描述了如何缩减 APK 的体积大小,以使得更多 ...

  4. 1018关于MySQL复制搭建[异步复制和半同步复制]

    转自:http://www.cnblogs.com/ivictor/p/5735580.html 搭建MySQL数据库的主从架构,还是蛮简单的.重要的几个命令整理一下. 主从服务器上: SHOW VA ...

  5. 在windows系统之中查看目前已安装的更新

    方法1:使用PowerShell get-hotfix 方法2:使用cmd systeminfo.exe 参考链接

  6. MyBatis(1)——快速入门

    MyBatis 简介 MyBatis 本是apache的一个开源项目iBatis, 2010年这个项目由apache software foundation 迁移到了google code,并且改名为 ...

  7. [LeetCode] Delete and Earn 删除与赚取

    Given an array nums of integers, you can perform operations on the array. In each operation, you pic ...

  8. [LeetCode] Student Attendance Record II 学生出勤记录之二

    Given a positive integer n, return the number of all possible attendance records with length n, whic ...

  9. 微信小程序开发•模块化

    微信小程序的MINA框架,其实是许多前端开发技术的组合.这篇文章中,我们来简单地讨论一下模块化. 1.模块化标准 玩前端的同学大部分都知道模块化的几个标准,CommonJs / AMD / CMD.这 ...

  10. jq图片展示插件highslide.js简单dom

    今天用用了一款图片展示插件highslide.js,感觉用起来很是舒畅,几乎不用怎么写代码,只需要知道如何写参数就行了. 那么这么牛叉的插件我们该如何用哪,下面我就跟大家讲解一下. 一.引入   首先 ...