[阿里DIN] 模型保存,加载和使用
[阿里DIN] 模型保存,加载和使用
0x00 摘要
Deep Interest Network(DIN)是阿里妈妈精准定向检索及基础算法团队在2017年6月提出的。其针对电子商务领域(e-commerce industry)的CTR预估,重点在于充分利用/挖掘用户历史行为数据中的信息。
本系列文章会解读论文以及源码,顺便梳理一些深度学习相关概念和TensorFlow的实现。
本文是系列第 12 篇 :介绍DIN模型的保存,加载和使用。
0x01 TensorFlow模型
1.1 模型文件
TensorFlow模型会保存在checkpoint相关文件中。因为TensorFlow会将计算图的结构和图上参数取值分开保存,所以保存后在相关文件夹中会出现3个文件。
下面就是DIN,DIEN相关生成的文件,可以通过名称来判别。
checkpoint
ckpt_noshuffDIN3.data-00000-of-00001
ckpt_noshuffDIN3.meta
ckpt_noshuffDIN3.index
ckpt_noshuffDIEN3.data-00000-of-00001
ckpt_noshuffDIEN3.index
ckpt_noshuffDIEN3.meta
所以我们可以认为和保存的模型直接相关的是以下这四个文件:
checkpoint
文件保存了一个目录下所有的模型文件列表,这个文件是TensorFlow自动生成且自动维护的。在checkpoint
文件中维护了由一个TensorFlow持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也会从checkpoint
文件中删除。checkpoint
中内容的格式为CheckpointState Protocol Buffer..meta
文件 保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构。
TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由MetaGraphDef Protocol Buffer定义的。MetaGraphDef 中的内容构成了TensorFlow持久化时的第一个文件。保存MetaGraphDef 信息的文件默认以.meta为后缀名。- .index文件保存了当前参数名。
model.ckpt
文件保存了TensorFlow程序中每一个变量的取值,这个文件是通过SSTable格式存储的,可以大致理解为就是一个(key,value)列表。model.ckpt
文件中列表的第一行描述了文件的元信息,比如在这个文件中存储的变量列表。列表剩下的每一行保存了一个变量的片段,变量片段的信息是通过SavedSlice Protocol Buffer定义的。SavedSlice类型中保存了变量的名称、当前片段的信息以及变量取值。TensorFlow提供了tf.train.NewCheckpointReader
类来查看model.ckpt
文件中保存的变量信息。
1.2 freeze_graph
正如前文所述,tensorflow在训练过程中,通常不会将权重数据保存的格式文件里,反而是分开保存在一个叫checkpoint的检查点文件里,当初始化时,再通过模型文件里的变量Op节点来从checkoupoint文件读取数据并初始化变量。这种模型和权重数据分开保存的情况,使得发布产品时不是那么方便,所以便有了freeze_graph.py脚本文件用来将这两文件整合合并成一个文件。
freeze_graph.py是怎么做的呢?
- 它先加载模型文件;
- 提供checkpoint文件地址后,它从checkpoint文件读取权重数据初始化到模型里的权重变量;
- 将权重变量转换成权重常量 (因为常量能随模型一起保存在同一个文件里);
- 再通过指定的输出节点将没用于输出推理的Op节点从图中剥离掉;
- 使用tf.train.writegraph保存图,这个图会提供给freeze_graph使用;
- 再使用freeze_graph重新保存到指定的文件里;
0x02 DIN代码
因为 DIN 源码中没有实现此部分,所以我们需要自行添加。
2.1 输出结点
首先,在model.py中,需要声明输出结点。
def build_fcn_net(self, inp, use_dice = False):
.....
# 此处需要给 y_hat 添加一个name
self.y_hat = tf.nn.softmax(dnn3, name='final_output') + 0.00000001
2.2 保存函数
其次,需要添加一个保存函数,调用 freeze_graph 来进行保存。
需要注意几点:
- write_graph 的 as_text 参数默认是 True,我们这里设置为 False。有的环境如果设置为 True 会有问题;
- 因为write_graph 的 as_text 参数做了设置,所以freeze_graph的参数也做相应设置:
input_binary=True
; - input_checkpoint 参数需要针对DIN或者DIEN做相应调整;
具体代码如下:
def din_freeze_graph(sess):
# 模型持久化,将变量值固定
output_graph_def = convert_variables_to_constants(
sess=sess,
input_graph_def=sess.graph_def, # 等于:sess.graph_def
output_node_names=['final_output']) # 如果有多个输出节点,以逗号隔开
tf.train.write_graph(output_graph_def, 'dnn_best_model', 'model.pb', False)
freeze_graph.freeze_graph(
input_graph='./dnn_best_model/model.pb',
input_saver='',
input_binary=True,
input_checkpoint='./dnn_best_model/ckpt_noshuffDIN3',
output_node_names='final_output', # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0',
output_graph='./dnn_best_model/frozen_model.pb',
clear_devices=False,
initializer_nodes=''
)
2.2 调用保存
我们在train函数中,存储模型之后,进行调用。
def train(...):
if (iter % save_iter) == 0:
print('save model iter: %d' %(iter))
model.save(sess, model_path+"--"+str(iter))
freeze_graph(sess) # 此处调用
0x03 验证
3.1 加载
加载函数如下:
def load_graph(fz_gh_fn):
with tf.gfile.GFile(fz_gh_fn, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name="prefix" # 此处可以自己修改
)
return graph
调用加载函数如下,我们在加载之后,打印出图中对应节点:
graph = load_graph('./dnn_best_model/frozen_model.pb')
for op in graph.get_operations():
print(op.name, op.values())
从打印结果我们可以看出来,有些op是Inputs相关,final_output节点则是我们之前设定的。
(u'prefix/Inputs/mid_his_batch_ph', (<tf.Tensor 'prefix/Inputs/mid_his_batch_ph:0' shape=(?, ?) dtype=int32>,))
(u'prefix/Inputs/cat_his_batch_ph', (<tf.Tensor 'prefix/Inputs/cat_his_batch_ph:0' shape=(?, ?) dtype=int32>,))
(u'prefix/Inputs/uid_batch_ph', (<tf.Tensor 'prefix/Inputs/uid_batch_ph:0' shape=(?,) dtype=int32>,))
(u'prefix/Inputs/mid_batch_ph', (<tf.Tensor 'prefix/Inputs/mid_batch_ph:0' shape=(?,) dtype=int32>,))
(u'prefix/Inputs/cat_batch_ph', (<tf.Tensor 'prefix/Inputs/cat_batch_ph:0' shape=(?,) dtype=int32>,))
(u'prefix/Inputs/mask', (<tf.Tensor 'prefix/Inputs/mask:0' shape=(?, ?) dtype=float32>,))
(u'prefix/Inputs/seq_len_ph', (<tf.Tensor 'prefix/Inputs/seq_len_ph:0' shape=(?,)
......
(u'prefix/final_output', (<tf.Tensor 'prefix/final_output:0' shape=(?, 2) dtype=float32>,))
3.2 验证
验证数据可以自己炮制,或者就是从测试数据中取出两条即可,我们的验证文件名字为 local_predict_splitByUser
。
0 A3BI7R43VUZ1TY B00JNHU0T2 Literature & Fiction 0989464105B00B01691C14778097321608442845 BooksLiterature & FictionBooksBooks
1 A3BI7R43VUZ1TY 0989464121 Books 0989464105B00B01691C14778097321608442845 BooksLiterature & FictionBooksBooks
验证代码如下,其中feed_dict如何填充,需要根据上节的输出结果来进行相关配置。
def predict(
graph,
predict_file = "local_predict_splitByUser",
uid_voc = "uid_voc.pkl",
mid_voc = "mid_voc.pkl",
cat_voc = "cat_voc.pkl",
batch_size = 128,
maxlen = 100):
gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options), graph = graph) as sess:
predict_data = DataIterator(predict_file, uid_voc, mid_voc, cat_voc, batch_size, maxlen)
for src, tgt in predict_data:
uids, mids, cats, mid_his, cat_his, mid_mask, target, sl, noclk_mids, noclk_cats = prepare_data(src, tgt, maxlen, return_neg=True)
final_output = "prefix/final_output:0"
feed_dict = {
'prefix/Inputs/mid_his_batch_ph:0' : mid_his,
'prefix/Inputs/cat_his_batch_ph:0':cat_his,
'prefix/Inputs/uid_batch_ph:0':uids,
'prefix/Inputs/mid_batch_ph:0':mids,
'prefix/Inputs/cat_batch_ph:0':cats,
'prefix/Inputs/mask:0':mid_mask,
'prefix/Inputs/seq_len_ph:0':sl
}
y_hat = sess.run(final_output, feed_dict = feed_dict)
print(y_hat)
预测结果如下:
[[0.95820646 0.04179354]
[0.09431148 0.9056886 ]]
3.3 为什么要在tensor后面加:0
在上节中,我们可以看到在feed_dict之中,给定的tensor名字后面都带了 :0
。
feed_dict = {
'prefix/Inputs/mid_his_batch_ph:0' : mid_his,
'prefix/Inputs/cat_his_batch_ph:0':cat_his,
'prefix/Inputs/uid_batch_ph:0':uids,
'prefix/Inputs/mid_batch_ph:0':mids,
'prefix/Inputs/cat_batch_ph:0':cats,
'prefix/Inputs/mask:0':mid_mask,
'prefix/Inputs/seq_len_ph:0':sl
}
这里需要注意,TensorFlow的运算结果不是一个数,而是一个张量结构。张量的命名形式:“node : src_output”
,node为节点的名称,src_output 表示当前张量来自来自节点的第几个输出。
在我们这里,prefix/Inputs/mid_batch_ph
是操作节点,prefix/Inputs/mid_batch_ph:0
才是变量的名字。冒号后面的数字编号表示这个张量是计算节点上的第几个结果。
0xFF 参考
[深度学习] TensorFlow中模型的freeze_graph
TensorFlow模型冷冻以及为什么tensor名字要加:0
tensorflow实战笔记(19)----使用freeze_graph.py将ckpt转为pb文件
Tensorflow-GraphDef、MetaGraph、CheckPoint
[阿里DIN] 模型保存,加载和使用的更多相关文章
- PyTorch保存模型与加载模型+Finetune预训练模型使用
Pytorch 保存模型与加载模型 PyTorch之保存加载模型 参数初始化参 数的初始化其实就是对参数赋值.而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了da ...
- [Pytorch]Pytorch 保存模型与加载模型(转)
转自:知乎 目录: 保存模型与加载模型 冻结一部分参数,训练另一部分参数 采用不同的学习率进行训练 1.保存模型与加载 简单的保存与加载方法: # 保存整个网络 torch.save(net, PAT ...
- 【4】TensorFlow光速入门-保存模型及加载模型并使用
本文地址:https://www.cnblogs.com/tujia/p/13862360.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...
- 学习笔记TF049:TensorFlow 模型存储加载、队列线程、加载数据、自定义操作
生成检查点文件(chekpoint file),扩展名.ckpt,tf.train.Saver对象调用Saver.save()生成.包含权重和其他程序定义变量,不包含图结构.另一程序使用,需要重新创建 ...
- TensorFlow模型保存和加载方法
TensorFlow模型保存和加载方法 模型保存 import tensorflow as tf w1 = tf.Variable(tf.constant(2.0, shape=[1]), name= ...
- keras中的模型保存和加载
tensorflow中的模型常常是protobuf格式,这种格式既可以是二进制也可以是文本.keras模型保存和加载与tensorflow不同,keras中的模型保存和加载往往是保存成hdf5格式. ...
- 从头学pytorch(十二):模型保存和加载
模型读取和存储 总结下来,就是几个函数 torch.load()/torch.save() 通过python的pickle完成序列化与反序列化.完成内存<-->磁盘转换. Module.s ...
- docker 保存 加载(导入 导出镜像
tensorflow 的docker镜像很大,pull一次由于墙经常失败.其实docker 可以将镜像导出再导入. 保存加载(tensorflow)镜像 1) 查看镜像 docker images 如 ...
- xBIM 实战03 使用WPF技术实现IFC模型的加载与浏览
系列目录 [已更新最新开发文章,点击查看详细] WPF应用程序在底层使用 DirectX ,无论设计复杂的3D图形(这是 DirectX 的特长所在)还是绘制简单的按钮与文本,所有绘图工作都是 ...
- caffe 模型的加载
在caffe中模型的加载是通过这个函数加载的: void Net<Dtype>::CopyTrainedLayersFrom(const string trained_filename)
随机推荐
- Python之解析配置文件
[.env] 1) 使用python-dotenv 安装: pip install python-dotenv 示例配置文件: ADMIN_HOST = https://uat-rm-gwaaa.cn ...
- Blazor 组件库 BootstrapBlazor 中Tag组件介绍
Tag组件的样子 Tag组件的介绍 Tag组件是一个非常简单的组件. <Tag Icon="fa fa-fw fa-check-circle" Color="Col ...
- 设计模式【3.1】-- 浅谈代理模式之静态、动态、cglib代理
代理模式:为其他对象提供一种代理以控制对这个对象的访问,在某种情况下,一个对象不适合或者不能够直接引用另一个对象,而代理对象可以在客户类和目标对象之间起到中介的作用. 可以这么理解:使用代理对象,是为 ...
- ZCMU-1179
我的错误: 明知道是大数问题但不是不想写数组或者字符串的结构. 思路 网上查阅后发现可以使用JAVA的大数类型做. 若不使用JAVA则就是整型数组或者字符串的情况. 将a^b结果放在数组当中,实时更新 ...
- COSBrowser 文件夹分享——多端文件实时共享
您还在为临时分享某个文件夹而烦恼吗? 您是否对授权的复杂度感到震惊? 关于存储桶 Policy 权限,您是否很迷茫,不知如何设置? 不用着急,用 COSBrowser 开始文件夹分享,一键简化分享 ...
- 用触摸屏辅助3D建模
现在在触摸屏上进行3D建模的软件很多,这里说的是另一个概念. 我的设想是将触摸屏当做一个带有 ViewPort 的输入设备. 比如 Blender 在建模时,我们可以通过一个外接的触摸屏从另一个角度观 ...
- d2js 中实现 memcached 共享 session 的过程
https://github.com/inshua/d2js/blob/master/WebContent/guide/memcached-session.md 基于 https://github.c ...
- 拦截烂SQL,解读GaussDB(DWS)查询过滤器过滤规则原理
本文分享自华为云社区<GaussDB(DWS)查询过滤器过滤规则原理与使用介绍>,作者: 清道夫. 1. 前言 适用版本:[9.1.0.100(及以上)] 查询过滤器在9.1.0.100之 ...
- JVM简介—1.Java内存区域
大纲 1.运行时数据区的介绍 2.运行时数据区各区域的作用 3.各个版本内存区域的变化 4.直接内存的使用和作用 5.站在线程的角度看Java内存区域 6.深入分析堆和栈的区别 7.方法的出入栈和栈上 ...
- Qt音视频开发37-识别鼠标按下像素坐标
一.前言 在和视频交互过程中,用户一般需要在显示视频的通道上点击对应的区域,弹出对应的操作按钮,将当前点击的区域或者绘制的多边形区域坐标或者坐标点集合,发送出去,通知其他设备进行处理.比如识别到很多人 ...