前一篇讲过环境的部署篇,这一次就讲讲从代码角度如何导出pb模型,如何进行服务调用。

1 hello world篇

部署完docker后,如果是cpu环境,可以直接拉取tensorflow/serving,如果是GPU环境则麻烦点,具体参考前一篇,这里就不再赘述了。

cpu版本的可以直接拉取tensorflow/serving,docker会自动拉取latest版本:

  1. docker pull tensorflow/serving

如果想要指定tensorflow的版本,可以去这里查看:https://hub.docker.com/r/tensorflow/serving/tags/

比如我需要的是1.12.0版本的tf,那么也可以拉取指定的版本:

  1. docker pull tensorflow/serving:1.12.0

拉取完镜像,需要下载一个hello world的程序代码。

  1. mkdir -p /tmp/tfserving
  2. cd /tmp/tfserving
  3. git clone https://github.com/tensorflow/serving

tensorflow/serving的github中有对应的测试模型,模型其实就是 y = 0.5 * x + 2。即输入一个数,输出是对应的y。

运行下面的命令,在docker中部署服务:

  1. docker run -p 8501:8501 --mount type=bind,source=/tmp/serving/tensorflow_serving/servables/tensorflow/testdata/saved_model_half_plus_two_cpu,target=/models/half_plus_two -e MODEL_NAME=half_plus_two -t tensorflow/serving &

上面的命令中,把/tmp/serving/tensorflow_serving/servables/tensorflow/testdata/saved_model_half_plus_two_cpu路径挂载到/models/half_plus_two,这样tensorflow_serving就可以加载models下的模型了,然后开放内部8501的http接口。

执行docker ps查看服务列表:

  1. ~ docker ps
  2. CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES
  3. 7decb4286057 tensorflow/serving "/usr/bin/tf_serving…" 7 seconds ago Up 6 seconds 8500/tcp, 0.0.0.0:8501->8501/tcp eager_dewdney

发送一个http请求测试一下:

  1. curl -d '{"instances": [1.0, 2.0, 5.0]}' -X POST http://localhost:8501/v1/models/half_plus_two:predict
  2. {
  3. "predictions": [2.5, 3.0, 4.5
  4. ]
  5. }%

2 mnist篇

由于前面的例子,serving工程下只有pb模型,没有模型的训练和导出,因此看不出其中的门道。这一部分就直接基于手写体识别的例子,展示一下如何从tensorflow训练代码导出模型,又如何通过grpc服务进行模型的调用。

训练和导出:

  1. #! /usr/bin/env python
  2. """
  3. 训练并导出Softmax回归模型,使用SaveModel导出训练模型并添加签名。
  4. """
  5. from __future__ import print_function
  6. import os
  7. import sys
  8. # This is a placeholder for a Google-internal import.
  9. import tensorflow as tf
  10. import ssl
  11. ssl._create_default_https_context = ssl._create_unverified_context
  12. import basic.mnist_input_data as mnist_input_data
  13. # 定义模型参数
  14. tf.app.flags.DEFINE_integer('training_iteration', 10, 'number of training iterations.')
  15. tf.app.flags.DEFINE_integer('model_version', 2, 'version number of the model.')
  16. tf.app.flags.DEFINE_string('work_dir', './tmp', 'Working directory.')
  17. FLAGS = tf.app.flags.FLAGS
  18. def main(_):
  19. # 参数校验
  20. # if len(sys.argv) < 2 or sys.argv[-1].startswith('-'):
  21. # print('Usage: mnist_saved_model.py [--training_iteration=x] '
  22. # '[--model_version=y] export_dir')
  23. # sys.exit(-1)
  24. # if FLAGS.training_iteration <= 0:
  25. # print('Please specify a positive value for training iteration.')
  26. # sys.exit(-1)
  27. # if FLAGS.model_version <= 0:
  28. # print('Please specify a positive value for version number.')
  29. # sys.exit(-1)
  30. # Train model
  31. print('Training model...')
  32. mnist = mnist_input_data.read_data_sets(FLAGS.work_dir, one_hot=True)
  33. sess = tf.InteractiveSession()
  34. serialized_tf_example = tf.placeholder(tf.string, name='tf_example')
  35. feature_configs = {'x': tf.FixedLenFeature(shape=[784], dtype=tf.float32), }
  36. tf_example = tf.parse_example(serialized_tf_example, feature_configs)
  37. x = tf.identity(tf_example['x'], name='x') # use tf.identity() to assign name
  38. y_ = tf.placeholder('float', shape=[None, 10])
  39. w = tf.Variable(tf.zeros([784, 10]))
  40. b = tf.Variable(tf.zeros([10]))
  41. sess.run(tf.global_variables_initializer())
  42. y = tf.nn.softmax(tf.matmul(x, w) + b, name='y')
  43. cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
  44. train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
  45. values, indices = tf.nn.top_k(y, 10)
  46. table = tf.contrib.lookup.index_to_string_table_from_tensor(
  47. tf.constant([str(i) for i in range(10)]))
  48. prediction_classes = table.lookup(tf.to_int64(indices))
  49. for _ in range(FLAGS.training_iteration):
  50. batch = mnist.train.next_batch(50)
  51. train_step.run(feed_dict={x: batch[0], y_: batch[1]})
  52. correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
  53. accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
  54. print('training accuracy %g' % sess.run(
  55. accuracy, feed_dict={
  56. x: mnist.test.images,
  57. y_: mnist.test.labels
  58. }))
  59. print('Done training!')
  60. # Export model
  61. # WARNING(break-tutorial-inline-code): The following code snippet is
  62. # in-lined in tutorials, please update tutorial documents accordingly
  63. # whenever code changes.
  64. # export_path_base = sys.argv[-1]
  65. export_path_base = "/Users/xingoo/PycharmProjects/ml-in-action/实践-tensorflow/01-官方文档-学习和使用ML/save_model"
  66. export_path = os.path.join(tf.compat.as_bytes(export_path_base), tf.compat.as_bytes(str(FLAGS.model_version)))
  67. print('Exporting trained model to', export_path)
  68. # 配置导出地址,创建SaveModel
  69. builder = tf.saved_model.builder.SavedModelBuilder(export_path)
  70. # Build the signature_def_map.
  71. # 创建TensorInfo,包含type,shape,name
  72. classification_inputs = tf.saved_model.utils.build_tensor_info(serialized_tf_example)
  73. classification_outputs_classes = tf.saved_model.utils.build_tensor_info(prediction_classes)
  74. classification_outputs_scores = tf.saved_model.utils.build_tensor_info(values)
  75. # 分类签名:算法类型+输入+输出(概率和名字)
  76. classification_signature = (
  77. tf.saved_model.signature_def_utils.build_signature_def(
  78. inputs={
  79. tf.saved_model.signature_constants.CLASSIFY_INPUTS:
  80. classification_inputs
  81. },
  82. outputs={
  83. tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES:
  84. classification_outputs_classes,
  85. tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES:
  86. classification_outputs_scores
  87. },
  88. method_name=tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME))
  89. tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
  90. tensor_info_y = tf.saved_model.utils.build_tensor_info(y)
  91. # 预测签名:输入的x和输出的y
  92. prediction_signature = (
  93. tf.saved_model.signature_def_utils.build_signature_def(
  94. inputs={'images': tensor_info_x},
  95. outputs={'scores': tensor_info_y},
  96. method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
  97. # 构建图和变量的信息:
  98. """
  99. sess 会话
  100. tags 标签,默认提供serving、train、eval、gpu、tpu
  101. signature_def_map 签名
  102. main_op 初始化?
  103. strip_default_attrs strip?
  104. """
  105. # predict_images就是服务调用的方法
  106. # serving_default是没有输入签名时,使用的方法
  107. builder.add_meta_graph_and_variables(
  108. sess, [tf.saved_model.tag_constants.SERVING],
  109. signature_def_map={
  110. 'predict_images':
  111. prediction_signature,
  112. tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
  113. classification_signature,
  114. },
  115. main_op=tf.tables_initializer(),
  116. strip_default_attrs=True)
  117. # 保存
  118. builder.save()
  119. print('Done exporting!')
  120. if __name__ == '__main__':
  121. tf.app.run()

执行后,在当前目录中就有一个save_model文件,保存了各个版本的pb模型文件。

然后基于grpc部署服务:

  1. docker run -p 8500:8500 --mount type=bind,source=/Users/xingoo/PycharmProjects/ml-in-action/01-实践-tensorflow/01-官方文档-学习和使用ML/save_model,target=/models/mnist -e MODEL_NAME=mnist -t tensorflow/serving &

服务部署成功,查看一下docker列表:

  1. ~ docker ps
  2. CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES
  3. 39a06cc35961 tensorflow/serving "/usr/bin/tf_serving…" 4 seconds ago Up 3 seconds 0.0.0.0:8500->8500/tcp, 8501/tcp hardcore_galileo

然后编写对应的client代码:

  1. import tensorflow as tf
  2. import basic.mnist_input_data as mnist_input_data
  3. import grpc
  4. import numpy as np
  5. import sys
  6. import threading
  7. from tensorflow_serving.apis import predict_pb2
  8. from tensorflow_serving.apis import prediction_service_pb2_grpc
  9. tf.app.flags.DEFINE_integer('concurrency', 1, 'maximum number of concurrent inference requests')
  10. tf.app.flags.DEFINE_integer('num_tests', 100, 'Number of test images')
  11. tf.app.flags.DEFINE_string('server', 'localhost:8500', 'PredictionService host:port')
  12. tf.app.flags.DEFINE_string('work_dir', './tmp', 'Working directory. ')
  13. FLAGS = tf.app.flags.FLAGS
  14. test_data_set = mnist_input_data.read_data_sets(FLAGS.work_dir).test
  15. channel = grpc.insecure_channel(FLAGS.server)
  16. stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
  17. class _ResultCounter(object):
  18. """Counter for the prediction results."""
  19. def __init__(self, num_tests, concurrency):
  20. self._num_tests = num_tests
  21. self._concurrency = concurrency
  22. self._error = 0
  23. self._done = 0
  24. self._active = 0
  25. self._condition = threading.Condition()
  26. def inc_error(self):
  27. with self._condition:
  28. self._error += 1
  29. def inc_done(self):
  30. with self._condition:
  31. self._done += 1
  32. self._condition.notify()
  33. def dec_active(self):
  34. with self._condition:
  35. self._active -= 1
  36. self._condition.notify()
  37. def get_error_rate(self):
  38. with self._condition:
  39. while self._done != self._num_tests:
  40. self._condition.wait()
  41. return self._error / float(self._num_tests)
  42. def throttle(self):
  43. with self._condition:
  44. while self._active == self._concurrency:
  45. self._condition.wait()
  46. self._active += 1
  47. def _create_rpc_callback(label, result_counter):
  48. def _callback(result_future):
  49. exception = result_future.exception()
  50. if exception:
  51. result_counter.inc_error()
  52. print(exception)
  53. else:
  54. response = np.array(result_future.result().outputs['scores'].float_val)
  55. prediction = np.argmax(response)
  56. sys.stdout.write("%s - %s\n" % (label, prediction))
  57. sys.stdout.flush()
  58. result_counter.inc_done()
  59. result_counter.dec_active()
  60. return _callback
  61. result_counter = _ResultCounter(FLAGS.num_tests, FLAGS.concurrency)
  62. for i in range(FLAGS.num_tests):
  63. request = predict_pb2.PredictRequest()
  64. request.model_spec.name = 'mnist'
  65. request.model_spec.signature_name = 'predict_images'
  66. image, label = test_data_set.next_batch(1)
  67. request.inputs['images'].CopyFrom(tf.contrib.util.make_tensor_proto(image[0], shape=[1, image[0].size]))
  68. result_counter.throttle()
  69. result_future = stub.Predict.future(request, 5.0) # 5 seconds
  70. result_future.add_done_callback(_create_rpc_callback(label[0], result_counter))
  71. print(result_counter.get_error_rate())

得到对应的输出:

  1. 3 - 3
  2. 6 - 6
  3. 9 - 9
  4. 3 - 3
  5. 1 - 1
  6. 4 - 9
  7. 1 - 5
  8. 7 - 9
  9. 6 - 6
  10. 9 - 9
  11. 0.0

深度学习Tensorflow生产环境部署(下·模型部署篇)的更多相关文章

  1. 深度学习Tensorflow生产环境部署(上·环境准备篇)

    最近在研究Tensorflow Serving生产环境部署,尤其是在做服务器GPU环境部署时,遇到了不少坑.特意总结一下,当做前车之鉴. 1 系统背景 系统是ubuntu16.04 ubuntu@ub ...

  2. linux服务器上配置进行kaggle比赛的深度学习tensorflow keras环境详细教程

    本文首发于个人博客https://kezunlin.me/post/6b505d27/,欢迎阅读最新内容! full guide tutorial to install and configure d ...

  3. 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识

    深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...

  4. 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别

    深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...

  5. 深度学习之Attention Model(注意力模型)

    1.Attention Model 概述 深度学习里的Attention model其实模拟的是人脑的注意力模型,举个例子来说,当我们观赏一幅画时,虽然我们可以看到整幅画的全貌,但是在我们深入仔细地观 ...

  6. [源码解析] 深度学习流水线并行 PipeDream(3)--- 转换模型

    [源码解析] 深度学习流水线并行 PipeDream(3)--- 转换模型 目录 [源码解析] 深度学习流水线并行 PipeDream(3)--- 转换模型 0x00 摘要 0x01 前言 1.1 改 ...

  7. 深度学习Tensorflow相关书籍推荐和PDF下载

    深度学习Tensorflow相关书籍推荐和PDF下载 baihualinxin关注 32018.03.28 10:46:16字数 481阅读 22,673 1.机器学习入门经典<统计学习方法&g ...

  8. 深度学习入门者的Python快速教程 - 基础篇

      5.1 Python简介 本章将介绍Python的最基本语法,以及一些和深度学习还有计算机视觉最相关的基本使用. 5.1.1 Python简史 Python是一门解释型的高级编程语言,特点是简单明 ...

  9. 在linux ubuntu下搭建深度学习/机器学习开发环境

    一.安装Anaconda 1.下载 下载地址为:https://www.anaconda.com/download/#linux 2.安装anaconda,执行命令: bash ~/Downloads ...

随机推荐

  1. How to change system keyboard keymap layout on CentOS 7 Linux

    The easiest way to swap between keymaps and thus temporarily set keys to different language by use o ...

  2. linux 学习之路:mkdir命令使用

    linux mkdir 命令 在当前目录下创建文件夹,当前账号需要保证目录下有写到权限. 1.命令格式 mkdir[选项]文件名 mkdir  创建目录文件 语法:mkdir [ -m Mode ] ...

  3. css -html-文档流

    首先先考虑一下什么是普通流?普通流就是正常的文档流,在HTML里面的写法就是从上到下,从左到右的排版布局. 例: <div id="01"></div>&l ...

  4. 常用jquery

    水果:<input type="checkbox" name="shuiGuo" value="2">苹果<input t ...

  5. React Native不同设备分辨率适配和设计稿尺寸单位px的适配

    React Native中使用的尺寸单位是dp(一种基于屏幕密度的抽象单位.在每英寸160点的显示器上,1dp = 1px),而设计师使用的是px, 这两种尺寸如何换算呢? 官方提供了PixelRat ...

  6. mysql 1055

    在 /etc/my.cnf 文件里加上如下: sql_mode=NO_ENGINE_SUBSTITUTION

  7. Vue+Webpack构建去哪儿APP_一.开发前准备

    一.开发前准备 1.node环境搭建 去node.js官网下载长期支持版本的node,采用全局安装,安装方式自行百度 网址:https://nodejs.org/zh-cn/ 安装后在cmd命令行运行 ...

  8. js电子表

    <!DOCTYPE html> <html> <head> <meta charset="UTF-8"> <title> ...

  9. Mac 系统搭建ThinkPHP3.2

    PHP3.2完整包目录 拷贝两个文件 index.php 和ThinkPHP目录到服务器目录中,我已经设置服务器目录与eclipse工作空间为同一个 创建TestThinkPHP 项目 Eclipse ...

  10. centos7制作本地yum源

    创建想要挂载的路径 mkdir /mnt/cdrom 挂载本地镜像到创建的目录 mount -t iso9660 /dev/cdrom /mnt/cdrom/ mount: /dev/sr0 is w ...