[Tensorflow] 使用 model.save_weights() 保存 Keras Subclassed Model
import numpy as np
import matplotlib.pyplot as plt
import os
import time
import tensorflow as tf
tf.enable_eager_execution()
# create data
X = np.linspace(-1, 1, 5000)
np.random.shuffle(X)
y = 0.5 * X + 2 + np.random.normal(0, 0.05, (5000,))
# plot data
plt.scatter(X, y)
plt.show()
# split data
X_train, y_train = X[:4000], y[:4000]
X_test, y_test = X[4000:], y[4000:]
# tf.data
BATCH_SIZE = 32
BUFFER_SIZE = 512
dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(BATCH_SIZE).shuffle(BUFFER_SIZE)
# subclassed model
UNITS = 1
class Model(tf.keras.Model):
def __init__(self):
super(Model, self).__init__()
self.fc = tf.keras.layers.Dense(units=UNITS)
def call(self, inputs):
return self.fc(inputs)
model = Model()
optimizer = tf.train.AdamOptimizer()
# loss function
def loss_function(real, pred):
return tf.losses.mean_squared_error(labels=real, predictions=pred)
EPOCHS = 30
checkpoint_dir = './save_subclassed_keras_model_training_checkpoints'
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
# training loop
for epoch in range(EPOCHS):
start = time.time()
epoch_loss = 0
for (batch, (x, y)) in enumerate(dataset):
x = tf.cast(x, tf.float32)
y = tf.cast(y, tf.float32)
x = tf.expand_dims(x, axis=1)
y = tf.expand_dims(y, axis=1)
# print(x) # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32)
# print(y) # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32)
with tf.GradientTape() as tape:
predictions = model(x)
# print(predictions) # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32)
batch_loss = loss_function(real=y, pred=predictions)
grads = tape.gradient(batch_loss, model.variables)
optimizer.apply_gradients(zip(grads, model.variables),
global_step=tf.train.get_or_create_global_step())
epoch_loss += batch_loss
if (batch + 1) % 10 == 0:
print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
batch + 1,
batch_loss/int(x.shape[0])))
print('Epoch {} Loss {:.4f}'.format(epoch + 1, epoch_loss/len(X_train)))
print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
# save checkpoint
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
if (epoch + 1) % 10 == 0:
model.save_weights(checkpoint_prefix.format(epoch=epoch), overwrite=True)
_model = Model()
_model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
_model.build(input_shape=tf.TensorShape([BATCH_SIZE, 1]))
_model.summary()
test_dataset = tf.data.Dataset.from_tensor_slices(X_test).batch(1)
for (batch, x) in enumerate(test_dataset):
x = tf.cast(x, tf.float32)
x = tf.expand_dims(x, axis=1)
print(x)
predictions = _model(x)
print(predictions)
exit()
[Tensorflow] 使用 model.save_weights() 保存 Keras Subclassed Model的更多相关文章
- [Tensorflow] 使用 model.save_weights() 保存 / 加载 Keras Subclassed Model
在 parameters.py 中,定义了各类参数. # training data directory TRAINING_DATA_DIR = './data/' # checkpoint dire ...
- [Tensorflow] 使用 tf.train.Checkpoint() 保存 / 加载 keras subclassed model
在 subclassed_model.py 中,通过对 tf.keras.Model 进行子类化,设计了两个自定义模型. import tensorflow as tf tf.enable_eager ...
- keras系列︱Sequential与Model模型、keras基本结构功能(一)
引自:http://blog.csdn.net/sinat_26917383/article/details/72857454 中文文档:http://keras-cn.readthedocs.io/ ...
- Keras(一)Sequential与Model模型、Keras基本结构功能
keras介绍与基本的模型保存 思维导图 1.keras网络结构 2.keras网络配置 3.keras预处理功能 模型的节点信息提取 config = model.get_config() 把mod ...
- [Model] LeNet-5 by Keras
典型的卷积神经网络. 数据的预处理 Keras傻瓜式读取数据:自动下载,自动解压,自动加载. # X_train: array([[[[ 0., 0., 0., ..., 0., 0., 0.], [ ...
- xen 保存快照的实现之 —— device model 状态保存
xen 保存快照的实现之 —— device model 状态保存 实现要点: 设备状态保存在 /var/lib/xen/qemu-save.x 文件这个文件由 qemu-dm 产生,也由 qemu- ...
- AI - TensorFlow - 示例05:保存和恢复模型
保存和恢复模型(Save and restore models) 官网示例:https://www.tensorflow.org/tutorials/keras/save_and_restore_mo ...
- 如何保存Keras模型
我们不推荐使用pickle或cPickle来保存Keras模型 你可以使用model.save(filepath)将Keras模型和权重保存在一个HDF5文件中,该文件将包含: 模型的结构,以便重构该 ...
- Python之TensorFlow的模型训练保存与加载-3
一.TensorFlow的模型保存和加载,使我们在训练和使用时的一种常用方式.我们把训练好的模型通过二次加载训练,或者独立加载模型训练.这基本上都是比较常用的方式. 二.模型的保存与加载类型有2种 1 ...
随机推荐
- vuex资料
vuex最简单.最详细的入门文档 链接:https://segmentfault.com/a/1190000009404727 https://www.jb51.net/article/138239. ...
- 工作总结 default Console.WriteLine(default(Guid));
泛型代码中的默认关键字 在泛型类和泛型方法中产生的一个问题是,在预先未知以下情况时,如何将默认值分配给参数化类型 T: T 是引用类型还是值类型. 如果 T 为值类型,则它是数值还是结构. 给定参数化 ...
- Rational 最新软件试用下载地址
看到非常多 TX 都在问老版本号 Raitonal 软件相关的问题,可是因为产品升级的时候有非常多名字都发生了更改(比方说 Rational Rose 最新的版本号变成了 Rational Softw ...
- 第一章:Android系统介绍android虚拟机
学习android,我们是要了解他的历史的,这里我也就不在累述什么大家都知道的东东了,简单的介绍下内部的相关内容: 1:android虚拟机 我们学习java知道java用的是JVM虚拟机,而开发An ...
- Iteye已经沦陷
watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvZHl5YXJpZXM=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA ...
- ArcGIS中生成蜂窝多边形算法解析
近来有不少同学.都有问我关于蜂窝多边形的问题.也就是正六边形,也就是以下这个东东: 一般的问答模式例如以下: 亲们问:ArcGIS里面那个工具能够做这个东东? 虾神答:额,没有原生的工具. 亲们问:那 ...
- CSP 201612-3 权限查询 【模拟+STL】
201612-3 试题名称: 权限查询 时间限制: 1.0s 内存限制: 256.0MB 问题描述: 问题描述 授权 (authorization) 是各类业务系统不可缺少的组成部分,系统用户通过授权 ...
- Spring 之AOP 面向切面编程
AOP相关术语: Joinpoint (连接点):所谓连接点是指那些被拦截到的点,在spring中,这些点指的是方法,因为spring 只支持方法类型的连接点. Pointcut(切入点):所谓切入点 ...
- Ubuntu 16.04下安装MacBuntu 16.04 TP 变身Mac OS X主题风格
Ubuntu 16.04下安装MacBuntu 16.04 TP 变身Mac OS X主题风格 sudo add-apt-repository ppa:noobslab/macbuntu sudo a ...
- ORACLE获取某个时间段之间的月份列表和日期列表
ORACLE获取某个时间段之间的月份列表获取某个时间段之间的月份列表(示例返回2009-03到2010-03之间的月份列表) SELECT TO_CHAR(ADD_MONTHS(TO_DATE('20 ...