Effective TensorFlow 2.0

为使TensorFLow用户更高效,TensorFlow 2.0中进行了多出更改。TensorFlow 2.0删除了篇冗余API,使API更加一致(统一RNNs, 统一优化器),并通过Eager execution更好地与Python集成。

许多RFCs已经解释了TensorFlow 2.0带来的变化。本指南介绍了TensorFlow 2.0应该怎么进行开发。这假设您已对TensorFlow 1.x有一定了解。

A brief summary of major changes

API Cleanup

许多API在TF 2.0中进行了移动或删除。一些主要的变化包括删除tf.apptf.flags,使tf.logging支持现在开源的absl-py,重新生成项目的tf.contribe,通过清理tf.*中那些较少使用的命名空间,例如tf.math。一些API已替换为自己的2.0版本-tf.summary,tf.keras.metrics, 和tf.keras.optimizers。最快升级应用这些重命名带来的变化可使用v2升级脚本

Eager execution

TensorFlow 1.x要求用户通过tf.*API手动的将抽象语法树(图)拼接在一起。然后它要求用户通过一组输入、输出张量传递给session.run()从而手动编译调用这个图。TensorFlow 2.0 Eager execution可以像Python那样执行,在2.0中,graph 和 session会像实现细节一样。

值得注意的是tf.control_dependencies()不再需要了,因为所有代码都是行顺序执行的(用tf.function声明)。

No more globals

TensorFlow 1.x严重依赖隐式全局命名空间。当你调用tf.Variable(),它会被放入默认图中,即使你忘了指向它的Python变量,它也会被保留在那里。然后你可以恢复它,但前提是你得知道它创建时的名称。如果你无法控制变量的创建,这很难做到。其结果是,各种各样的机制,试图帮助用户再次找到他们的变量,以及为框架找到用户创建的变量:Variable scopes, global collections。例如tf.get_global_step()tf.global_variables_initializer(),还有优化器隐式计算所有可训练变量的梯度等等。

TensorFlow 2.0消除了这些机制(Variable 2.0 RFC)默认支持的机制:跟踪你的变量!如果你忘记了一个tf.Variable,它就会当作垃圾被回收。

Functions, not sessions

session.run()几乎可以像函数一样调用:指定输入和被调用的函数,你可以得到一组输出。在TensorFlow 2.0中,您可以使用Python函数tf.function()来标记它以进行JIT编译,以便TensorFlow将其作为单个图运行(Function 2.0 RFC)。这种机制允许TensorFlow 2.0获得图模型所有的好处:

  • 性能:函数可以被优化(node pruning, kernel fusion, etc.)
  • 可移植性:该功能可以被导出/重新导入(SavedModel 2.0 RFC),允许用户重用和共享模块化TensorFlow功能。
# TensorFlow 1.X
outputs = session.run(f(placeholder), feed_dict={placeholder: input})
# TensorFlow 2.0
outputs = f(input)

凭借穿插Python 和TensorFlow代码的能力,我们希望用户能够充分利用Python的表现力。除了在没有Python解释器的情况下执行TensorFlow,如mobile, C++, 和 JS。为了帮助用户避免在添加时重写代码@tf.functionAutoGraph会将Python构造的一个子集转换为他们的TensorFlow等价物:

  • for/while -> tf.while_loop (支持break 和 continue)
  • if->tf.cond
  • for _ in dataset -> dataset.reduce

AutoGraph支持控制流的任意嵌套,这使得可以有较好性能并且简洁地实现许多复杂的ML程序,如序列模型,强化学习,自定义训练循环等。

Recommendations for idiomatic TensorFlow 2.0

Refactor your code into smaller functions

TensorFlow 1.x中常见使用模式是“kitchen sink”策略,其中所有可能的计算的联合被预先布置,然后选择被评估的张量,通过session.run()运行。在TensorFlow 2.0中,用户应该将代码重构为较小的函数,这些函数根据需要被调用。通常,没有必要用tf.function去装饰那些比较小的函数;仅用tf.function去装饰高等级的计算,例如,训练的一个步骤,或模型的前向传递。

Use Keras layers and models to manage variables

Keras模型和图层提供了方便variables和 trainable_variables属性,它以递归方式收集所有因变量。这使得在本地管理变量非常容易。

对比:

def dense(x, W, b):
return tf.nn.sigmoid(tf.matmul(x, W) + b) @tf.function
def multilayer_perceptron(x, w0, b0, w1, b1, w2, b2 ...):
x = dense(x, w0, b0)
x = dense(x, w1, b1)
x = dense(x, w2, b2)
... # 你仍然需要管理w_i和b_i,它们的形状远离代码定义。

Keras版本:

# 可以调用每个图层,其签名等效于 linear(x)
layers = [tf.keras.layers.Dense(hidden_size, activation=tf.nn.sigmoid) for _ in range(n)]
perceptron = tf.keras.Sequential(layers) # layers[3].trainable_variables => returns [w3, b3]
# perceptron.trainable_variables => returns [w0, b0, ...]

Keras layers/models继承自tf.train.Checkpointable并集成了@tf.function,这使得直接从Keras对象导出SavedModels或checkpoint成为可能。您不一定要使用Keras的.fitAPI来利用这些集成。

这是一个迁移学习的例子,演示了Keras如何轻松收集相关变量的子集。假设你正在训练一个带有共享主干的多头模型:

trunk = tf.keras.Sequential([...])
head1 = tf.keras.Sequential([...])
head2 = tf.keras.Sequential([...]) path1 = tf.keras.Sequential([trunk, head1])
path2 = tf.keras.Sequential([trunk, head2]) # Train on primary dataset
for x, y in main_dataset:
with tf.GradientTape() as tape:
prediction = path1(x)
loss = loss_fn_head1(prediction, y)
# Simultaneously optimize trunk and head1 weights.
gradients = tape.gradients(loss, path1.trainable_variables)
optimizer.apply_gradients(gradients, path1.trainable_variables) # Fine-tune second head, reusing the trunk
for x, y in small_dataset:
with tf.GradientTape() as tape:
prediction = path2(x)
loss = loss_fn_head2(prediction, y)
# Only optimize head2 weights, not trunk weights
gradients = tape.gradients(loss, head2.trainable_variables)
optimizer.apply_gradients(gradients, head2.trainable_variables) # You can publish just the trunk computation for other people to reuse.
tf.saved_model.save(trunk, output_path)

Combine tf.data.Datasets and @tf.function

在内存中迭代拟合训练数据时,可以随意使用常规的Python迭代。或者,tf.data.Dataset是从硬盘读取训练数据流的最好方法。Datasets是可迭代的(不是迭代器),它可以像在Eager模式下的其他Python迭代一样工作。您可以通过用tf.function()包装代码来充分利用数据集异步预取/流功能,这将使用AutoGraph等效的图操作替换Python的迭代。

@tf.function
def train(model, dataset, optimizer):
for x, y in dataset:
with tf.GradientTape() as tape:
prediction = model(x)
loss = loss_fn(prediction, y)
gradients = tape.gradients(loss, model.trainable_variables)
optimizer.apply_gradients(gradients, model.trainable_variables)

如果您使用Keras.fit()API,则无需担心数据集迭代。

model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)

Take advantage of AutoGraph with Python control flow

AutoGraph提供了一种将依赖于数据的控制流转换为等效图形模式的方法,如tf.condtf.while_loop

数据相关控制流出现的一个常见位置是序列模型。tf.keras.layers.RNN包装了一个RNN cell,允许您既可以静态也可以动态的循环展开。为了演示,您可以重新实现动态展开,如下所示:

class DynamicRNN(tf.keras.Model):

  def __init__(self, rnn_cell):
super(DynamicRNN, self).__init__(self)
self.cell = rnn_cell def call(self, input_data):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
outputs = tf.TensorArray(tf.float32, input_data.shape[0])
state = self.cell.zero_state(input_data.shape[1], dtype=tf.float32)
for i in tf.range(input_data.shape[0]):
output, state = self.cell(input_data[i], state)
outputs = outputs.write(i, output)
return tf.transpose(outputs.stack(), [1, 0, 2]), state

有关AutoGraph功能的更详细概述,请参阅指南

Use tf.metrics to aggregate data and tf.summary to log it

要记录摘要,请使用tf.summary.(scalar|histogram|...)上下文管理器将其重定向到编写器。(如果省略上下文管理器,则不会发生任何事情。)与TF 1.x不同,摘要直接发送给编写器; 没有单独的“合并”操作,也没有单独的add_summary()调用,这意味着step必须在调用点提供该值。

summary_writer = tf.summary.create_file_writer('/tmp/summaries')
with summary_writer.as_default():
tf.summary.scalar('loss', 0.1, step=42)

要在将数据记录为摘要之前聚合数据,请使用tf.metrics。Metrics是有状态的;它们积累值并在您调用.result()时返回结果。清除积累值,请使用.reset_states()

def train(model, optimizer, dataset, log_freq=10):
avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
for images, labels in dataset:
loss = train_step(model, optimizer, images, labels)
avg_loss.update_state(loss)
if tf.equal(optimizer.iterations % log_freq, 0):
tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)
avg_loss.reset_states() def test(model, test_x, test_y, step_num):
loss = loss_fn(model(test_x), test_y)
tf.summary.scalar('loss', loss, step=step_num) train_summary_writer = tf.summary.create_file_writer('/tmp/summaries/train')
test_summary_writer = tf.summary.create_file_writer('/tmp/summaries/test') with train_summary_writer.as_default():
train(model, optimizer, dataset) with test_summary_writer.as_default():
test(model, test_x, test_y, optimizer.iterations)

通过将TensorBoard指向摘要日志目录来可视化生成的摘要:tensorboard --logdir /tmp/summaries

阅读原文

  • 欢迎关注我的公众号,一起学习!

TensorFlow 2.0高效开发指南的更多相关文章

  1. TensorFlow 2.0 快速入门指南 | iBooker·ApacheCN

    原文:TensorFlow 2.0 Quick Start Guide 协议:CC BY-NC-SA 4.0 自豪地采用谷歌翻译 不要担心自己的形象,只关心如何实现目标.--<原则>,生活 ...

  2. Cognos 11.0快速开发指南 Ⅰ

    1. 概述 Cognos Analysics 11,是IBM在Cognos BI 10的版本基础上,吸取业界流行的敏捷BI理念,强化了自助式分析的一款强大BI开发平台工具.其官方文档内容丰富,但是较为 ...

  3. Cognos 11.0快速开发指南 Ⅱ

    1.    创建报表 在创建好数据源之后,我们就可以创建报表了,报表的开发是浏览器中完成的,这里我选用了chrome浏览器,在地址栏输入:http://localhost:80/ibmcognos ( ...

  4. Odoo 8.0 实施开发指南 第一版 试读

    试读地址: http://share.weiyun.com/4f83964db87e022c7c210abe6b5e782f 如有错误,欢迎指正.

  5. OAuth2.0开发指南

    OAuth2.0开发指南 1.认证与登录 来往开放平台支持3种不同的OAuth 2.0验证与授权流程: 服务端流程(协议中Authorization Code Flow): 此流程适用于在Web服务端 ...

  6. 腾讯云安全:开发者必看|Android 8.0 新特性及开发指南

    欢迎大家关注腾讯云技术社区-博客园官方主页,我们将持续在博客园为大家推荐技术精品文章哦~ 背景介绍 谷歌2017 I/O开发者大会今年将于5月17-19日在美国加州举办.大会将跟往年一样发布最新的 A ...

  7. OAuth2.0学习(2-1)Spring Security OAuth2.0 开发指南

    开发指南:http://www.cnblogs.com/xingxueliao/p/5911292.html Spring OAuth2.0 提供者实现原理: Spring OAuth2.0提供者实际 ...

  8. Android开发指南--0 总览

    无意间发现一个网站,主打IOS方面的教程,然而作为一个Android开发者,我就找了下网站里有没有Android的教程,还真有,这里就翻译一下. 翻译目标教程:https://www.raywende ...

  9. 开发者必看|Android 8.0 新特性及开发指南

    背景介绍 谷歌2017 I/O开发者大会今年将于5月17-19日在美国加州举办.大会将跟往年一样发布最新的 Android 系统,今年为 Android 8.0.谷歌在今年3 月21日发布 Andro ...

随机推荐

  1. 关于在Spring项目中使用thymeleaf报Exception parsing document错误

    今天在使用SpringBoot的过程中,SpringBoot版本为1.5.18.RELEASE,访问thymeleaf引擎的html页面报错Exception parsing document: 这是 ...

  2. 【01】HTML_day01_03-HTML常用标签

    typora-copy-images-to: media 第01阶段.前端基础.HTML常用标签 学习目标 理解: 相对路径三种形式 应用 排版标签 文本格式化标签 图像标签 链接 相对路径,绝对路径 ...

  3. 后端跨域的N种方法

    简单来说,CORS是一种访问机制,英文全称是Cross-Origin Resource Sharing,即我们常说的跨域资源共享,通过在服务器端设置响应头,把发起跨域的原始域名添加到Access-Co ...

  4. 【python基础语法】第5天作业练习题

    import random """ 1.一家商场在降价促销.如果购买金额50-100元(包含50元和100元)之间,会给10%的折扣(打九折), 如果购买金额大于100元 ...

  5. MySQL 8 InnoDB 集群生产部署

    生产部署InnoDB集群 1.先决条件 InnoDB集群使用组复制技术,所以InnoDB中的实例需要满足组复制要求.可以查看MySQL文档中组复制相关的部分,也可以通过AdminAPI提供的dba.c ...

  6. 本地服务器热更新 插件 live-server

    本地服务器热更新 插件 live-server 超级好用 强烈种草一波 无需安装到项目中 使用方法如下: 1.先全局安装live-server: npm i http-server -g 2.在需要热 ...

  7. Django复制记录的方法

    最近的Django项目中有复制记录的需求.数据库里有一张名为Party的表,记录用户创建的party,现在要让用户能够复制一个新的party.本身非常简单的一个功能,但运行的时候出错了.我以为是复制过 ...

  8. 并发编程之J.U.C的第一篇

    并发编程之J.U.C AQS 原理 ReentrantLock 原理 1. 非公平锁实现原理 2)可重入原理 3. 可打断原理 5) 条件变量实现原理 3. 读写锁 3.1 ReentrantRead ...

  9. 在写论文的参考文献时,有的段落空格很大,有的段落则正常,原因及解决方法(wps)

    下图是一段原始的参考文献,可以看出第一行的空格很大: 原因: 当一个词占不下时,自动将单词移动到下一行,但是这一行又有很多字符,因此这时,软件会将空闲的位置用空白字符填满.第一行有两个空白字符,因此将 ...

  10. 使用centos6.5整理出来的常用命令

    1.Vi 基本操作1) 进入vi 在系统提示符号输入vi及文件名称后,就进入vi全屏幕编辑画面: $ vi myfile 进入vi之后,是处于「命令行模式(command mode)」,您要切换到「插 ...