最近在学习tensorflow serving,但是就这样平淡看代码可能觉得不能真正思考,就想着写个文章看看,自己写给自己的,就像自己对着镜子演讲一样,写个文章也像自己给自己讲课,这样思考的比较深,学到的也比较多,有错欢迎揪出,

minist_saved_model.py 是tensorflow的第一个例子,里面有很多serving的知识,还不了解,现在看。下面是它的入口函数,然后直接跳转到main

if __name__ == '__main__':
tf.app.run()

在main函数里:

首先,是对一些参数取值等的合理性校验:

def main(_):
if len(sys.argv) < 2 or sys.argv[-1].startswith('-'):
print('Usage: mnist_export.py [--training_iteration=x] '
'[--model_version=y] export_dir')
sys.exit(-1)
if FLAGS.training_iteration <= 0:
print 'Please specify a positive value for training iteration.'
sys.exit(-1)
if FLAGS.model_version <= 0:
print 'Please specify a positive value for version number.'
sys.exit(-1)

然后,就开始train model,既然是代码解读加上自己能力还比较弱,简单的我得解读呀,牛人绕道。。。

# Train model
print 'Training model...'
#输入minist数据,这个常见的,里面的源码就是查看有没有数据,没有就在网上
下载下来,然后封装成一个个batch
mnist = mnist_input_data.read_data_sets(FLAGS.work_dir, one_hot=True) #这是创建一个session,Session是Graph和执行者之间的媒介,Session.run()实际
上将graph、fetches、feed_dict序列化到字节数组中进行计算
sess = tf.InteractiveSession() #定义一个占位符,为以后数据等输入留好接口
serialized_tf_example = tf.placeholder(tf.string, name='tf_example') #feature_configs 顾名思义,是特征配置,从形式上看这是一个字典,字典中
初始化key为‘x’,value 是 tf.FixedLenFeature(shape=[784], dtype=tf.float32)的返
回值,而该函数的作用是解析定长的输入特征feature相关配置
feature_configs = {'x': tf.FixedLenFeature(shape=[784], dtype=tf.float32),} #parse_example 常用于稀疏输入数据
tf_example = tf.parse_example(serialized_tf_example, feature_configs) #
x = tf.identity(tf_example['x'], name='x') # use tf.identity() to assign name #因为输出是10类,所以y_设置成None×10
y_ = tf.placeholder('float', shape=[None, 10]) #定义权重变量
w = tf.Variable(tf.zeros([784, 10])) #定义偏置变量
b = tf.Variable(tf.zeros([10])) #对定义的变量进行参数初始化
sess.run(tf.global_variables_initializer()) #对输入的x和权重w,偏置b进行处理
y = tf.nn.softmax(tf.matmul(x, w) + b, name='y') #计算交叉熵
cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) #配置优化函数
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) #这个函数的作用是返回 input 中每行最大的 k 个数,并且返回它们所在位置的索引
values, indices = tf.nn.top_k(y, 10) #这函数返回一个将索引的Tensor映射到字符串的查找表
table = tf.contrib.lookup.index_to_string_table_from_tensor(
tf.constant([str(i) for i in xrange(10)])) #在tabel中查找索引
prediction_classes = table.lookup(tf.to_int64(indices)) #然后开始训练迭代啦
for _ in range(FLAGS.training_iteration):
#获取一个batch数据
batch = mnist.train.next_batch(50)
#计算train_step运算,train_step是优化函数的,这个执行带来的作用就是
根据学习率,最小化cross_entropy,执行一次,就调整参数权重w一次
train_step.run(feed_dict={x: batch[0], y_: batch[1]}) #将得到的y和y_进行对比
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) #对比结果计算准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float')) #运行sess,并使用更新后的最终权重,去做预测,并返回预测结果
print 'training accuracy %g' % sess.run(
accuracy, feed_dict={x: mnist.test.images,
y_: mnist.test.labels})
print 'Done training!'

上面就是训练的过程,就和普通情况下train模型是一样的道理,现在,我们看后面的model export

# Export model
# WARNING(break-tutorial-inline-code): The following code snippet is
# in-lined in tutorials, please update tutorial documents accordingly
# whenever code changes.
#export_path_base基本路径代表你要将model export到哪一个路径下面,
#它的值的获取是传入参数的最后一个,训练命令为:
bazel-bin/tensorflow_serving/example/mnist_saved_model /tmp/mnist_model
那输出的路径就是/tmp/mnist_model
export_path_base = sys.argv[-1] #export_path 真正输出的路径是在基本路径的基础上加上版本号,默认是version=1
export_path = os.path.join(
tf.compat.as_bytes(export_path_base),
tf.compat.as_bytes(str(FLAGS.model_version)))
print 'Exporting trained model to', export_path #官网解释:Builds the SavedModel protocol buffer and saves variables and assets.
builder = tf.saved_model.builder.SavedModelBuilder(export_path) # Build the signature_def_map.
# serialized_tf_example是上面提到的占位的输入,
#其当时定义为tf.placeholder(tf.string, name='tf_example') #tf.saved_model.utils.build_tensor_info 的作用是构建一个TensorInfo proto
#输入参数是张量的名称,类型,大小,这里是string,想应该是名称吧,毕竟
#代码还没全部看完,先暂时这么猜测。输出是,基于提供参数的a tensor protocol
# buffer
classification_inputs = tf.saved_model.utils.build_tensor_info(
serialized_tf_example) #函数功能介绍同上,这里不同的是输入参数是prediction_classes,
#其定义,prediction_classes = table.lookup(tf.to_int64(indices)),是一个查找表
#为查找表构建a tensor protocol buffer
classification_outputs_classes = tf.saved_model.utils.build_tensor_info(
prediction_classes) #函数功能介绍同上,这里不同的是输入参数是values,
#其定义,values, indices = tf.nn.top_k(y, 10),是返回的预测值
#为预测值构建a tensor protocol buffer
classification_outputs_scores = tf.saved_model.utils.build_tensor_info(values) #然后,继续看,下面那么多行都是一个语句,一个个结构慢慢解析
#下面可以直观地看到有三个参数,分别是inputs ,ouputs和method_name
#inputs ,是一个字典,其key是tensorflow serving 固定定义的接口,
#为: tf.saved_model.signature_constants.CLASSIFY_INPUTS,value的话
#就是之前build的a tensor protocol buffer 之 classification_inputs
#同样的,output 和method_name 也是一个意思,好吧,这部分就
#了解完啦。
classification_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={
tf.saved_model.signature_constants.CLASSIFY_INPUTS:
classification_inputs
},
outputs={
tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES:
classification_outputs_classes,
tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES:
classification_outputs_scores
},
method_name=tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME)) #这两句话都和上面一样,都是构建a tensor protocol buffer
tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(y) 这个和上面很多行的classification_signature,一样的
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'images': tensor_info_x},
outputs={'scores': tensor_info_y},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) #这个不一样了,tf.group的官网解释挺简洁的
#Create an op that groups multiple operations.
#When this op finishes, all ops in input have finished. This op has no output.
#Returns:An Operation that executes all its inputs.
#我们看下另一个tf.tables_initializer():
#Returns:An Op that initializes all tables. Note that if there are not tables the returned Op is a NoOp
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op') #下面是重点啦,怎么看出来的?因为上面都是定义什么的,下面是最后的操作啦
#就一个函数:builder.add_meta_graph_and_variables,
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'predict_images':
prediction_signature,
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
classification_signature,
},
legacy_init_op=legacy_init_op) builder.save()
print 'Done exporting!'
这里要从 tf.saved_model.builder.SavedModelBuilder 创建build开始,下面是看官网的,
可以直接参考:https://www.tensorflow.org/api_docs/python/tf/saved_model/builder/SavedModelBuilder

创建builder的是class SaveModelBuilder的功能是用来创建SaverModel

protocol buffer 并保存变量和资源,SaverModelBuilder类提供了创建

SaverModel protocol buffer 的函数方法

tensorflow serving 之minist_saved_model.py解读的更多相关文章

  1. tensorflow serving

    1.安装tensorflow serving 1.1确保当前环境已经安装并可运行tensorflow 从github上下载源码 git clone --recurse-submodules https ...

  2. tensorflow serving 中 No module named tensorflow_serving.apis,找不到predict_pb2问题

    最近在学习tensorflow serving,但是运行官网例子,不使用bazel时,发现运行mnist_client.py的时候出错, 在api文件中也没找到predict_pb2,因此,后面在网上 ...

  3. Tensorflow Serving 模型部署和服务

    http://blog.csdn.net/wangjian1204/article/details/68928656 本文转载自:https://zhuanlan.zhihu.com/p/233614 ...

  4. tensorflow serving 编写配置文件platform_config_file的方法

    1.安装grpc gRPC 的安装: $ pip install grpcio 安装 ProtoBuf 相关的 python 依赖库: $ pip install protobuf 安装 python ...

  5. Tensorflow Serving介绍及部署安装

    TensorFlow Serving 是一个用于机器学习模型 serving 的高性能开源库.它可以将训练好的机器学习模型部署到线上,使用 gRPC 作为接口接受外部调用.更加让人眼前一亮的是,它支持 ...

  6. 如何用 tensorflow serving 部署服务

    第一步,读一读这篇博客 https://www.jb51.net/article/138932.htm (浅谈Tensorflow模型的保存与恢复加载) 第二步: 参考博客: https://blog ...

  7. Tensorflow serving的编译

    Tensorflow serving提供了部署tensorflow生成的模型给线上服务的方法,包括模型的export,load等等. 安装参考这个 https://github.com/tensorf ...

  8. 谷歌发布 TensorFlow Serving

    TensorFlow服务是一个灵活的,高性能的机器学习模型的服务系统,专为生产环境而设计. TensorFlow服务可以轻松部署新的算法和实验,同时保持相同的服务器体系结构和API. TensorFl ...

  9. 学习笔记TF067:TensorFlow Serving、Flod、计算加速,机器学习评测体系,公开数据集

    TensorFlow Serving https://tensorflow.github.io/serving/ . 生产环境灵活.高性能机器学习模型服务系统.适合基于实际数据大规模运行,产生多个模型 ...

随机推荐

  1. WebView加载失败或网络异常时,替换WebView的错误界面;

    WebView在加载失败时会显示一个失败原因的界面,各个手机显示的界面还都不一样,部分手机还会把Url显示出来:我们要做的就是统一加载失败的界面: 大概思路:在WebView这个控件上面再覆盖一个Vi ...

  2. Hive快捷查询:不启用Mapreduce job启用Fetch task

    启用MapReduce Job是会消耗系统开销的.对于这个问题,从Hive0.10.0版本开始,对于简单的不需要聚合的类似SELECT <col> from <table> L ...

  3. 使用Redis数据库(String类型)

    一 String类型 首先使用启动服务器进程 : redis-server.exe 1. Set 设置Key对应的值为String 类型的value. 例子:向 Redis数据库中插入一条数据类型为S ...

  4. [python] 初学python,打卡签到

    自学python第一周,学了变量和简单的条件判断. 附上猜数游戏代码 #Author:shijt trueAge=40 count=0 while count<3: guessAge=int(i ...

  5. vim more

      启用鼠标 :set mouse=a 跳转到下一函数 下一个函数开头 ]] 当前函数末尾/下一个函数的末尾 ][ 当前函数开头/上一个函数的开头 [[ 选项可以按任何顺序生效,可以放在文件名前或后边 ...

  6. QTimer的一些注意事项和探索

    注意事项: 1.QTimer's accuracy depends on the underlying operating system and hardware.Most platforms sup ...

  7. SLD Related Gateway Serivces Unavaliable

    SAP NW 7.4 default switched on the ACL (access control list) in gateway service, so only local acces ...

  8. hive 动态分区实现 (hive-1.1.0)

    笔者使用的hive版本是hive-1.1.0 hive-1.1.0动态分区的默认实现是只有map没有reduce,通过执行计划就可以看出来.(执行计划如下) insert overwrite tabl ...

  9. 3. java.lang.UnsupportedClassVersionError: javax/annotation/ManagedBean : Unsupported major.minor version 51.0

    问题描述:

  10. antd-mobile使用报错

    在第一次使用时,按照官网的进行配置,完了报错找不到antd-mobile下面的css 解决方法来源于 :https://github.com/ant-design/ant-design-mobile/ ...