keras中保存自定义层和loss
在keras中保存模型有几种方式:
(1):使用callbacks,可以保存训练中任意的模型,或选择最好的模型
logdir = './callbacks'
if not os.path.exists(logdir):
os.mkdir(logdir)
output_model_file = os.path.join(logdir, "xxxx.h5")
callbacks = [
tf.keras.callbacks.ModelCheckpoint(output_model_file, save_best_file = True)
]
hist = model.fit_generator(xxxxx, callbacks = callbacks)
(2): 使用model.save(),会把整个模型保存下来,包括网络和参数
(3): 使用model.save_weights(),只保存模型的参数
当使用自定义的层或loss时,只有(3)可以直接使用,1 2会报下面这种错:
NotImplementedError: Layers with arguments in `__init__` must override `get_config`.
ValueError: Unknown loss function:loss
ValueError: Unknown layer: xxxlayer
解决办法:
在自定义网络层时重写get_config函数
我们主要看传入__init__接口时有哪些配置参数,然后在get_config内一一的将它们转为字典键值并且返回使用,以Mylayer为例:
class MyLayer(tf.keras.layers.Layer):
def __init__(self, num_outputs, name="MyLayer", **kwargs):
super(MyLayer, self).__init__(name=name, **kwargs)
self.num_outputs = num_outputs
def build(self, input_shape):
self.kernel = self.add_variable("kernel", shape=[int(input_shape[-1]), self.num_outputs])
super().build(input_shape)
def call(self, input):
output = tf.matmul(input, self.kernel)
return output
def get_config(self):
config = {"num_outputs":self.num_outputs}
base_config = super(Mylayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
一般来说,父类的config也是需要一并保存的,其中base_config即是父类网络层实现的配置参数,最后把父类及继承类的config组装为字典形式即可解决该问题
然后 在加载模型的时候,建立一个字典,该字典的键是自定义网络层时设定该层的名字,其值为该自定义网络层的类名,该字典将用于加载模型时使用
如果还使用了自定义的loss,则把loss也加到_custom_objects中
_custom_objects = {
"Mylayer" : Mylayer,
"loss" : Myloss
}
最后在load模型的时候把_custom_objects传入
model = tf.keras.models.load_model("path/to/your/model", custom_objects=_custom_objects)
keras中保存自定义层和loss的更多相关文章
- Keras处理已保存模型中的自定义层(或其他自定义对象)
如果要加载的模型包含自定义层或其他自定义类或函数,则可以通过 custom_objects 参数将它们传递给加载机制: from keras.models import load_model # 假设 ...
- Keras中使用LSTM层时设置的units参数是什么
https://www.zhihu.com/question/64470274 http://colah.github.io/posts/2015-08-Understanding-LSTMs/ ht ...
- OC中保存自定义类型对象的持久化方法
OC中如果要将自定义类型的对象保存到文件中,必须进行以下三个条件: 想要把存放自定义类型的数组进行 持久化(就是将内存中的临时数据以文件<数据库等>的形式写到磁盘上)必须满足: 1. 自定 ...
- keras 中如何自定义损失函数
http://lazycoderx.com/2016/10/09/keras%E4%BF%9D%E5%AD%98%E6%A8%A1%E5%9E%8B%E6%97%B6%E4%BD%BF%E7%94%A ...
- keras中自定义Layer
最近在学习SSD的源码,其中有两个自定的层,特此学习一下并记录. import keras.backend as K from keras.engine.topology import InputSp ...
- 为何Keras中的CNN是有问题的,如何修复它们?
在训练了 50 个 epoch 之后,本文作者惊讶地发现模型什么都没学到,于是开始深挖背后的问题,并最终从恺明大神论文中得到的知识解决了问题. 上个星期我做了一些实验,用了在 CIFAR10 数据集上 ...
- keras中的loss、optimizer、metrics
用keras搭好模型架构之后的下一步,就是执行编译操作.在编译时,经常需要指定三个参数 loss optimizer metrics 这三个参数有两类选择: 使用字符串 使用标识符,如keras.lo ...
- keras中的模型保存和加载
tensorflow中的模型常常是protobuf格式,这种格式既可以是二进制也可以是文本.keras模型保存和加载与tensorflow不同,keras中的模型保存和加载往往是保存成hdf5格式. ...
- Keras 自定义层
1.对于简单的定制操作,可以通过使用layers.core.Lambda层来完成.该方法的适用情况:仅对流经该层的数据做个变换,而这个变换本身没有需要学习的参数. # 切片后再分别进行embeddin ...
随机推荐
- 12 —— node 获取文件属性 —— 利用 自调用 闭包函数 解决 i 丢失的问题
闭包的作用 : 保存变量 一,i 丢失的案例 var arr = ['node','vue','mysql'] for(var i=0;i<arr.length;i++){ setTimeout ...
- tableau-参数
tableau参数可用在计算字段.筛选器和参考线中替换常量值得动态值. 三种方式:1.在计算字段中使用筛选器 案例动态替换计算字段中设定的目标值. 创建参数 以参数值创建计算字段 2.筛选器中使用参数 ...
- 【LeetCode】两个数相加
[问题]给定两个非空链表来表示两个非负整数.位数按照逆序方式存储,它们的每个节点只存储单个数字.将两数相加返回一个新的链表. 你可以假设除了数字 0 之外,这两个数字都不会以零开头. [实例] 输入: ...
- bzoj 3732Network
先搞个最小生成树,然后lca(和之前的一个cf题差不多2333, 纯属颓废了..) 顺便思考了一下正确性. 因为所求的是所有路径中最大边的最小值.而kruskal每次往里添加的就是最小边.所以在生成树 ...
- 通过整合遥感数据和社交媒体数据来进行城市土地利用的分类( Classifying urban land use by integrating remote sensing and social media data)DOI: 10.1080/13658816.2017.1324976 20.0204
Classifying urban land use by integrating remote sensing and social media data Xiaoping Liu, Jialv ...
- Java enum应用小结
用法一:常量 在JDK1.5 之前,我们定义常量都是: public static fianl.... .现在好了,有了枚举,可以把相关的常量分组到一个枚举类型里,而且枚举提供了比常量更多的方法. p ...
- UVA - 10285 Longest Run on a Snowboard(最长的滑雪路径)(dp---记忆化搜索)
题意:在一个R*C(R, C<=100)的整数矩阵上找一条高度严格递减的最长路.起点任意,但每次只能沿着上下左右4个方向之一走一格,并且不能走出矩阵外.矩阵中的数均为0~100. 分析:dp[x ...
- php 常用编译参数
安装依赖 yum install -y gcc gcc-c++ make zlib zlib-devel pcre pcre-devel libjpeg libjpeg-devel libpng li ...
- (转)绝对路径${pageContext.request.contextPath}用法及其与web.xml中Servlet的url-pattern匹配过程
以系统的一个“添加商品”的功能为例加以说明,系统页面为add.jsp,如图一所示: 图一 添加商品界面 系统的代码目录结构及add.jsp代码如图二所示: 图二 系统的代码目录结构及add.js ...
- 201609-2 火车购票 Java
思路待补充 import java.util.Scanner; class Main{ public static void main(String[] args) { //100个座位 int[] ...