自定义tf.keras.Model需要注意的点

model.save()

  • subclass Model 是不能直接save的,save成.h5,但是能够save_weights,或者save_format="tf"
NotImplementedError: Saving the model to HDF5 format requires the model to be a Functional model or a Sequential model. It does not work for subclassed models, because such models are defined via the body of a Python method, which isn't safely serializable. Consider saving to the Tensorflow SavedModel format (by setting save_format="tf") or using `save_weights`.

model.trainable_variables

  • __init__若没有注册该layers,那么在后面应用梯度时会找不到model.trainable_variables。

    像下面这样是不行的:
class Map_model(tf.keras.Model):
def __init__(self, is_train=False):
super(Map_model, self).__init__()
def call(self, x):
x = tf.keras.layers.Dense(10, activation='relu')
return x

model.summary()

  • 需要先指定input_shape,或者你直接fit一遍它也能自动确定
    model.build(input_shape=(None, 448, 448, 3))
print(model.summary())
class Map_model(tf.keras.Model):
def __init__(self, is_train=False):
super(Map_model, self).__init__()
self.map_f1 = tf.keras.layers.Dense(10, activation='relu', trainable=is_train)
# self.map_f2 = tf.keras.layers.Dense(6, activation='relu')
self.map_f3 = tf.keras.layers.Dense(3, activation='softmax', trainable=is_train) def call(self, x):
x = self.map_f1(x)
# x = self.map_f2(x)
return self.map_f3(x) @tf.function
def train_step(mmodel, label, L_label, loss_object, train_loss, train_accuracy, optimizer):
with tf.GradientTape() as tape:
L_label_pred = mmodel(label)
loss = loss_object(L_label, L_label_pred)
gradient_l = tape.gradient(loss, mmodel.trainable_variables)
train_loss(loss)
train_accuracy(L_label, L_label_pred)
optimizer.apply_gradients(zip(gradient_l, mmodel.trainable_variables)) def train():
mmodel = Map_model(is_train=True)
optimizer = tf.keras.optimizers.Adam(0.01)
loss_object = tf.keras.losses.CategoricalCrossentropy()
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy') EPOCHS = 0
labels = range(1, 30) # labels = truth_label -1
L_labels = [int(prpcs.map2Lclass(l)) for l in labels]
labels = [l - 1 for l in labels]
labels_onehot = tf.one_hot(labels, depth=29)
L_labels_onehot = tf.one_hot(L_labels, depth=3)
EPS = 1e-6
loss_e = 0x7f7f7f
while loss_e > EPS:
EPOCHS += 1
train_loss.reset_states()
train_accuracy.reset_states()
train_step(mmodel, labels_onehot, L_labels_onehot, loss_object, train_loss, train_accuracy, optimizer) template = 'Epoch {}, Loss: {}, Accuracy: {}'
print(template.format(EPOCHS,
train_loss.result(),
train_accuracy.result() * 100))
loss_e = train_loss.result()
print("labels_onehot shape:", labels_onehot.shape)
model_path = r'./models/'
if not os.path.exists(model_path):
os.makedirs(model_path)
mmodel.save(os.path.join(model_path, 'map_model_{}'.format(EPS)))
mmodel.save_weights(os.path.join(model_path, 'map_model_weights_{}'.format(EPS)))
print("Save model!")

tensorflow 2.0 技巧 | 自定义tf.keras.Model的坑的更多相关文章

  1. tf.keras遇见的坑:Output tensors to a Model must be the output of a TensorFlow `Layer`

    经过网上查找,找到了问题所在:在使用keras编程模式是,中间插入了tf.reshape()方法便遇到此问题. 解决办法:对于遇到相同问题的任何人,可以使用keras的Lambda层来包装张量流操作, ...

  2. [TensorFlow 2.0] Keras 简介

    Keras 是一个用于构建和训练深度学习模型的高阶 API.它可用于快速设计原型.高级研究和生产. keras的3个优点: 方便用户使用.模块化和可组合.易于扩展 简单点说就是,简单.好用.快(构建) ...

  3. 三分钟快速上手TensorFlow 2.0 (上)——前置基础、模型建立与可视化

    本文学习笔记参照来源:https://tf.wiki/zh/basic/basic.html 学习笔记类似提纲,具体细节参照上文链接 一些前置的基础 随机数 tf.random uniform(sha ...

  4. python 3.7 安装 sklearn keras(tf.keras)

    # 1   sklearn  一般方法 网上有很多教程,不再赘述. 注意顺序是 numpy+mkl     ,然后 scipy的环境,scipy,然后 sklearn # 2 anoconda ana ...

  5. 【tf.keras】TensorFlow 1.x 到 2.0 的 API 变化

    TensorFlow 2.0 版本将 keras 作为高级 API,对于 keras boy/girl 来说,这就很友好了.tf.keras 从 1.x 版本迁移到 2.0 版本,需要修改几个地方. ...

  6. 一文上手Tensorflow2.0之tf.keras(三)

    系列文章目录: Tensorflow2.0 介绍 Tensorflow 常见基本概念 从1.x 到2.0 的变化 Tensorflow2.0 的架构 Tensorflow2.0 的安装(CPU和GPU ...

  7. TensorFlow2.0(11):tf.keras建模三部曲

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  8. 【tf.keras】tf.keras使用tensorflow中定义的optimizer

    Update:2019/09/21 使用 tf.keras 时,请使用 tf.keras.optimizers 里面的优化器,不要使用 tf.train 里面的优化器,不然学习率衰减会出现问题. 使用 ...

  9. 【tf.keras】Resource exhausted: OOM when allocating tensor with shape [9216,4096] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc

    运行以下类似代码: while True: inputs, outputs = get_AlexNet() model = tf.keras.Model(inputs=inputs, outputs= ...

随机推荐

  1. Apache工作流程

    一个经典的Apache处理php页面的流程 需要连接mysql数据库并处理的流程 网站是一系列网页的组合 从用户角度看就是访问诸如 hhtp://www.baidu.com -----url 这是互联 ...

  2. Python网络数据采集(1):博客访问量统计

    前言 Python中能够爬虫的包还有很多,但requests号称是“让HTTP服务人类”...口气不小,但的确也很好用. 本文是博客里爬虫的第一篇,实现一个很简单的功能:获取自己博客主页里的访问量. ...

  3. 当心JavaScript奇葩的逗号表达式

    看看下面的代码输出什么? let a = 2; switch (a) { case (3, 2, 5): console.log(1); break case (2, 3, 4): console.l ...

  4. 2019春Python程序设计练习7(0430--0506)

    1-1 对文件进行读写操作之后必须显式关闭文件以确保所有内容都得到保存. (2分) T         F 1-2 以追加模式打开文件时,文件指针指向文件尾.(2分) T         F 1-3 ...

  5. HTML中的表格和图像总结

    ㈠表格 ⑴表格的基本结构 ①表格的基本标签有:table标签(表格),tr标签(行),td标签(单元格).<tr>标签和<td>标签都要在表格的开始标签<table> ...

  6. 大视频上传T级别解决方案

    核心原理: 该项目核心就是文件分块上传.前后端要高度配合,需要双方约定好一些数据,才能完成大文件分块,我们在项目中要重点解决的以下问题. * 如何分片: * 如何合成一个文件: * 中断了从哪个分片开 ...

  7. BZOJ 3940 Censoring ( Trie 图 )

    题目链接 题意 : 中文题.点链接 分析 : 直接建 Trie 图.在每一个串的末尾节点记录其整串长度.方便删串操作 然后对于问询串.由于可能有删串操作 所以在跑 Trie 图的过程当中需要拿个栈记录 ...

  8. MongoDB下载以及安装

    一.下载与安装 1.安装Mongo MongoDB下载地址:https://www.mongodb.com/download-center?jmp=tutorials#community 运行安装程序 ...

  9. PX4学习之-uORB msg 自动生成模板解读

    最后更新日期 2019-06-22 一.前言 在 PX4学习之-uORB简单体验 中指出, 使用 uORB 进行通信的第一步是新建 msg.在实际编译过程中,新建的 msg 会转换成对应的 .h..c ...

  10. JS框架_(JQuery.js)纯css3进度条动画

    百度云盘 传送门 密码:wirc 进度条动画效果: <!DOCTYPE html> <html lang="zh"> <head> <me ...