Using Tensorflow SavedModel Format to Save and Do Predictions
We are now trying to deploy our Deep Learning model onto Google Cloud. It is required to use Google Function to trigger the Deep Learning predictions. However, when pre-trained models are stored on cloud, it is impossible to get the exact directory path and restore the tensorflow session like what we did on local machine.
So we turn to use SavedModel, which is quite like a 'Prediction Mode' of tensorflow. According to official turotial: a SavedModel contains a complete TensorFlow program, including weights and computation. It does not require the original model building code to run, which makes it useful for sharing or deploying.
The Definition of our graph, just here to show the input and output tensors:
'''RNN Model Definition'''
tf.reset_default_graph()
''''''
#define inputs
tf_x = tf.placeholder(tf.float32, [None, window_size,1],name='x')
tf_y = tf.placeholder(tf.int32, [None, 2],name='y') cells = [tf.keras.layers.LSTMCell(units=n) for n in num_units]
stacked_rnn_cell = tf.keras.layers.StackedRNNCells(cells)
outputs, (h_c, h_n) = tf.nn.dynamic_rnn(
stacked_rnn_cell, # cell you have chosen
tf_x, # input
initial_state=None, # the initial hidden state
dtype=tf.float32, # must given if set initial_state = None
time_major=False, # False: (batch, time step, input); True: (time step, batch, input)
)
l1 = tf.layers.dense(outputs[:, -1, :],32,activation=tf.nn.relu,name='l1')
l2 = tf.layers.dense(l1,8,activation=tf.nn.relu,name='l6')
pred = tf.layers.dense(l2,2,activation=tf.nn.relu,name='pred') with tf.name_scope('loss'):
cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=tf_y, logits=pred)
loss = tf.reduce_mean(cross_entropy)
tf.summary.scalar("loss",tensor=loss)
train_op = tf.train.AdamOptimizer(LR).minimize(loss)
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(tf_y, axis=1), tf.argmax(pred, axis=1)), tf.float32)) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
saver = tf.train.Saver()
Train and Save the model, we use simple_save:
sess = tf.Session()
sess.run(init_op) for i in range(0,n):
sess.run(train_op,{tf_x:batch_X , tf_y:batch_y})
...
tf.saved_model.simple_save(sess, 'simple_save/model', \
inputs={"x": tf_x},outputs={"pred": pred})
sess.close()
Restore and Predict:
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, ["serve"], 'simple_save_test/model')
batch = sess.run('pred/Relu:0',feed_dict={'x:0':dataX.reshape([-1,24,1])})
print(batch)
Reference:
medium post: https://medium.com/@jsflo.dev/saving-and-loading-a-tensorflow-model-using-the-savedmodel-api-17645576527
The official tutorial of Tensorflow: https://www.tensorflow.org/guide/saved_model
Using Tensorflow SavedModel Format to Save and Do Predictions的更多相关文章
- [Tool] Enable Prettier in VSCode as Format on Save and add config files to gitingore
First of all, install Prettier extension: "Pettier - Code formatter". The open the VSCode ...
- vs code的使用(一) Format On Paste/Format On Save/ Format On Type
很多经典的问题可以搜索出来,但是一些很小的问题网上却没有答案 (这是最令人发狂的,这么简单,网上居然连个相关的信息都没有给出) (就比如我想保存后自动格式化,但网上的大部分都是如何取消保存后自动格式化 ...
- 135、TensorFlow SavedModel工具类的使用
# SavedModelBuilder 类提供了保存多个MetaGraphDef的功能 # MetaGraph是一个数据流图,加上它的关联变量,资产和标签 # 一个MetaGraphDef是一个协议缓 ...
- tensorflow 2.0 技巧 | 自定义tf.keras.Model的坑
自定义tf.keras.Model需要注意的点 model.save() subclass Model 是不能直接save的,save成.h5,但是能够save_weights,或者save_form ...
- Run Your Tensorflow Deep Learning Models on Google AI
People commonly tend to put much effort on hyperparameter tuning and training while using Tensoflow& ...
- Tensorflow 模型线上部署
获取源码,请移步笔者的github: tensorflow-serving-tutorial 由于python的灵活性和完备的生态库,使得其成为实现.验证ML算法的不二之选.但是工业界要将模型部署到生 ...
- Tensorflow 2.x入门教程
前言 至于为什么写这个教程,首先是为了自己学习做个记录,其次是因为Tensorflow的API写的很好,但是他的教程写的太乱了,不适合新手学习.tensorflow 1 和tensorflow 2 有 ...
- Tensorflow应用之LSTM
学习RNN时原理理解起来不难,但是用TensorFlow去实现时被它各种数据的shape弄得晕头转向.现在就结合一个情感分析的案例来了解一下LSTM的操作流程. 一.深度学习在自然语言处理中的应用 自 ...
- 基于Spark和Tensorflow构建DCN模型进行CTR预测
实验介绍 数据采用Criteo Display Ads.这个数据一共11G,有13个integer features,26个categorical features. Spark 由于数据比较大,且只 ...
随机推荐
- SCAU 2015 GDCPC team_training0
A.题目:http://acm.timus.ru/problem.aspx?space=1&num=2024 题意:求一个包含K个不同字符的集合的最大长度,还有组成这个长度的集合的个数 做法: ...
- 修改url,
第一种场景: 无论url怎么变,表单里面的url始终不变 http://127.0.0.1:8000/CC/indexssssssssssssssssss/ url(r'^indexsssssssss ...
- C# asp.net XML格式的字符串显示不全
前台显示XML字符串显示不全 后台XML字符串使用<xmp></xmp>将XML格式字符串括起来
- 在navcat中清空数据后,设置id归零方法
写后台完成后,需要清空Mysql数据库中的测试数据,但是后面新增的数据,一直是以原来所删除数据的最大id为增量基本,比如,对于一些id敏感的项,十分不便,如图 原有10条数据,清空后,新增一两条,手动 ...
- tomcat闪屏是jdk JAVA_HOEM环境变量配置问题
JAVA_HOME=D:\jdk.18
- Web学习之CSS总结
银角大王武Sir的博客地址 1.positoin属性固定元素的定位类型 说明:这个属性定义建立元素布局所用的定位机制.任何元素都可以定位,不过绝对或固定元素会生成一个块级框,而无论该元素是什么类型.相 ...
- ASE Code Search
重现基线模型 Hamel's model 基线模型原理 如何实现semantic search?在已有数据库的基础上,衡量一个句子和每段代码的相关性再进行排序,选出最优代码片段即可实现一个通用的cod ...
- express 获取post 请求参数
在 Express 中没有内置获取表单 POST 请求体的 API , 我们需要添加第三方插件库 安装: npm install --save body-parser 配置: var bodyPars ...
- 基于 Ansible 的 ELK 部署说明
ELK-Ansible使用手册 ELK-Ansible 是基于 Ansible 的 Playbooks 研发的 ELK集群部署工具.本文将介绍如何使用 ELK-Ansible 快速部署 ELK 集群. ...
- 专家告诉你!如何避免黑客BGP劫持?
BGP前缀劫持是针对Internet组织的持久威胁,原因是域间路由系统缺乏授权和身份验证机制. 仅在2017年,数千起路由事件导致代价高昂的中断和信息拦截,而问题的确切程度未知.尽管在过去20年中已经 ...