[Tensorflow] 使用 model.save_weights() 保存 Keras Subclassed Model
import numpy as np import matplotlib.pyplot as plt import os import time import tensorflow as tf tf.enable_eager_execution() # create data X = np.linspace(-1, 1, 5000) np.random.shuffle(X) y = 0.5 * X + 2 + np.random.normal(0, 0.05, (5000,)) # plot data plt.scatter(X, y) plt.show() # split data X_train, y_train = X[:4000], y[:4000] X_test, y_test = X[4000:], y[4000:] # tf.data BATCH_SIZE = 32 BUFFER_SIZE = 512 dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(BATCH_SIZE).shuffle(BUFFER_SIZE) # subclassed model UNITS = 1 class Model(tf.keras.Model): def __init__(self): super(Model, self).__init__() self.fc = tf.keras.layers.Dense(units=UNITS) def call(self, inputs): return self.fc(inputs) model = Model() optimizer = tf.train.AdamOptimizer() # loss function def loss_function(real, pred): return tf.losses.mean_squared_error(labels=real, predictions=pred) EPOCHS = 30 checkpoint_dir = './save_subclassed_keras_model_training_checkpoints' if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) # training loop for epoch in range(EPOCHS): start = time.time() epoch_loss = 0 for (batch, (x, y)) in enumerate(dataset): x = tf.cast(x, tf.float32) y = tf.cast(y, tf.float32) x = tf.expand_dims(x, axis=1) y = tf.expand_dims(y, axis=1) # print(x) # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32) # print(y) # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32) with tf.GradientTape() as tape: predictions = model(x) # print(predictions) # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32) batch_loss = loss_function(real=y, pred=predictions) grads = tape.gradient(batch_loss, model.variables) optimizer.apply_gradients(zip(grads, model.variables), global_step=tf.train.get_or_create_global_step()) epoch_loss += batch_loss if (batch + 1) % 10 == 0: print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, batch + 1, batch_loss/int(x.shape[0]))) print('Epoch {} Loss {:.4f}'.format(epoch + 1, epoch_loss/len(X_train))) print('Time taken for 1 epoch {} sec\n'.format(time.time() - start)) # save checkpoint checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt') if (epoch + 1) % 10 == 0: model.save_weights(checkpoint_prefix.format(epoch=epoch), overwrite=True) _model = Model() _model.load_weights(tf.train.latest_checkpoint(checkpoint_dir)) _model.build(input_shape=tf.TensorShape([BATCH_SIZE, 1])) _model.summary() test_dataset = tf.data.Dataset.from_tensor_slices(X_test).batch(1) for (batch, x) in enumerate(test_dataset): x = tf.cast(x, tf.float32) x = tf.expand_dims(x, axis=1) print(x) predictions = _model(x) print(predictions) exit()
[Tensorflow] 使用 model.save_weights() 保存 Keras Subclassed Model的更多相关文章
- [Tensorflow] 使用 model.save_weights() 保存 / 加载 Keras Subclassed Model
在 parameters.py 中,定义了各类参数. # training data directory TRAINING_DATA_DIR = './data/' # checkpoint dire ...
- [Tensorflow] 使用 tf.train.Checkpoint() 保存 / 加载 keras subclassed model
在 subclassed_model.py 中,通过对 tf.keras.Model 进行子类化,设计了两个自定义模型. import tensorflow as tf tf.enable_eager ...
- keras系列︱Sequential与Model模型、keras基本结构功能(一)
引自:http://blog.csdn.net/sinat_26917383/article/details/72857454 中文文档:http://keras-cn.readthedocs.io/ ...
- Keras(一)Sequential与Model模型、Keras基本结构功能
keras介绍与基本的模型保存 思维导图 1.keras网络结构 2.keras网络配置 3.keras预处理功能 模型的节点信息提取 config = model.get_config() 把mod ...
- [Model] LeNet-5 by Keras
典型的卷积神经网络. 数据的预处理 Keras傻瓜式读取数据:自动下载,自动解压,自动加载. # X_train: array([[[[ 0., 0., 0., ..., 0., 0., 0.], [ ...
- xen 保存快照的实现之 —— device model 状态保存
xen 保存快照的实现之 —— device model 状态保存 实现要点: 设备状态保存在 /var/lib/xen/qemu-save.x 文件这个文件由 qemu-dm 产生,也由 qemu- ...
- AI - TensorFlow - 示例05:保存和恢复模型
保存和恢复模型(Save and restore models) 官网示例:https://www.tensorflow.org/tutorials/keras/save_and_restore_mo ...
- 如何保存Keras模型
我们不推荐使用pickle或cPickle来保存Keras模型 你可以使用model.save(filepath)将Keras模型和权重保存在一个HDF5文件中,该文件将包含: 模型的结构,以便重构该 ...
- Python之TensorFlow的模型训练保存与加载-3
一.TensorFlow的模型保存和加载,使我们在训练和使用时的一种常用方式.我们把训练好的模型通过二次加载训练,或者独立加载模型训练.这基本上都是比较常用的方式. 二.模型的保存与加载类型有2种 1 ...
随机推荐
- JDBC基本应用
首先我们来看一下JDBC操作数据的核心: Connection 应用程序与数据库之间的桥梁 数据库驱动程序是构建桥梁的基石和材料 DriverManager类是基石和材料的管理员 Statement ...
- Skype for Business七大新功能
Lync Server 2013的下一版本号.Skype for Business将于2015年4月正式公布,下面是七大新功能. "呼叫监听"(Call Monitor)--假设你 ...
- HDU 5754Life Winner Bo
Life Winner Bo Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 131072/131072 K (Java/Others) ...
- HDU 1379:DNA Sorting
DNA Sorting Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 65536/32768 K (Java/Others) Tota ...
- [CF348B]Apple Tree
https://www.zybuluo.com/ysner/note/1300249 题面 给一棵大小为\(n\)的树,树的每个叶子节点上有权值. 定义一颗树平衡:对于每一个结点\(u\)的子树都拥有 ...
- 基于CentOS7.5的 Rsync 服务详解
第1章 Rsync概述 1.1 Rsync基本概述 rsync是一款开源的备份工具,可以在不同服务器(主机)之间进行同步备份, 可实现完全备份与增量备份,因此非常适合用于架构集中式备份或异地备份等应用 ...
- E20171230-hm
refine vt. 提炼; 改善; 使高雅; revert vi. 恢复; 重提; 回到…上; <律>归还; vt. 使恢复原状; n ...
- bzoj 1710: [Usaco2007 Open]Cheappal 廉价回文【区间dp】
只要发现添加一个字符和删除一个字符是等价的,就是挺裸的区间dp了 因为在当前位置加上一个字符x就相当于在他的对称位置删掉字符x,所以只要考虑删除即可,删除费用是添加和删除取min 设f[i][j]为从 ...
- Java并发编程系列之CyclicBarrier详解
简介 jdk原文 A synchronization aid that allows a set of threads to all wait for each other to reach a co ...
- 【js】再谈移动端的模态框实现
移动端模态框的机制因为与PC的模态框机制一直有所区别,一直是许多新人很容易踩坑的地方,最近笔者作为一条老咸鱼也踩进了一个新坑中,真是平日里代码读得太粗略,故而写上几笔,以儆效尤. 故事的起因是这样的, ...