[Tensorflow] 使用 model.save_weights() 保存 / 加载 Keras Subclassed Model
在 parameters.py 中,定义了各类参数。
# training data directory
TRAINING_DATA_DIR = './data/' # checkpoint directory
CHECKPOINT_DIR = './training_checkpoints/' # training details
BATCH_SIZE = 16
BUFFER_SIZE = 128
EPOCHS = 15
在 numpy_dataset.py 中,创建了 5000 组训练数据集,模拟 y = x^3 + 1,并二进制格式写入文件。
from parameters import TRAINING_DATA_DIR import numpy as np
import matplotlib.pyplot as plt
import os # create training data
X = np.linspace(-1, 1, 5000)
np.random.shuffle(X)
y = X ** 3 + 1 + np.random.normal(0, 0.01, (5000,)) # plot training data
plt.scatter(X, y)
plt.show() # save data
if not os.path.exists(TRAINING_DATA_DIR):
os.makedirs(TRAINING_DATA_DIR) X.tofile(os.path.join(TRAINING_DATA_DIR + 'training_data_X.bin'))
y.tofile(os.path.join(TRAINING_DATA_DIR + 'training_data_y.bin'))
在 subclassed_model.py 中,通过对 tf.keras.models.Model 进行子类化,设计了两个自定义模型。
import tensorflow as tf
tf.enable_eager_execution() # model definition
class Encoder(tf.keras.models.Model):
def __init__(self):
super(Encoder, self).__init__()
self.fc1 = tf.keras.layers.Dense(units=16, activation='relu')
self.fc2 = tf.keras.layers.Dense(units=8, activation='relu') def call(self, inputs):
r = self.fc1(inputs)
return self.fc2(r) class Decoder(tf.keras.models.Model):
def __init__(self):
super(Decoder, self).__init__()
self.fc = tf.keras.layers.Dense(units=1, activation=None) def call(self, inputs):
return self.fc(inputs)
在 loss_function.py 中,定义了损失函数。
import tensorflow as tf
tf.enable_eager_execution() def loss(real, pred):
return tf.losses.mean_squared_error(labels=real, predictions=pred)
在 training.py 中,使用在 numpy_dataset.py 中创建的数据集训练模型,之后使用 model.save_weights() 保存 Keras Subclassed Model 模型,并创建验证集验证模型。
from parameters import TRAINING_DATA_DIR, CHECKPOINT_DIR, BATCH_SIZE, BUFFER_SIZE, EPOCHS
from subclassed_model import *
from loss_function import loss import os
import numpy as np
import matplotlib.pyplot as plt # load training data
training_X = np.fromfile(os.path.join(TRAINING_DATA_DIR, 'training_data_X.bin'), dtype=np.float64)
training_y = np.fromfile(os.path.join(TRAINING_DATA_DIR, 'training_data_y.bin'), dtype=np.float64) # plot training data
plt.scatter(training_X, training_y)
plt.show() # training dataset
training_dataset = tf.data.Dataset.from_tensor_slices((training_X, training_y)).batch(BATCH_SIZE).shuffle(BUFFER_SIZE) # model instance
encoder = Encoder()
decoder = Decoder() # optimizer
optimizer = tf.train.AdamOptimizer() # checkpoint
checkpoint_prefix_encoder = os.path.join(CHECKPOINT_DIR, 'encoder/', 'ckpt')
checkpoint_prefix_decoder = os.path.join(CHECKPOINT_DIR, 'decoder/', 'ckpt') if not os.path.exists(os.path.dirname(checkpoint_prefix_encoder)):
os.makedirs(os.path.dirname(checkpoint_prefix_encoder))
if not os.path.exists(os.path.dirname(checkpoint_prefix_decoder)):
os.makedirs(os.path.dirname(checkpoint_prefix_decoder)) # training step
for epoch in range(EPOCHS):
epoch_loss = 0 for (batch, (tx, ty)) in enumerate(training_dataset):
x = tf.cast(tx, tf.float32)
y = tf.cast(ty, tf.float32)
x = tf.expand_dims(x, axis=1) # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32)
y = tf.expand_dims(y, axis=1) # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32) with tf.GradientTape() as tape:
y_ = encoder(x) # tf.Tensor([...], shape=(BATCH_SIZE, 8), dtype=float32)
prediction = decoder(y_) # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32)
batch_loss = loss(real=y, pred=prediction) variables = encoder.variables + decoder.variables
grads = tape.gradient(batch_loss, variables)
optimizer.apply_gradients(zip(grads, variables), global_step=tf.train.get_or_create_global_step()) epoch_loss += batch_loss if (batch + 1) % 100 == 0:
print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
batch + 1,
batch_loss.numpy())) print('Epoch {} Loss {:.4f}'.format(epoch + 1,
epoch_loss / len(training_X))) if (epoch + 1) % 5 == 0:
encoder.save_weights(checkpoint_prefix_encoder)
decoder.save_weights(checkpoint_prefix_decoder) # create evaluation data
X = np.linspace(-1, 1, 3000)
np.random.shuffle(X) evaluation_X = tf.data.Dataset.from_tensor_slices(X).batch(BATCH_SIZE)
ey = [] for (batch, ex) in enumerate(evaluation_X):
x = tf.cast(ex, tf.float32)
x = tf.expand_dims(x, axis=1)
prediction = decoder(encoder(x))
for i in range(len(prediction.numpy())):
ey.append(prediction.numpy()[i]) plt.scatter(X, ey)
plt.show() # evaluate
eval_x = [[0.5]]
tensor_x = tf.convert_to_tensor(eval_x)
print(decoder(encoder(tensor_x)))
验证集评价结果如下图所示。
使用测试样例 eval_x 进行测试,测试结果如下。
tf.Tensor([[1.122567]], shape=(1, 1), dtype=float32)
在 evaluate.py 中,使用 model.load_weights() 恢复 Keras Subclassed Model 模型,并在验证集上进行验证,验证结果如下图所示。
from parameters import CHECKPOINT_DIR, BATCH_SIZE
from subclassed_model import * import os
import numpy as np
import matplotlib.pyplot as plt # load model
enc = Encoder()
dec = Decoder() enc.load_weights(tf.train.latest_checkpoint(os.path.join(CHECKPOINT_DIR, 'encoder/')))
dec.load_weights(tf.train.latest_checkpoint(os.path.join(CHECKPOINT_DIR, 'decoder/'))) # create evaluation data
X = np.linspace(-1, 1, 3000)
np.random.shuffle(X) evaluation_X = tf.data.Dataset.from_tensor_slices(X).batch(BATCH_SIZE)
ey = [] for (batch, ex) in enumerate(evaluation_X):
x = tf.cast(ex, tf.float32)
x = tf.expand_dims(x, axis=1)
prediction = dec(enc(x))
for i in range(len(prediction.numpy())):
ey.append(prediction.numpy()[i]) plt.scatter(X, ey)
plt.show() # evaluate
eval_x = [[0.5]]
tensor_x = tf.convert_to_tensor(eval_x)
print(dec(enc(tensor_x))) # model summary
enc.summary()
dec.summary()
使用测试样例 eval_x 进行测试,测试结果如下。
tf.Tensor([[1.122567]], shape=(1, 1), dtype=float32)
恢复模型的测试结果,与训练后模型的测试结果一致,且无需 build 模型。
版权声明:本文为博主原创文章,欢迎转载,转载请注明作者及原文出处!
[Tensorflow] 使用 model.save_weights() 保存 / 加载 Keras Subclassed Model的更多相关文章
- [Tensorflow] 使用 tf.train.Checkpoint() 保存 / 加载 keras subclassed model
在 subclassed_model.py 中,通过对 tf.keras.Model 进行子类化,设计了两个自定义模型. import tensorflow as tf tf.enable_eager ...
- Tensorflow 模型持久化saver及加载图结构
主要内容: 1. 直接保存,加载模型; (可以指定加载,保存的var_list) 2. 加载,保存指定变量的模型 3. slim加载模型使用 4. 加载模型图结构和参数等 tensorflow 恢复部 ...
- docker 保存 加载(导入 导出镜像
tensorflow 的docker镜像很大,pull一次由于墙经常失败.其实docker 可以将镜像导出再导入. 保存加载(tensorflow)镜像 1) 查看镜像 docker images 如 ...
- gensim Word2Vec 训练和使用(Model一定要加载到内存中,节省时间!!!)
训练模型利用gensim.models.Word2Vec(sentences)建立词向量模型该构造函数执行了三个步骤:建立一个空的模型对象,遍历一次语料库建立词典,第二次遍历语料库建立神经网络模型可以 ...
- 优化tableView加载cell与model的过程
优化tableView加载cell与model的过程 效果图 说明 1. 用多态的特性来优化tableView加载cell与model的过程 2. swift写起来果然要比Objective-C简洁了 ...
- [Tensorflow] 使用 model.save_weights() 保存 Keras Subclassed Model
import numpy as np import matplotlib.pyplot as plt import os import time import tensorflow as tf tf. ...
- Tensorflow 2.0 datasets数据加载
导入包 import tensorflow as tf from tensorflow import keras 加载数据 tensorflow可以调用keras自带的datasets,很方便,就是有 ...
- C# DataGridVie利用model特性动态加载列
今天闲来无事看到ORm的特性映射sql语句.我就想到datagridview也可以用这个来动态添加列.这样就不用每次都去界面上点开界面填列了. 代码简漏希望有人看到了能指点一二. 先定义好Datagr ...
- docker 保存,加载,导入,导出 命令
持久化docker的镜像或容器的方法 docker的镜像和容器可以有两种方式来导出 docker save #ID or #Name docker export #ID or #Name docker ...
随机推荐
- linux实操_rpm包和yum包
rpm包的简单查询指令: 查询已安装的rpm列表 rpm -qa | grep xxx 查询火狐浏览器 查询安装的rpm包软件的信息 查询rpm软件包的文件安装在哪里 查询文件属于哪个软件包 卸载rp ...
- WCF Windows基础通信
概述 WCF,Windows Communication Foundation ,Windows通信基础, 面向服务的架构,Service Orientation Architechture=SOP ...
- 将 Python 程序打包成 .exe 文件
1.简介 做了一个excel的风控模板,里面含有宏,我用python的第三方xlwings部署到linux后发现,linux环境并不支持xlwings. Python 程序都是脚本的方式,一般是在解析 ...
- 洛谷P1339 热浪【最短路】
题目:https://www.luogu.org/problemnew/show/P1339 题意:给定一张图,问起点到终点的最短路. 思路:dijkstra板子题. 很久没有写最短路了.总结一下di ...
- 配置IIS使其支持APK文件的下载
在管理工具里打开Internet 信息服务(IIS)管理器.然后选择需要配置的网站. 右侧的界面中会显示该网站的所有功能配置,我们选择并点击进入“MIME类型” 在左侧的操作区选择点击“添加”MI ...
- java 下拉控件 转自 http://www.cnblogs.com/lhb25/p/form-enhanced-with-javascript-three.html
表单元素让人爱恨交加.作为网页最重要的组成部分,表单几乎无处不在,从简单的邮件订阅.登陆注册到复杂的需要多页填写的信息提交功能,表单都让开发者花费了大量的时间和精力去处理,以期实现好用又漂亮的表单功能 ...
- Java进阶知识19 Struts2和Spring整合在一起
1.概述 1.Spring负责对象创建 2.Struts2负责用Action处理请求 3.整合的关键点:让Struts2框架Action对象的创建交给Spring完成. 2.整合实例 需要用到的 ...
- LibreOJ #110. 乘法逆元
二次联通门 : LibreOJ #110. 乘法逆元 /* LibreOJ #110. 乘法逆元 求一个数在模意义下的所有逆元 */ #include <cstdio> void read ...
- Ubuntu14.04 gzip failed file too large
使用gzip解压一个oracle rman备份集时报错:File too large.gizp -d cosp_db_full.tar.gzgzip: cosp_db_full.tar:File to ...
- Spring框架IOC解说
控制反转——Spring通过一种称作控制反转(IoC)的技术促进了松耦合.当应用了IoC,一个对象依赖的其它对象会通过被动的方式传递进来,而不是这个对象自己创建或者查找依赖对象.你可以认为IoC与JN ...