在keras中保存模型有几种方式:

(1):使用callbacks,可以保存训练中任意的模型,或选择最好的模型

logdir = './callbacks'
if not os.path.exists(logdir):
os.mkdir(logdir)
output_model_file = os.path.join(logdir, "xxxx.h5")
callbacks = [
tf.keras.callbacks.ModelCheckpoint(output_model_file, save_best_file = True)
] hist = model.fit_generator(xxxxx, callbacks = callbacks)

(2): 使用model.save(),会把整个模型保存下来,包括网络和参数

(3): 使用model.save_weights(),只保存模型的参数

当使用自定义的层或loss时,只有(3)可以直接使用,1 2会报下面这种错:

NotImplementedError: Layers with arguments in `__init__` must override `get_config`.
ValueError: Unknown loss function:loss
ValueError: Unknown layer: xxxlayer

解决办法:

在自定义网络层时重写get_config函数

我们主要看传入__init__接口时有哪些配置参数,然后在get_config内一一的将它们转为字典键值并且返回使用,以Mylayer为例:

class MyLayer(tf.keras.layers.Layer):
def __init__(self, num_outputs, name="MyLayer", **kwargs):
super(MyLayer, self).__init__(name=name, **kwargs)
self.num_outputs = num_outputs def build(self, input_shape):
self.kernel = self.add_variable("kernel", shape=[int(input_shape[-1]), self.num_outputs])
super().build(input_shape) def call(self, input):
output = tf.matmul(input, self.kernel)
return output def get_config(self):
config = {"num_outputs":self.num_outputs}
base_config = super(Mylayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

一般来说,父类的config也是需要一并保存的,其中base_config即是父类网络层实现的配置参数,最后把父类及继承类的config组装为字典形式即可解决该问题

然后 在加载模型的时候,建立一个字典,该字典的键是自定义网络层时设定该层的名字,其值为该自定义网络层的类名,该字典将用于加载模型时使用

如果还使用了自定义的loss,则把loss也加到_custom_objects中

_custom_objects = {
"Mylayer" : Mylayer,
"loss" : Myloss
}

最后在load模型的时候把_custom_objects传入

model = tf.keras.models.load_model("path/to/your/model", custom_objects=_custom_objects)

keras中保存自定义层和loss的更多相关文章

  1. Keras处理已保存模型中的自定义层(或其他自定义对象)

    如果要加载的模型包含自定义层或其他自定义类或函数,则可以通过 custom_objects 参数将它们传递给加载机制: from keras.models import load_model # 假设 ...

  2. Keras中使用LSTM层时设置的units参数是什么

    https://www.zhihu.com/question/64470274 http://colah.github.io/posts/2015-08-Understanding-LSTMs/ ht ...

  3. OC中保存自定义类型对象的持久化方法

    OC中如果要将自定义类型的对象保存到文件中,必须进行以下三个条件: 想要把存放自定义类型的数组进行 持久化(就是将内存中的临时数据以文件<数据库等>的形式写到磁盘上)必须满足: 1. 自定 ...

  4. keras 中如何自定义损失函数

    http://lazycoderx.com/2016/10/09/keras%E4%BF%9D%E5%AD%98%E6%A8%A1%E5%9E%8B%E6%97%B6%E4%BD%BF%E7%94%A ...

  5. keras中自定义Layer

    最近在学习SSD的源码,其中有两个自定的层,特此学习一下并记录. import keras.backend as K from keras.engine.topology import InputSp ...

  6. 为何Keras中的CNN是有问题的,如何修复它们?

    在训练了 50 个 epoch 之后,本文作者惊讶地发现模型什么都没学到,于是开始深挖背后的问题,并最终从恺明大神论文中得到的知识解决了问题. 上个星期我做了一些实验,用了在 CIFAR10 数据集上 ...

  7. keras中的loss、optimizer、metrics

    用keras搭好模型架构之后的下一步,就是执行编译操作.在编译时,经常需要指定三个参数 loss optimizer metrics 这三个参数有两类选择: 使用字符串 使用标识符,如keras.lo ...

  8. keras中的模型保存和加载

    tensorflow中的模型常常是protobuf格式,这种格式既可以是二进制也可以是文本.keras模型保存和加载与tensorflow不同,keras中的模型保存和加载往往是保存成hdf5格式. ...

  9. Keras 自定义层

    1.对于简单的定制操作,可以通过使用layers.core.Lambda层来完成.该方法的适用情况:仅对流经该层的数据做个变换,而这个变换本身没有需要学习的参数. # 切片后再分别进行embeddin ...

随机推荐

  1. C# Stream篇(五) -- MemoryStream

    MemoryStream 目录: 1 简单介绍一下MemoryStream 2 MemoryStream和FileStream的区别 3 通过部分源码深入了解下MemoryStream 4 分析Mem ...

  2. P2057 [SHOI2007]善意的投票 / [JLOI2010]冠军调查

    P2057 [SHOI2007]善意的投票 / [JLOI2010]冠军调查 拿来练网络流的qwq 思路:如果i不同意,连边(i,t,1),否则连边(s,i,1).好朋友x,y间连边(x,y,1)(y ...

  3. 自定义 radio 的样式,更改选中样式

      思路: 1. 可以为<label>元素添加生成性内容(伪元素),并基于单选按钮的状态来为其设置样式: 2. 然后把真正的单选按钮隐藏起来: 3. 最后把生成内容美化一下. 解决方法: ...

  4. 用Visual studio11在Windows8上开发驱动实现注册表监控和过滤

    在Windows NT中,80386保护模式的“保护”比Windows 95中更坚固,这个“镀金的笼子”更加结实,更加难以打破.在Windows 95中,至少应用程序I/O操作是不受限制的,而在Win ...

  5. C++编程学习(四)声明/枚举

    一.typedef 声明 typedef 为一个已有的类型取一个新的名字 typedef int num;//feet是int的另一个名字num a;//a是int类型 二.枚举类型 enum col ...

  6. POJ 2226 缩点建图+二分图最大匹配

    这个最小覆盖但不同于 POJ 3041,只有横或者竖方向连通的点能用一块板子覆盖,非连续的,就要用多块 所以用类似并查集方法,分别横向与竖向缩点,有交集的地方就连通,再走一遍最大匹配即可 一开始还有点 ...

  7. Python中的抽象基类

    1.说在前头 "抽象基类"这个词可能听着比较"深奥",其实"基类"就是"父类","抽象"就是&quo ...

  8. AT2000 Leftmost Ball 解题报告

    题面 给你n种颜色的球,每个球有k个,把这n*k个球排成一排,把每一种颜色的最左边出现的球涂成白色(初始球不包含白色),求有多少种不同的颜色序列,答案对1e9+7取模 解法 设\(f(i,\;j)\) ...

  9. 7. react 基础 - React Developer Tools 的安装 及 使用

    1. 安装 react 开发调试工具 React Developer Tools 打开 chrome 浏览器访问 chrome://extensions/ 点击右上角的 拓展程序 -> 打开 c ...

  10. 69.ORM查询条件:isnull和regex的使用

    首先查看数据库中的article表的数据: 定义模型的文件models.py中的示例代码如下: from django.db import models class Category(models.M ...