【新闻】:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测、医学图像、时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会。微信:cyx645016617.

参考目录:

本文主要讲述TF2.0的模型文件的存储和载入的多种方法。主要分成两类型:模型结构和参数一起载入,模型的结构载入。

1 模型的构建

import tensorflow.keras as keras

class CBR(keras.layers.Layer):
def __init__(self,output_dim):
super(CBR,self).__init__()
self.conv = keras.layers.Conv2D(filters=output_dim, kernel_size=4, padding='same', strides=1)
self.bn = keras.layers.BatchNormalization(axis=3)
self.ReLU = keras.layers.ReLU() def call(self, inputs):
inputs = self.conv(inputs)
inputs = self.ReLU(self.bn(inputs))
return inputs class MyNet(keras.Model):
def __init__ (self):
super(MyNet,self).__init__()
self.cbr1 = CBR(16)
self.maxpool1 = keras.layers.MaxPool2D(pool_size=(2,2))
self.cbr2 = CBR(32)
self.maxpool2 = keras.layers.MaxPool2D(pool_size=(2,2)) def call(self, inputs):
inputs = self.maxpool1(self.cbr1(inputs))
inputs = self.maxpool2(self.cbr2(inputs))
return inputs model = MyNet()

部分朋友可以发现,上面的代码就是上一次课程所构建的一个自定义的网络。

我们现在需要展示这个模型的框架:

model.build((16,224,224,3))
print(model.summary())

运行结果为:

这里需要对网络执行一个构建.build()函数,之后才能生成model.summary()这样的模型的描述。 这是因为模型的参数量是需要知道输入数据的通道数的,假如我们输入的是单通道的图片,那么就是:

model.build((16,224,224,1))
print(model.summary())

输出结果为:

2 结构参数的存储与载入

model.save('save_model.h5')
new_model = keras.models.load_model('save_model.h5')

这里并不能保存成功,出现这样的错误:

大概的意思就是:因为你的模型不是官方的模型,是自定义的,所以并不能同时保存结构和参数。只有官方的模型可以时候上面的保存的方法,同时保存参数和权重;自定义的模型建议只保存参数

3 参数的存储与载入

model.save_weights('model_weight')
new_model = MyNet()
new_model.load_weights('model_weight')

这样子就可以保存自定义的模型了。在对应的目录下会出现这几个文件:

我们来看一下原来的模型和载入的模型对于同一个样本给出的结果是否相同:

# 看一下原来的模型和载入的模型预测相同的样本的输出
test = tf.ones((1,8,8,3))
prediction = model.predict(test)
new_prediction = new_model.predict(test)
print(prediction,new_prediction)
>>> [[[[0.02559286]]]] [[[[0.02559286]]]]

结果相同,载入的没有问题~

4 结构的存储与载入

结构的存储有两种方法:

  • model.get_config()
  • model.to_json()

需要注意的是,上面的两个方法和save的问题一样,是不能用在自定义的模型中的,如果你在其中使用了自定义的Layer类,那么只能!只能用save_weights的方式进行保存

下面依然给出这两种方法的代码,对于简单的、已经封装好的一些网络层构成的网络,是可以使用这些的。我个人还是常用save_weights啦

# 第一种方法
config = model.get_config()
reinitialized_model = keras.Model.from_config(config)
# 第二种方法
json_config = model.to_json()
# 把json写的文件中
with open('model_config.json', 'w') as json_file:
json_file.write(json_config)
# 读取本地json文件
with open('model_config.json') as json_file:
json_config = json_file.read()
reinitialized_model = keras.models.model_from_json(json_config)

今天的内容就是这么多,虽然提供了四种方法,但是对于自定义程度较高的模型,还是要使用save_weights哦~

【小白学PyTorch】19 TF2模型的存储与载入的更多相关文章

  1. 【小白学PyTorch】6 模型的构建访问遍历存储(附代码)

    文章转载自微信公众号:机器学习炼丹术.欢迎大家关注,这是我的学习分享公众号,100+原创干货. 文章目录: 目录 1 模型构建函数 1.1 add_module 1.2 ModuleList 1.3 ...

  2. 【小白学PyTorch】20 TF2的eager模式与求导

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...

  3. 从头学pytorch(十二):模型保存和加载

    模型读取和存储 总结下来,就是几个函数 torch.load()/torch.save() 通过python的pickle完成序列化与反序列化.完成内存<-->磁盘转换. Module.s ...

  4. 【小白学PyTorch】18 TF2构建自定义模型

    [机器学习炼丹术]的炼丹总群已经快满了,要加入的快联系炼丹兄WX:cyx645016617 参考目录: 目录 1 创建自定义网络层 2 创建一个完整的CNN 2.1 keras.Model vs ke ...

  5. 【小白学PyTorch】15 TF2实现一个简单的服装分类任务

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...

  6. 【小白学PyTorch】4 构建模型三要素与权重初始化

    文章目录: 目录 1 模型三要素 2 参数初始化 3 完整运行代码 4 尺寸计算与参数计算 1 模型三要素 三要素其实很简单 必须要继承nn.Module这个类,要让PyTorch知道这个类是一个Mo ...

  7. 【小白学PyTorch】9 tensor数据结构与存储结构

    文章来自微信公众号[机器学习炼丹术]. 上一节课,讲解了MNIST图像分类的一个小实战,现在我们继续深入学习一下pytorch的一些有的没的的小知识来作为只是储备. 参考目录: @ 目录 1 pyto ...

  8. 【小白学PyTorch】16 TF2读取图片的方法

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.NLP等多个学术交流分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx645016617. 参考 ...

  9. 小白学PyTorch 动态图与静态图的浅显理解

    文章来自公众号[机器学习炼丹术],回复"炼丹"即可获得海量学习资料哦! 目录 1 动态图的初步推导 2 动态图的叶子节点 3. grad_fn 4 静态图 本章节缕一缕PyTorc ...

随机推荐

  1. 注册github时总卡在第一步无法验证的解决办法

    从github官网可以看出问题所在,所以造成这一问题的极大可能就是浏览器的问题. 最简单的方法就是换手机浏览器进行注册

  2. [PyTorch 学习笔记] 3.3 池化层、线性层和激活函数层

    本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson3/nn_layers_others.py 这篇文章主要介绍 ...

  3. 如何使用zabbix监控公网环境的云服务器(从小白到高级技术顾问!!!)

    问题:当我们在本地部署了一台Zabbix服务器后,想要对云上的服务器做监控.但是zabbix一个在内网,云服务器一个在公网,网络环境不同该如何解决?能否检测到云服务器数据? 思路:使用NAT技术,将本 ...

  4. PAT甲级1151(由前序和中序确定LCA)

    The lowest common ancestor (LCA) of two nodes U and V in a tree is the deepest node that has both U ...

  5. 分布式文件存储:FastDFS简单使用与原理分析

    引言 FastDFS 属于分布式存储范畴,分布式文件系统 FastDFS 非常适合中小型项目,在我接手维护公司图片服务的时候开始接触到它,本篇文章目的是总结一下 FastDFS 的知识点. 用了 2 ...

  6. Go语言使用swagger生成接口文档

    swagger介绍 Swagger本质上是一种用于描述使用JSON表示的RESTful API的接口描述语言.Swagger与一组开源软件工具一起使用,以设计.构建.记录和使用RESTful Web服 ...

  7. 干Java这一行,应该怎样提升自己?

    前段时间,字节跳动在阿里巴巴的大本营杭州悄悄的建立一个研发中心,最近在疯狂招人. 相信最近一段时间,杭州的很多的互联网公司的开发人员都接到过猎头的电话.据了解,字节跳动杭州研发中心主要负责字节跳动新增 ...

  8. Mybatis-使用注解开发

    使用注解开发 [toc] 1. 面向接口编程 大家之前都学过面向对象编程,也学习过接口,但在真正的开发中,很多时候我们会选择面向接口编程 根本原因 : 解耦 , 可拓展 , 提高复用 , 分层开发中 ...

  9. oracle之dblink

    当用户要跨本地Oracle数据库,访问另外一个数据库表中的数据时,本地数据库中必须创建了远程数据库的dblink,通过dblink本地数据库可以像访问本地数据库一样访问远程数据库表中的数据.下面讲介绍 ...

  10. canvas学习作业,模仿做一个祖玛的小游戏

    这个游戏的原理我分为11个步骤,依次如下: 1.布局, 2.画曲线(曲线由两个半径不同的圆构成) 3.画曲线起点起始圆和曲线终点终止圆 4.起始的圆动起来, 5.起始的圆沿曲线走起来 6.起始的圆沿曲 ...