【新闻】:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测、医学图像、时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会。微信:cyx645016617.

参考目录:

本文主要讲述TF2.0的模型文件的存储和载入的多种方法。主要分成两类型:模型结构和参数一起载入,模型的结构载入。

1 模型的构建

import tensorflow.keras as keras

class CBR(keras.layers.Layer):
def __init__(self,output_dim):
super(CBR,self).__init__()
self.conv = keras.layers.Conv2D(filters=output_dim, kernel_size=4, padding='same', strides=1)
self.bn = keras.layers.BatchNormalization(axis=3)
self.ReLU = keras.layers.ReLU() def call(self, inputs):
inputs = self.conv(inputs)
inputs = self.ReLU(self.bn(inputs))
return inputs class MyNet(keras.Model):
def __init__ (self):
super(MyNet,self).__init__()
self.cbr1 = CBR(16)
self.maxpool1 = keras.layers.MaxPool2D(pool_size=(2,2))
self.cbr2 = CBR(32)
self.maxpool2 = keras.layers.MaxPool2D(pool_size=(2,2)) def call(self, inputs):
inputs = self.maxpool1(self.cbr1(inputs))
inputs = self.maxpool2(self.cbr2(inputs))
return inputs model = MyNet()

部分朋友可以发现,上面的代码就是上一次课程所构建的一个自定义的网络。

我们现在需要展示这个模型的框架:

model.build((16,224,224,3))
print(model.summary())

运行结果为:

这里需要对网络执行一个构建.build()函数,之后才能生成model.summary()这样的模型的描述。 这是因为模型的参数量是需要知道输入数据的通道数的,假如我们输入的是单通道的图片,那么就是:

model.build((16,224,224,1))
print(model.summary())

输出结果为:

2 结构参数的存储与载入

model.save('save_model.h5')
new_model = keras.models.load_model('save_model.h5')

这里并不能保存成功,出现这样的错误:

大概的意思就是:因为你的模型不是官方的模型,是自定义的,所以并不能同时保存结构和参数。只有官方的模型可以时候上面的保存的方法,同时保存参数和权重;自定义的模型建议只保存参数

3 参数的存储与载入

model.save_weights('model_weight')
new_model = MyNet()
new_model.load_weights('model_weight')

这样子就可以保存自定义的模型了。在对应的目录下会出现这几个文件:

我们来看一下原来的模型和载入的模型对于同一个样本给出的结果是否相同:

# 看一下原来的模型和载入的模型预测相同的样本的输出
test = tf.ones((1,8,8,3))
prediction = model.predict(test)
new_prediction = new_model.predict(test)
print(prediction,new_prediction)
>>> [[[[0.02559286]]]] [[[[0.02559286]]]]

结果相同,载入的没有问题~

4 结构的存储与载入

结构的存储有两种方法:

  • model.get_config()
  • model.to_json()

需要注意的是,上面的两个方法和save的问题一样,是不能用在自定义的模型中的,如果你在其中使用了自定义的Layer类,那么只能!只能用save_weights的方式进行保存

下面依然给出这两种方法的代码,对于简单的、已经封装好的一些网络层构成的网络,是可以使用这些的。我个人还是常用save_weights啦

# 第一种方法
config = model.get_config()
reinitialized_model = keras.Model.from_config(config)
# 第二种方法
json_config = model.to_json()
# 把json写的文件中
with open('model_config.json', 'w') as json_file:
json_file.write(json_config)
# 读取本地json文件
with open('model_config.json') as json_file:
json_config = json_file.read()
reinitialized_model = keras.models.model_from_json(json_config)

今天的内容就是这么多,虽然提供了四种方法,但是对于自定义程度较高的模型,还是要使用save_weights哦~

【小白学PyTorch】19 TF2模型的存储与载入的更多相关文章

  1. 【小白学PyTorch】6 模型的构建访问遍历存储(附代码)

    文章转载自微信公众号:机器学习炼丹术.欢迎大家关注,这是我的学习分享公众号,100+原创干货. 文章目录: 目录 1 模型构建函数 1.1 add_module 1.2 ModuleList 1.3 ...

  2. 【小白学PyTorch】20 TF2的eager模式与求导

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...

  3. 从头学pytorch(十二):模型保存和加载

    模型读取和存储 总结下来,就是几个函数 torch.load()/torch.save() 通过python的pickle完成序列化与反序列化.完成内存<-->磁盘转换. Module.s ...

  4. 【小白学PyTorch】18 TF2构建自定义模型

    [机器学习炼丹术]的炼丹总群已经快满了,要加入的快联系炼丹兄WX:cyx645016617 参考目录: 目录 1 创建自定义网络层 2 创建一个完整的CNN 2.1 keras.Model vs ke ...

  5. 【小白学PyTorch】15 TF2实现一个简单的服装分类任务

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...

  6. 【小白学PyTorch】4 构建模型三要素与权重初始化

    文章目录: 目录 1 模型三要素 2 参数初始化 3 完整运行代码 4 尺寸计算与参数计算 1 模型三要素 三要素其实很简单 必须要继承nn.Module这个类,要让PyTorch知道这个类是一个Mo ...

  7. 【小白学PyTorch】9 tensor数据结构与存储结构

    文章来自微信公众号[机器学习炼丹术]. 上一节课,讲解了MNIST图像分类的一个小实战,现在我们继续深入学习一下pytorch的一些有的没的的小知识来作为只是储备. 参考目录: @ 目录 1 pyto ...

  8. 【小白学PyTorch】16 TF2读取图片的方法

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.NLP等多个学术交流分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx645016617. 参考 ...

  9. 小白学PyTorch 动态图与静态图的浅显理解

    文章来自公众号[机器学习炼丹术],回复"炼丹"即可获得海量学习资料哦! 目录 1 动态图的初步推导 2 动态图的叶子节点 3. grad_fn 4 静态图 本章节缕一缕PyTorc ...

随机推荐

  1. 学习一下 JVM (三) -- 了解一下 垃圾回收

    一.简单了解几个概念 1.什么是垃圾(Garbage)?什么是垃圾回收(Garbage Collection,简称 GC)? (1)什么是垃圾(Garbage)? 这里的垃圾 指的是 在程序运行过程中 ...

  2. Hacker101 CTF 学习记录(一)

    前言 苦力挖洞,靠运气赚点小钱.看着大佬严重,高危,再看看自己手上的低危,无危害默默流下了菜鸡的泪水 思路受局限,之前听学长推荐和同事聊到hacker101,因此通过hacker101拓展下漏洞利用思 ...

  3. android尺寸问题(转)

    android尺寸问题(转) (2013-01-15 16:55:36) 转载▼ 标签: 杂谈 分类: LINUX 最近公司做的项目中涉及到屏幕自适应的问题.由于做的是电视版的项目,因此屏幕自适应问题 ...

  4. amd、cmd、CommonJS以及ES6模块化

    AMD.CMD.CommonJs.ES6的对比 他们都是用于在模块化定义中使用的,AMD.CMD.CommonJs是ES5中提供的模块化编程的方案,import/export是ES6中定义新增的 什么 ...

  5. spring中bean初始化执行顺序

    常用的javabean的初始化方法为,构造方法,@PostConstruct,以及实现InitializingBean接口的afterPropertiesSet方法. note在构造方法执行时候,sp ...

  6. java初探(1)之静态页面化——客户端缓存

    利用服务端缓存技术,将页面和对象缓存在redis中,可以减少时间浪费,内存开销.但在每次请求的过程中,仍然会有大量静态资源的请求和返回. 使用静态页面技术,页面不必要使用页面交互技术,比如thymel ...

  7. Mybatis实例增删改查(二)

    创建实体类: package com.test.mybatis.bean; public class Employee { private Integer id; private String las ...

  8. C#知识点:抽象类和接口浅谈

    首先介绍什么是抽象类? 抽象类用关键字abstract修饰的类就是叫抽象类,抽象类天生的作用就是被继承的,所以不能实例化,只能被继承.而且 abstract 关键字不能和sealed一起使用,因为se ...

  9. 区块链Fabric 交易流程

    1. 提交交易预案 1)应用端首先构建交易的预案,预案的作用是调用通道中的链码来读取或者写入账本的数据.应用端使用 Fabric 的 SDK 打包交易预案,并使用用户的私钥对预案进行签名. 应用打包完 ...

  10. 白嫖党看番神器 AnimeSeacher

    - Anime Searcher - 简介 通过整合第三方网站的视频和弹幕资源, 给白嫖党提供最舒适的看番体验~ 目前有 4 个资源搜索引擎和 2 个弹幕搜索引擎, 资源丰富, 更新超快, 不用下载, ...