在 parameters.py 中,定义了各类参数。

 # training data directory
TRAINING_DATA_DIR = './data/' # checkpoint directory
CHECKPOINT_DIR = './training_checkpoints/' # training details
BATCH_SIZE = 16
BUFFER_SIZE = 128
EPOCHS = 15

在 numpy_dataset.py 中,创建了 5000 组训练数据集,模拟 y = x^3 + 1,并二进制格式写入文件。

 from parameters import TRAINING_DATA_DIR

 import numpy as np
import matplotlib.pyplot as plt
import os # create training data
X = np.linspace(-1, 1, 5000)
np.random.shuffle(X)
y = X ** 3 + 1 + np.random.normal(0, 0.01, (5000,)) # plot training data
plt.scatter(X, y)
plt.show() # save data
if not os.path.exists(TRAINING_DATA_DIR):
os.makedirs(TRAINING_DATA_DIR) X.tofile(os.path.join(TRAINING_DATA_DIR + 'training_data_X.bin'))
y.tofile(os.path.join(TRAINING_DATA_DIR + 'training_data_y.bin'))


在 subclassed_model.py 中,通过对 tf.keras.models.Model 进行子类化,设计了两个自定义模型。

 import tensorflow as tf
tf.enable_eager_execution() # model definition
class Encoder(tf.keras.models.Model):
def __init__(self):
super(Encoder, self).__init__()
self.fc1 = tf.keras.layers.Dense(units=16, activation='relu')
self.fc2 = tf.keras.layers.Dense(units=8, activation='relu') def call(self, inputs):
r = self.fc1(inputs)
return self.fc2(r) class Decoder(tf.keras.models.Model):
def __init__(self):
super(Decoder, self).__init__()
self.fc = tf.keras.layers.Dense(units=1, activation=None) def call(self, inputs):
return self.fc(inputs)

在 loss_function.py 中,定义了损失函数。

 import tensorflow as tf
tf.enable_eager_execution() def loss(real, pred):
return tf.losses.mean_squared_error(labels=real, predictions=pred)

在 training.py 中,使用在 numpy_dataset.py 中创建的数据集训练模型,之后使用 model.save_weights() 保存 Keras Subclassed Model 模型,并创建验证集验证模型。

 from parameters import TRAINING_DATA_DIR, CHECKPOINT_DIR, BATCH_SIZE, BUFFER_SIZE, EPOCHS
from subclassed_model import *
from loss_function import loss import os
import numpy as np
import matplotlib.pyplot as plt # load training data
training_X = np.fromfile(os.path.join(TRAINING_DATA_DIR, 'training_data_X.bin'), dtype=np.float64)
training_y = np.fromfile(os.path.join(TRAINING_DATA_DIR, 'training_data_y.bin'), dtype=np.float64) # plot training data
plt.scatter(training_X, training_y)
plt.show() # training dataset
training_dataset = tf.data.Dataset.from_tensor_slices((training_X, training_y)).batch(BATCH_SIZE).shuffle(BUFFER_SIZE) # model instance
encoder = Encoder()
decoder = Decoder() # optimizer
optimizer = tf.train.AdamOptimizer() # checkpoint
checkpoint_prefix_encoder = os.path.join(CHECKPOINT_DIR, 'encoder/', 'ckpt')
checkpoint_prefix_decoder = os.path.join(CHECKPOINT_DIR, 'decoder/', 'ckpt') if not os.path.exists(os.path.dirname(checkpoint_prefix_encoder)):
os.makedirs(os.path.dirname(checkpoint_prefix_encoder))
if not os.path.exists(os.path.dirname(checkpoint_prefix_decoder)):
os.makedirs(os.path.dirname(checkpoint_prefix_decoder)) # training step
for epoch in range(EPOCHS):
epoch_loss = 0 for (batch, (tx, ty)) in enumerate(training_dataset):
x = tf.cast(tx, tf.float32)
y = tf.cast(ty, tf.float32)
x = tf.expand_dims(x, axis=1) # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32)
y = tf.expand_dims(y, axis=1) # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32) with tf.GradientTape() as tape:
y_ = encoder(x) # tf.Tensor([...], shape=(BATCH_SIZE, 8), dtype=float32)
prediction = decoder(y_) # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32)
batch_loss = loss(real=y, pred=prediction) variables = encoder.variables + decoder.variables
grads = tape.gradient(batch_loss, variables)
optimizer.apply_gradients(zip(grads, variables), global_step=tf.train.get_or_create_global_step()) epoch_loss += batch_loss if (batch + 1) % 100 == 0:
print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
batch + 1,
batch_loss.numpy())) print('Epoch {} Loss {:.4f}'.format(epoch + 1,
epoch_loss / len(training_X))) if (epoch + 1) % 5 == 0:
encoder.save_weights(checkpoint_prefix_encoder)
decoder.save_weights(checkpoint_prefix_decoder) # create evaluation data
X = np.linspace(-1, 1, 3000)
np.random.shuffle(X) evaluation_X = tf.data.Dataset.from_tensor_slices(X).batch(BATCH_SIZE)
ey = [] for (batch, ex) in enumerate(evaluation_X):
x = tf.cast(ex, tf.float32)
x = tf.expand_dims(x, axis=1)
prediction = decoder(encoder(x))
for i in range(len(prediction.numpy())):
ey.append(prediction.numpy()[i]) plt.scatter(X, ey)
plt.show() # evaluate
eval_x = [[0.5]]
tensor_x = tf.convert_to_tensor(eval_x)
print(decoder(encoder(tensor_x)))

验证集评价结果如下图所示。

使用测试样例 eval_x 进行测试,测试结果如下。

tf.Tensor([[1.122567]], shape=(1, 1), dtype=float32)

在 evaluate.py 中,使用 model.load_weights() 恢复 Keras Subclassed Model 模型,并在验证集上进行验证,验证结果如下图所示。

 from parameters import CHECKPOINT_DIR, BATCH_SIZE
from subclassed_model import * import os
import numpy as np
import matplotlib.pyplot as plt # load model
enc = Encoder()
dec = Decoder() enc.load_weights(tf.train.latest_checkpoint(os.path.join(CHECKPOINT_DIR, 'encoder/')))
dec.load_weights(tf.train.latest_checkpoint(os.path.join(CHECKPOINT_DIR, 'decoder/'))) # create evaluation data
X = np.linspace(-1, 1, 3000)
np.random.shuffle(X) evaluation_X = tf.data.Dataset.from_tensor_slices(X).batch(BATCH_SIZE)
ey = [] for (batch, ex) in enumerate(evaluation_X):
x = tf.cast(ex, tf.float32)
x = tf.expand_dims(x, axis=1)
prediction = dec(enc(x))
for i in range(len(prediction.numpy())):
ey.append(prediction.numpy()[i]) plt.scatter(X, ey)
plt.show() # evaluate
eval_x = [[0.5]]
tensor_x = tf.convert_to_tensor(eval_x)
print(dec(enc(tensor_x))) # model summary
enc.summary()
dec.summary()

使用测试样例 eval_x 进行测试,测试结果如下。

tf.Tensor([[1.122567]], shape=(1, 1), dtype=float32)

恢复模型的测试结果,与训练后模型的测试结果一致,且无需 build 模型。


版权声明:本文为博主原创文章,欢迎转载,转载请注明作者及原文出处!

[Tensorflow] 使用 model.save_weights() 保存 / 加载 Keras Subclassed Model的更多相关文章

  1. [Tensorflow] 使用 tf.train.Checkpoint() 保存 / 加载 keras subclassed model

    在 subclassed_model.py 中,通过对 tf.keras.Model 进行子类化,设计了两个自定义模型. import tensorflow as tf tf.enable_eager ...

  2. Tensorflow 模型持久化saver及加载图结构

    主要内容: 1. 直接保存,加载模型; (可以指定加载,保存的var_list) 2. 加载,保存指定变量的模型 3. slim加载模型使用 4. 加载模型图结构和参数等 tensorflow 恢复部 ...

  3. docker 保存 加载(导入 导出镜像

    tensorflow 的docker镜像很大,pull一次由于墙经常失败.其实docker 可以将镜像导出再导入. 保存加载(tensorflow)镜像 1) 查看镜像 docker images 如 ...

  4. gensim Word2Vec 训练和使用(Model一定要加载到内存中,节省时间!!!)

    训练模型利用gensim.models.Word2Vec(sentences)建立词向量模型该构造函数执行了三个步骤:建立一个空的模型对象,遍历一次语料库建立词典,第二次遍历语料库建立神经网络模型可以 ...

  5. 优化tableView加载cell与model的过程

    优化tableView加载cell与model的过程 效果图 说明 1. 用多态的特性来优化tableView加载cell与model的过程 2. swift写起来果然要比Objective-C简洁了 ...

  6. [Tensorflow] 使用 model.save_weights() 保存 Keras Subclassed Model

    import numpy as np import matplotlib.pyplot as plt import os import time import tensorflow as tf tf. ...

  7. Tensorflow 2.0 datasets数据加载

    导入包 import tensorflow as tf from tensorflow import keras 加载数据 tensorflow可以调用keras自带的datasets,很方便,就是有 ...

  8. C# DataGridVie利用model特性动态加载列

    今天闲来无事看到ORm的特性映射sql语句.我就想到datagridview也可以用这个来动态添加列.这样就不用每次都去界面上点开界面填列了. 代码简漏希望有人看到了能指点一二. 先定义好Datagr ...

  9. docker 保存,加载,导入,导出 命令

    持久化docker的镜像或容器的方法 docker的镜像和容器可以有两种方式来导出 docker save #ID or #Name docker export #ID or #Name docker ...

随机推荐

  1. Spring-使用JAVA的方式配置Spring-代理模式

    9.使用Java的方式配置Spring 我们现在要完全不使用Spring的xml配置了,全权交给Java来做! JavaConfig是Spring的一个子项目,在Spring4之后,它成为了一个核心功 ...

  2. Lua 学习之基础篇五<Lua OS 库>

    lua os库提供了简单的跟操作系统有关的功能 1.os.clock() 返回程序所运行使用的时间 local nowTime = os.clock() print("now time is ...

  3. 第69题:x的平方根

    一. 问题描述 实现 int sqrt(int x) 函数. 计算并返回 x 的平方根,其中 x 是非负整数. 由于返回类型是整数,结果只保留整数的部分,小数部分将被舍去. 示例 1: 输入: 4 输 ...

  4. Jquery的toggle()与trigger()方法

    我一直分不清楚toggle()与trigger()两个各自的作用,所以今天抽时间记录一些,以加深印象. 1.toggle() 定义和用法: toggle() 方法切换元素的可见状态.如果被选元素可见, ...

  5. python自动华 (十)

    Python自动化 [第十篇]:Python进阶-多进程/协程/事件驱动与Select\Poll\Epoll异步IO 本节内容: 多进程 协程 事件驱动与Select\Poll\Epoll异步IO   ...

  6. github 管理代码、笔记

    1.先注册github.com的账号官方网站: https://github.com/ 2.登录 3.创建仓库 二. 1.安装git 2.刚才我们已经在github上面创建了一个仓库,那么我们现在就在 ...

  7. [Luogu] 产生数

    题面:https://www.luogu.org/problemnew/show/P1037 题解:https://www.zybuluo.com/wsndy-xx/note/1145473

  8. mac 安装软件 显示信任任何来源

    “通用”里有时没有“任何来源”这个选项: 显示"任何来源"选项在控制台中执行: sudo spctl --master-disable 不显示"任何来源"选项( ...

  9. loadrunner11安装

    今天虚拟机里面装了下lr11,虚拟机版本是vm9.0,先在虚拟机里面装了windows2003,当然lr也是可以装在自己电脑上面的,但是最好是纯净的环境,由于我电脑东西比较多,所以我就装在虚拟机里面了 ...

  10. JQuery 行内编辑(即点即改)

    行内编辑 下面是详细的代码: <style> .dian { cursor: pointer; } </style> //这个让鼠标 移动到 span上 的时候 是一个小手 & ...