我第一次开始接触到TensorFlow大概是去年五月份,大三下,如果一年多已过,我却还在写启程。。这进度,实在汗颜。。

一个完整的tensorflow程序可以分为以下几部分:

  • Inputs and Placeholders
  • Build the Graph
    • Inference
    • Loss
  • Training
  • Train the Model
  • Visualize the Status
  • Save a Checkpoint
  • Evaluate the Model
  • Build the Eval Graph
  • Eval Output

Inputs and Placeholders

对于一个完整的网络来说,必定有输入还有输出,而Placeholders就是针对网络输入来的,相当于预先给输入变量占个坑,拿mnist来说,占坑代码可以如下面的例子:

 images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,mnist.IMAGE_PIXELS))
labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))

上述代码相当于为mnist图片和标签分别占坑,而tf.placeholder参数可以如下面所示:

tf.placeholder(dtype, shape=None, name=None)

即需要提供占坑数据类型dtype,占坑数据shape,当然也可以给它提供一个唯一的name

Build the Graph

因为tf是通过构建图模型来进行网络搭建的,因此搭建网络也就是’Build the Graph’。

Inference

首先就是构建图,利用一系列符号将要表达的操作表达清楚,以用于后续模型的训练。如下面代码:

 with tf.name_scope('hidden1'):
weights = tf.Variable(tf.truncated_normal([IMAGE_PIXELS, hidden1_units],\
stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),name='weights') biases = tf.Variable(tf.zeros([hidden1_units]),\
name='biases')

如上述代码,对于一个图的搭建,需要一些变量来支持我们的运算,比如矩阵相乘等,需要通过tf.Variable来声明变量,其参数格式如下:

 tf.Variable(self, initial_value=None, trainable=True, collections=None, validate_shape=True,\
caching_device=None, name=None, variable_def=None, dtype=None)

需要提供变量初始值initial_value, 是否接受训练trainable,对于validate_shape表示该变量是否可以改变,如果形状可以改变,那么应该为False。对于每个变量,可以赋予不同的名字tf.name_scope

Loss

在定义完图结构之后,我们需要有个目标函数,用作更新图结构中的各个变量。

 labels = tf.to_int64(labels)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels, name='xentropy')

如上,通过给定的labels占坑变量,完成手写数字识别的最后交叉熵函数。

Training

在得到目标函数之后,我们就可以对模型进行训练,这里常用梯度下降法。在训练阶段,我们可以通过tf.scalar_summary来实现变量的记录,用作后续的tensorboard的可视化,如:

 tf.scalar_summary(loss.op.name, loss)

然后通过tf.SummaryWriter()来得到对应的提交值。

而对于模型的最优化,这里 tf 提供了很多optimazer,通常在tf.train里面,这里常用的是GradientDenscentOptimizer(lr),然后通过调用:

 train_op = optimizer.minimize(loss, global_step=global_step)

Train the Model

在模型训练时,我们需要打开一个默认的图环境,用作训练,如:

 with tf.Graph().as_default():

以此来打开一个图结构,然后我们需要声明一个会话在所有操作都定义完毕之后,这样我们就可以利用这个session来运行Graph.可以通过如下方法声明:

 with tf.Session() as sess:
init = tf.initialize_all_variables()
sess.run(init)

每次我们可以通过sess.run来运行一些操作,进而获取其输出值,

 sess.run(fetches, feed_dict=None, options=None, run_metadata=None)

可以看到,run需要fetches,即操作,feed_dictfetches的输入,即占坑变量与其对应值构成的字典。

Visualize the Status

当然,在运行过程,我们可以通过可视化的操作来看网络运行情况。
在之前的tf.scalar_summary, 我们可以通过:

 summary = tf.merge_all_summaries()

将在图构建阶段的变量收集起来,然后在session创建之后运行如下命令生成可视化的值。

 summary_str = sess.run(summary, feed_dict=feed_dict)
summary_writer.add_summary(summary_str, step)

其中summary_writer由如下得到:

 summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

然后用tensorboard打开对应文件即可。

Save a Chenckpoint

对于模型的保存,可以通过如下代码实现:

 saver = tf.train.Saver()
saver.save(sess, FLAGS.train_dir, global_step=step)

而载入模型可以通过如下的代码来实现:

saver.restore(sess, FLAGS.train_dir)

当然了,模型的估计就类似上述了。

这样简单的模型搭建到运行就完成了。本文主要用到这些函数:

    • tf.placeholder
    • tf.Variable
    • tf.train
      • tf.train.GradientDenscentOptimizer
      • tf.train.SummaryWriter
      • tf.train.Saver
    • tf.session
    • tf.Graph
    • tf.add_summary
    • tf.merge_all_summaries

其实构建一个模型基本就用这些函数,然后就是一些数理计算方法。详情参看tensorflow

Reference:Link

TF启程的更多相关文章

  1. TFS命令tf:undo(强制签入签出文件)

    由于修改计算机名称或不同电脑上操作忘记签入,则需要强制签入文件 具体步骤如下: 1.在命令行中输入"cd  C:\Program Files\Microsoft Visual Studio ...

  2. 制作HP MicroServer Gen8可用的ESXi 5.x SD/TF卡启动盘

    前些日子看到HP公司和京东在搞服务器促销活动,于是就入了一个 ProLiant MicroServer Gen8 的低配版 相比上一代产品,新一代 MicroServer系列微型服务器可更换处理器,还 ...

  3. TF Boys (TensorFlow Boys ) 养成记(六)

    圣诞节玩的有点嗨,差点忘记更新.祝大家昨天圣诞节快乐,再过几天元旦节快乐. 来继续学习,在/home/your_name/TensorFlow/cifar10/ 下新建文件夹cifar10_train ...

  4. TF Boys (TensorFlow Boys ) 养成记(五)

    有了数据,有了网络结构,下面我们就来写 cifar10 的代码. 首先处理输入,在 /home/your_name/TensorFlow/cifar10/ 下建立 cifar10_input.py,输 ...

  5. TF Boys (TensorFlow Boys ) 养成记(四)

    前面基本上把 TensorFlow 的在图像处理上的基础知识介绍完了,下面我们就用 TensorFlow 来搭建一个分类 cifar10 的神经网络. 首先准备数据: cifar10 的数据集共有 6 ...

  6. TF Boys (TensorFlow Boys ) 养成记(三)

    上次说到了 TensorFlow 从文件读取数据,这次我们来谈一谈变量共享的问题. 为什么要共享变量?我举个简单的例子:例如,当我们研究生成对抗网络GAN的时候,判别器的任务是,如果接收到的是生成器生 ...

  7. TF Boys (TensorFlow Boys ) 养成记(二)

    TensorFlow 的 How-Tos,讲解了这么几点: 1. 变量:创建,初始化,保存,加载,共享: 2. TensorFlow 的可视化学习,(r0.12版本后,加入了Embedding Vis ...

  8. TF Boys (TensorFlow Boys ) 养成记(一)

    本资料是在Ubuntu14.0.4版本下进行,用来进行图像处理,所以只介绍关于图像处理部分的内容,并且默认TensorFlow已经配置好,如果没有配置好,请参考官方文档配置安装,推荐用pip安装.关于 ...

  9. ROS TF——learning tf

    在机器人的控制中,坐标系统是非常重要的,在ROS使用tf软件库进行坐标转换. 相关链接:http://www.ros.org/wiki/tf/Tutorials#Learning_tf 一.tf简介 ...

随机推荐

  1. Python3 Selenium自动化web测试 ==>FAQ:隐式等待和sleep区别

    FAQ: 情景1: 设置等待时间 A方法:sleep 线程休眠,但只单次有效,其他操作需要加载等待时间,需要再次添加time.sleep() B方法:implicitly_wait() from se ...

  2. Laravel从模型中图片的相对路径获取绝对路径

    在模型product.php中增加以下方法.数据库图片字段为image.存储的图片相对路径 public function getImageUrlAttribute() { // 如果 image 字 ...

  3. [bzoj1892][bzoj2384][bzoj1461][Ceoi2011]Match/字符串的匹配_KMP_树状数组

    2384: [Ceoi2011]Match 1892: Match 1461: 字符串的匹配 题目大意: 数据范围: 题解: 很巧妙的一道题呀. 需要对$KMP$算法有很深的理解才行. 首先我们需要发 ...

  4. Java判断指定日期是否为工作日

    Java判断指定日期是否为工作日 转自:https://www.jianshu.com/p/966659492f2f 转:https://www.jianshu.com/p/05ccb5783f65转 ...

  5. 输出单项链表中倒数第k个结点——牛客刷题

    题目描述: 输入一个单向链表,输出该链表中倒数第k个结点 输入.输出描述: 输入说明:1.链表结点个数 2.链表结点的值3.输入k的值 输出说明:第k个结点指针 题目分析: 假设链表长度为n,倒数第k ...

  6. 【sublime Text】sublime Text3安装可以使xml格式化的插件

    应该有机会 ,会碰到需要格式化xml文件的情况. 例如,修改word转化的xml文件之后再将修改之后的xml文件转化为word文件. 但是,word另存的xml文件是没有格式的一片: 那怎么格式化 这 ...

  7. Java实现的基础数据结构

    Java实现的基础数据结构 0,常用的基础数据结构 图1 基础数据结构&相关特性 图2 Java自带的类集框架&继承关系图 1,数组[Array] 特点:长度固定.查找方便[直接使用i ...

  8. 基于C#开发的扩展按钮控件

    最近在准备一套自定义控件开发的课程,下面将第一个做的按钮控件分享给大家. 其实这个控件属于自定义控件中的扩展控件,与组合控件和GDI+开发的控件不同,这个控件是继承原生的Button, 这个控件的目的 ...

  9. 服务端相关知识学习(六)Zookeeper client

    Zookeeper的client是通过Zookeeper类提供的.前面曾经说过,Zookeeper给使用者提供的是一个类似操作系统的文件结构,只不过这个结构是分布式的.可以理解为一个分布式的文件系统. ...

  10. luogu P4688 [Ynoi2016]掉进兔子洞

    luogu 我们要求的答案应该是三个区间长度\(-3*\)在三个区间中都出现过的数个数 先考虑数列中没有相同的数怎么做,那就是对三个区间求交,然后交集大小就是要求的那个个数.现在有相同的数,考虑给区间 ...