TensorFlow 1.4利用Keras+Estimator API进行训练和预测
Tensorflow 1.4中,Keras作为作为核心模块可以直接通过tf.keas进行调用,但是考虑到keras对tfrecords文件进行操作比较麻烦,而将keras模型转成tensorflow中的另一个高级API -- Estimator模型,然后就可以调用Dataset API进行对tfrecords进行操作用来训练/评估模型。而keras本身也用到了Estimator API并且提供了tf.keras.estimator.model_to_estimator函数将keras模型可以很方便的转换成Estimator模型,因此用Keras API搭建模型框架然后用Dataset API操作IO,然后用Estimator训练模型是一套比较方便高效的操作流程。
注:
tf.keras.estimator.model_to_estimator这个函数只在tf.keras下面有在原生的keras中是没有这个函数的。- Estimator训练的模型类型主要有
regressor和classifier两类,如果需要用自定义的模型类型,可以通过自定有model_fn来构建,具体操作可以查看这里 - Estimator模型可以通过
export_savedmodel()函数输出训练好的estimator模型,然后可以把模型创建服务接受输入数据并输出结果,这在大规模云端部署的时候会非常有用(具体操作流程可以看这里)。
1. 利用Keras搭建模型框架并转换成estimator模型
比如我们利用keras的ResNet50构建二分类模型:
import tensorflow as tf
import os
resnet = tf.keras.applications.resnet50
def my_model_fn():
base_model = resnet.ResNet50(include_top=True, # include fully layers or not
weights='imagenet', # pre-trained weights
input_shape=(224, 224, 3), # default input shape
classes=2)
base_model.summary()
optimizer = tf.keras.optimizers.RMSprop(lr=2e-3,
decay=0.9)
base_model.compile(optimizer=optimizer,
loss='categorical_crossentropy',
metrics=["accurary"])
# convert keras model to estimator model
model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), “train”)
est_model = tf.keras.estimator.model_to_estimator(base_model, model_dir=model_dir) # model save dir is 'train',
return est_model
注意:model_dir必须为全路径,使用相对路径的estimator在检索模型输入输出的时候可能会报错。
2. 利用Dateset API从tfrecords中读取数据并构建estimator模型的输入input_fn
比如tfrecords的图片和标签分别保存在“image/encoded”和“image/label”下:(如何写tfrecords可以参考这里)
def _tf_example_parser(record):
feature = {"image/encoded": tf.FixedLenFeature([], tf.string),
"image/class_id": tf.FixedLenFeature([], tf.int64)}
features = tf.parse_single_example(record, features=feature)
image = tf.decode_raw(features["image/encoded"], out_type=tf.uint8) # 写入tfrecords的时候是misc/cv2读取的ndarray
image = tf.cast(image, dtype=tf.float32)
image = tf.reshape(image, shape=(224, 224, 3)) # 如果输入图片不做resize,那么不同大小的图片是无法输入到同一个batch中的
label = tf.cast(features["image/class_id"], dtype=tf.int64)
return image, label
def input_fn(data_path, batch_size=64, is_training=True):
"""
Parse image and data in tfrecords file for training
Args:
data_path: a list or single tf records path
batch_size: size of images returned
is_training: is training stage or not
Returns:
image and labels batches of randomly shuffling tensors
"""
with K.name_scope("input_pipeline"):
if not isinstance(data_path, (tuple, list)):
data_path = [data_path]
dataset = tf.data.TFRecordDataset(data_path)
dataset = dataset.map(_tf_example_parser)
dataset = dataset.repeat(25) # num of epochs
dataset = dataset.batch(64) # batch size
if is_training:
dataset = dataset.shuffle(1000) # 对输入进行shuffle,buffer_size越大,内存占用越大,shuffle的时间也越长,因此可以在写tfrecords的时候就实现用乱序写入,这样的话这里就不需要用shuffle
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
# convert to onehot label
labels = tf.one_hot(labels, 2) # 二分类
# preprocess image: scale pixel values from 0-255 to 0-1
images = tf.image.convert_image_dtype(images, dtype=tf.float32) # 将图片像素从0-255转换成0-1,tf提供的图像操作大多需要0-1之间的float32类型
images /= 255.
images -= 0.5
images *= 2.
return dict({"input_1": images}), labels
3. 利用estimator API训练模型
def train(source_dir):
if tf.gfile.Exists(source_dir):
train_data_paths = tf.gfile.Glob(source_dir+"/train*tfrecord") # 所有train开头的tfrecords都用于模型训练
val_data_paths = tf.gfile.Glob(source_dir+"/val*tfrecord") # 所有val开头的tfrecords都用于模型评估
train_data_paths = val_data_paths
if not len(train_data_paths):
raise Exception("[Train Error]: unable to find train*.tfrecord file")
if not len(val_data_paths):
raise Exception("[Eval Error]: unable to find val*.tfrecord file")
else:
raise Exception("[Train Error]: unable to find input directory!")
est_model = my_model_fn()
train_spec = tf.estimator.TrainSpec(input_fn=lambda: input_fn(data_path=train_data_paths,
batch_size=_BATCH_SIZE,
is_training=True),
max_steps=300000)
eval_spec = tf.estimator.EvalSpec(input_fn=lambda: input_fn(val_data_paths,
batch_size=_BATCH_SIZE,
is_training=False))
# train and evaluate model
tf.estimator.train_and_evaluate(estimator=est_model,
train_spec=train_spec,
eval_spec=eval_spec)
PS: 这里用lambda表示输入为函数,而非函数的返回值;也可以用partial函数进行包裹;而当没有输入变量的时候就可以直接用。
训练的时候,用途tensorboard监控train目录查看训练过程。
4. 利用estimator模型进行预测
由于estimator模型predict函数的输入与训练的时候一样为input_fn,但是此时直接从文件中读取,而非tfrecords,因此需要重新定义一个input_fn来用于predict。
def predict_input_fn(image_path):
images = misc.imread(image_path)
# preprocess image: scale pixel values from 0-255 to 0-1
images = tf.image.convert_image_dtype(images, dtype=tf.float32)
images /= 255.
images -= 0.5
images *= 2.
dataset = tf.data.Dataset.from_tensor_slices((images, ))
return dataset.batch(1).make_one_shot_iterator().get_next()
def predict(image_path):
est_model = my_model_fn()
result = est_model.predict(input_fn=lambda: predict_input_fn(image_path=image_path))
for r in result:
print(r)
参考:
Estimator:
- https://www.dlology.com/blog/an-easy-guide-to-build-new-tensorflow-datasets-and-estimator-with-keras-model/
- https://www.tensorflow.org/versions/master/programmers_guide/estimators
- https://www.tensorflow.org/extend/estimators
- https://medium.com/onfido-tech/higher-level-apis-in-tensorflow-67bfb602e6c0
- https://www.tensorflow.org/programmers_guide/saved_model
Dataset
- https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/programmers_guide/datasets.md
- https://developers.googleblog.com/2017/09/introducing-tensorflow-datasets.html
Tfredords:
http://www.machinelearninguru.com/deep_learning/tensorflow/basics/tfrecord/tfrecord.html
Keras:
https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html
TensorFlow 1.4利用Keras+Estimator API进行训练和预测的更多相关文章
- 手写数字识别——利用keras高层API快速搭建并优化网络模型
在<手写数字识别——手动搭建全连接层>一文中,我们通过机器学习的基本公式构建出了一个网络模型,其实现过程毫无疑问是过于复杂了——不得不考虑诸如数据类型匹配.梯度计算.准确度的统计等问题,但 ...
- Spark技术在京东智能供应链预测的应用——按照业务进行划分,然后利用scikit learn进行单机训练并预测
3.3 Spark在预测核心层的应用 我们使用Spark SQL和Spark RDD相结合的方式来编写程序,对于一般的数据处理,我们使用Spark的方式与其他无异,但是对于模型训练.预测这些需要调用算 ...
- tensorflow estimator API小栗子
TensorFlow的高级机器学习API(tf.estimator)可以轻松配置,训练和评估各种机器学习模型. 在本教程中,您将使用tf.estimator构建一个神经网络分类器,并在Iris数据集上 ...
- 【Python与机器学习】:利用Keras进行多类分类
多类分类问题本质上可以分解为多个二分类问题,而解决二分类问题的方法有很多.这里我们利用Keras机器学习框架中的ANN(artificial neural network)来解决多分类问题.这里我们采 ...
- Python机器学习笔记:利用Keras进行分类预测
Keras是一个用于深度学习的Python库,它包含高效的数值库Theano和TensorFlow. 本文的目的是学习如何从csv中加载数据并使其可供Keras使用,如何用神经网络建立多类分类的数据进 ...
- TensorFlow数据读取方式:Dataset API
英文详细版参考:https://www.cnblogs.com/jins-note/p/10243716.html Dataset API是TensorFlow 1.3版本中引入的一个新的模块,主要服 ...
- 人脸检测及识别python实现系列(5)——利用keras库训练人脸识别模型
人脸检测及识别python实现系列(5)——利用keras库训练人脸识别模型 经过前面稍显罗嗦的准备工作,现在,我们终于可以尝试训练我们自己的卷积神经网络模型了.CNN擅长图像处理,keras库的te ...
- Tensorflow、Pytorch、Keras的多GPU使用
Tensorflow.Pytorch.Keras的多GPU的并行操作 方法一 :使用深度学习工具提供的 API指定 1.1 Tesorflow tensroflow指定GPU的多卡并行的时候,也是可以 ...
- CNN眼中的世界:利用Keras解释CNN的滤波器
转载自:https://keras-cn.readthedocs.io/en/latest/legacy/blog/cnn_see_world/ 文章信息 本文地址:http://blog.keras ...
随机推荐
- BZOJ3531[Sdoi2014]旅行——树链剖分+线段树
题目描述 S国有N个城市,编号从1到N.城市间用N-1条双向道路连接,满足从一个城市出发可以到达其它所有城市.每个城市信仰不同的宗教,如飞天面条神教.隐形独角兽教.绝地教都是常见的信仰.为了方便,我们 ...
- Flask请求request
Flask中的request是一个公共变量,需要导入request from flask import Flask,request 接收url中的参数 @app.route("/req&qu ...
- Python基础-字符串、集合类型、判断、深拷贝与浅拷贝、文件读写
字符串 1.定义三个变量: 2.交换两个变量值 1)引入第三个变量: 2)Python引入第三方变量: 3)不引入第三方变量: 3. isalpha 是否是汉字或字母 4.Isalnum 是否是汉字 ...
- SuperWebSocket与Cocos2dx通信时执行不了命令的问题
要修改WebSocketSession.cs 中的方法 string IWebSocketSession.GetAvailableSubProtocol(string protocol) { if ( ...
- 【模拟】[NOIP2014]螺旋矩阵[c++]
题目描述 一个n行n列的螺旋矩阵可由如下方法生成: 从矩阵的左上角(第1行第1列)出发,初始时向右移动:如果前方是未曾经过的格子,则继续前进,否则右转:重复上述操作直至经过矩阵中所有格子.根据经过顺序 ...
- mysql 创建用户命令-grant
我们在使用mysql的过程中,经常需要对用户授权(添加,修改,删除),在mysql当中有三种方式实现 分别是 INSERT USER表的方法.CREATE USER的方法.GRANT的方法.今天主要看 ...
- Javascript数组(一)排序
一.简介首先,我们来看一下JS中sort()和reverse()这两个函数的函数吧reverse();这个函数是用来进行倒序,这个没有什么可说的,所谓倒序就是大的在前面,小的在后面. 比如: var ...
- redis启动出错Creating Server TCP listening socket 127.0.0.1:6379: bind: No error(转)
redis启动出错Creating Server TCP listening socket 127.0.0.1:6379: bind: No error windows下安装Redis第一次启动报 ...
- python之使用set对列表去重,并保持列表原来顺序(转)
https://www.cnblogs.com/laowangball/p/8424432.html #原始方法,但是会打乱顺序 mylist = [1,2,2,2,2,3,3,3,4,4,4,4]m ...
- shell编程学习笔记(九):Shell中的case条件判断
除了可以使用if条件判断,还可以使用case 以下蓝色字体部分为Linux命令,红色字体的内容为输出的内容: # cd /opt/scripts # vim script08.sh 开始编写scrip ...