TensorFlow 2.0 版本将 keras 作为高级 API,对于 keras boy/girl 来说,这就很友好了。tf.keras 从 1.x 版本迁移到 2.0 版本,需要注意几个地方。

1. 设置随机种子

import tensorflow as tf

# TF 1.x
tf.set_random_seed(args.seed)
# TF 2.0
tf.random.set_seed(args.seed)

2. 设置并行线程数和动态分配显存

import tensorflow as tf
from tensorflow.python.keras import backend as K import os
# 将程序限定在一块GPU上
os.environ["CUDA_VISIBLE_DEVICES"] = '0' # TF 1.x
config = tf.ConfigProto(intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True # 不全部占满显存, 按需分配
K.set_session(tf.Session(config=config)) # TF 2.0,由于之前限定了GPU可见范围,这里只能看到0号GPU
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
print(gpus)
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, enable=True)

3. model.compile() 中设置 metrics=['acc'] 或者 ['accuracy'],会影响 model.fit() 生成的 log,callbacks.ModelCheckpoint 需要对应填 val_acc 或者 val_accuracy:

from tensorflow.python.keras import callbacks

# TF 2.0, acc and val_acc
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['acc'])
ck_callback = callbacks.ModelCheckpoint('./model.h5', monitor='val_acc', mode='max',
verbose=1, save_best_only=True, save_weights_only=True) # TF 2.0, accuracy and val_accuracy
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
ck_callback = callbacks.ModelCheckpoint('./model.h5', monitor='val_accuracy', mode='max',
verbose=1, save_best_only=True, save_weights_only=True)

4. 舍弃 model.fit_generator() 函数

model.fit_generator() 函数在 TF 2.x 中合并到 model.fit() 函数中,并且在 TF 2.0 版本,该函数有问题,不能很好利用 GPU,训练速度很慢:

Performance: Training is much slower in TF v2.0.0 VS v1.14.0 when using Tf.Keras and model.fit_generator #33024

TF 2.0 版本的 model.fit() 在传入 generator 时需要手动设置 model.fit(shuffle=False)。

解决办法:直接使用 model.fit() 函数,并且升级到 TF 2.1。

【tf.keras】TensorFlow 1.x 到 2.0 的 API 变化的更多相关文章

  1. 【tf.keras】使用手册

    目录 0. 简介 1. 安装 1.1 安装 CUDA 和 cuDNN 2. 数据集 2.1 使用 tensorflow_datasets 导入公共数据集 2.2 数据集过大导致内存溢出 2.3 加载 ...

  2. tensorflow 2.0 技巧 | 自定义tf.keras.Model的坑

    自定义tf.keras.Model需要注意的点 model.save() subclass Model 是不能直接save的,save成.h5,但是能够save_weights,或者save_form ...

  3. 【tf.keras】实现 F1 score、precision、recall 等 metric

    tf.keras.metric 里面竟然没有实现 F1 score.recall.precision 等指标,一开始觉得真不可思议.但这是有原因的,这些指标在 batch-wise 上计算都没有意义, ...

  4. 基于tensorflow2.0 使用tf.keras实现Fashion MNIST

    本次使用的是2.0测试版,正式版估计会很快就上线了 tf2好像更新了蛮多东西 虽然教程不多 还是找了个试试 的确简单不少,但是还是比较喜欢现在这种写法 老样子先导入库 import tensorflo ...

  5. TensorFlow2.0(11):tf.keras建模三部曲

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  6. 【tf.keras】tf.keras使用tensorflow中定义的optimizer

    Update:2019/09/21 使用 tf.keras 时,请使用 tf.keras.optimizers 里面的优化器,不要使用 tf.train 里面的优化器,不然学习率衰减会出现问题. 使用 ...

  7. 【tf.keras】Resource exhausted: OOM when allocating tensor with shape [9216,4096] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc

    运行以下类似代码: while True: inputs, outputs = get_AlexNet() model = tf.keras.Model(inputs=inputs, outputs= ...

  8. 一文上手Tensorflow2.0之tf.keras(三)

    系列文章目录: Tensorflow2.0 介绍 Tensorflow 常见基本概念 从1.x 到2.0 的变化 Tensorflow2.0 的架构 Tensorflow2.0 的安装(CPU和GPU ...

  9. 【tf.keras】tensorflow datasets,tfds

    一些最常用的数据集如 MNIST.Fashion MNIST.cifar10/100 在 tf.keras.datasets 中就能找到,但对于其它也常用的数据集如 SVHN.Caltech101,t ...

随机推荐

  1. 《一头扎进》系列之Python+Selenium框架设计篇1-什么是自动化测试框架-价值好几K的框架,不看别后悔,过时不候

    1. 什么是自动化测试框架 在了解什么是自动化测试框架之前,先了解一下什么叫框架?框架是整个或部分系统的可重用设计,表现为一组抽象构件及构件实例间交互的方法:另一种定义认为,框架是可被应用开发者定制的 ...

  2. JS-常考算法题解析

    常考算法题解析 这一章节依托于上一章节的内容,毕竟了解了数据结构我们才能写出更好的算法. 对于大部分公司的面试来说,排序的内容已经足以应付了,由此为了更好的符合大众需求,排序的内容是最多的.当然如果你 ...

  3. leetcode-164、最大间距

    export default (arr) => { // 如果数组长度小于2返回0 if (arr.length < 2) { return 0 } // 排序 arr.sort() // ...

  4. HTML真零基础教程

    这是为完全没有接触过的童鞋写的,属于真正的傻瓜式教程,当然由于本人也不是什么大佬,可能有些知识的理解与自己想的不一样,如果有大佬看到,还请帮我指出.以下主要是HTML5的基础标签的使用. 开发前的准备 ...

  5. Java 判断密码是否是大小写字母、数字、特殊字符中的至少三种

    public class CheckPassword { //数字 public static final String REG_NUMBER = ".*\\d+.*"; //小写 ...

  6. Java设计模式整理

    一.创建型模式 1.抽象工厂模式(AbstractFactory): 提供一个接口, 用于创建相关或依赖对象的家族, 而不需要指定具体类. 案例:https://www.cnblogs.com/lfx ...

  7. Java实现微信小程序支付(完整版)

    在开发微信小程序支付的功能前,我们先熟悉下微信小程序支付的业务流程图: 不熟悉流程的建议还是仔细阅读微信官方的开发者文档. 一,准备工作 事先需要申请企业版小程序,并开通“微信支付”(即商户功能).并 ...

  8. NET Framework项目移植到NET Core上遇到的一系列坑(2)

    目录 获取请求的参数 获取完整的请求路径 获取域名 编码 文件上传的保存方法 获取物理路径 返回Json属性大小写问题 webconfig的配置移植到appsettings.json 设置区域块MVC ...

  9. c++-多态小案例

    多态小案例 C面向接口编程和C多态 函数类型语法基础 函数指针做函数参数(回调函数)思想剖析 函数指针做函数参数两种用法(正向调用.反向调用) 纯虚函数 抽象类 抽象类基本概念 抽象类在多继承中的应用 ...

  10. iOS 和 H5 页面交互(WKWebview 和 UIWebview cookie 设置)

    iOS 和 H5 页面交互(WKWebview 和 UIWebview cookie 设置) 主要记录关于cookie相关的坑 1. UIWebview 1. UIWebview 相对比较简单 直接通 ...