机器学习与Tensorflow(7)——tf.train.Saver()、inception-v3的应用
1. tf.train.Saver()
- tf.train.Saver()是一个类,提供了变量、模型(也称图Graph)的保存和恢复模型方法。
- TensorFlow是通过构造Graph的方式进行深度学习,任何操作(如卷积、池化等)都需要operator,保存和恢复操作也不例外。
- 在tf.train.Saver()类初始化时,用于保存和恢复的save和restore operator会被加入Graph。所以,下列类初始化操作应在搭建Graph时完成。
saver = tf.train.Saver()
TensorFlow的保存和恢复分为两种:
- 保存和恢复变量
- 保存和恢复模型
saver.save()保存模型
#举例:
保存一个训练好的手写数据集识别模型
保存在当前路径的net文件夹中
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True) #每个批次100张照片
batch_size = 100
#计算一个需要多少个批次
n_batch = mnist.train.num_examples // batch_size #定义两个placeholder
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10]) #创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x, W) + b)
#代价函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化变量
init = tf.global_variables_initializer() #结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver() with tf.Session() as sess:
sess.run(init)
for epoch in range(11):
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})
acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
print('Iter = ' + str(epoch) +', Testing Accuracy = ' + str(acc))
#保存模型
saver.save(sess, 'net/my_net.ckpt')
#保存路径中的文件为:
checkpoint:保存当前网络状态的文件
my_net.ckpt.data-00000-of-00001
my_net.ckpt.index
my_net.ckpt.meta:保存Graph结构的文件
#关于函数saver.save(),常用的参数就是前三个:
save(
sess, # 必需参数,Session对象
save_path, # 必需参数,存储路径
global_step=None, # 可以是Tensor, Tensor name, 整型数
latest_filename=None, # 协议缓冲文件名,默认为'checkpoint',不用管
meta_graph_suffix='meta', # 图文件的后缀,默认为'.meta',不用管
write_meta_graph=True, # 是否保存Graph
write_state=True, # 建议选择默认值True
strip_default_attrs=False # 是否跳过具有默认值的节点
saver.restore()加载已经训练好的模型
#举例:
通过加载刚才保存的训练好的手写数据集识别模型进行手写数据集的识别
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batch_size = 100
n_batch = mnist.train.num_examples // batch_size x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10]) W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x, W) + b) loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) init = tf.global_variables_initializer() correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver() with tf.Session() as sess:
sess.run(init)
print(sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels}))
saver.restore(sess, 'net/my_net.ckpt')
print(sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels}))
#执行结果: 0.098
0.9178 #直接得到的准确率相当低,通过加载训练好的模型,识别准确率大大提升。
2. 下载google图像识别网络inception-v3并查看结构
模型背景:
Inception(v3) 模型是Google 训练好的最新一个图像识别模型,我们可以利用它来对我们的图像进行识别。
下载地址:
https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip
文件描述:
- classify_image_graph_def.pb 文件就是训练好的Inception-v3模型。
- imagenet_synset_to_human_label_map.txt是类别文件,包含人类标签和uid之间的映射的文件。
- imagenet_2012_challenge_label_map_proto.pbtxt是包含类号和uid之间的映射的文件。
代码实现
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
import tensorflow as tf
import tarfile
import requests #inception模型下载地址
inception_pretrain_model_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' #inception模型存放地址
inception_pretrain_model_dir = 'inception_model'
if not os.path.exists(inception_pretrain_model_dir):
os.makedirs(inception_pretrain_model_dir)
#获取文件名,以及文件路径
filename = inception_pretrain_model_url.split('/')[-1]
filepath = os.path.join(inception_pretrain_model_dir, filename) #下载模型
if not os.path.exists(filepath):
print('download: ', filename)
r = requests.get(inception_pretrain_model_url, stream=True)
with open(filepath, 'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
print('finish: ', filename)
#解压文件
tarfile.open(filepath, 'r:gz').extractall(inception_pretrain_model_dir) #模型结构存放文件
log_dir = 'inception_log'
if not os.path.exists(log_dir):
os.makedirs(log_dir) #classify_image_graph_def.pb为google训练好的模型
inception_graph_def_file = os.path.join(inception_pretrain_model_dir, 'classify_image_graph_def.pb')
with tf.Session() as sess:
#创建一个图来存放google训练好的模型
with tf.gfile.FastGFile(inception_graph_def_file, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
#保存图的结构
writer = tf.summary.FileWriter(log_dir, sess.graph)
writer.close()
#在下载过程中,下的特别慢,不知道是网络原因还是什么
#程序总卡着不动
#所以我就手动下载压缩包并进行解压
下载结果
3. 使用inception-v3做各种图像的识别
#代码实现:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
import tensorflow as tf
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt #这部分是对标签号和类别号文件进行一个预处理 class NodeLookup(object):
def __init__(self):
label_lookup_path = 'inception_model/imagenet_2012_challenge_label_map_proto.pbtxt'
uid_lookup_path = 'inception_model/imagenet_synset_to_human_label_map.txt'
self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
def load(self, label_lookup_path, uid_lookup_path):
#加载分类字符串n********对应分类名称的文件
proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
uid_to_human={}
#一行一行读取数据
for line in proto_as_ascii_lines:
#去掉换行符
line = line.strip('\n')
#按照‘\t’进行分割
parsed_items = line.split('\t')
#获取分类编号
uid = parsed_items[0]
#获取分类名称
human_string = parsed_items[1]
#保存编号字符串n********与分类名称的映射关系
uid_to_human[uid] = human_string #加载分类字符串n********对应分类编号1-1000的文件
proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
node_id_to_uid = {}
for line in proto_as_ascii:
if line.startswith(' target_class:'):
#获取分类编号1-1000
target_class = int(line.split(': ')[1])
if line.startswith(' target_class_string:'):
#获取编号字符串nn********
target_class_string = line.split(': ')[1]
# 保存分类编号1-1000与编号字符串n********映射关系
node_id_to_uid[target_class] = target_class_string[1:-2]
# 建立分类编号1-1000对应分类名称的映射关系
node_id_to_name = {}
for key, val in node_id_to_uid.items():
#获取分类名称
name = uid_to_human[val]
# 建立分类编号1-1000到分类名称的映射关系
node_id_to_name[key] = name
return node_id_to_name
# 传入分类编号1-1000返回分类名称
def id_to_string(self, node_id):
if node_id not in self.node_lookup:
return ''
return self.node_lookup[node_id] #创建一个图来存放google训练好的模型 with tf.gfile.FastGFile('inception_model/classify_image_graph_def.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='') with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
#遍历目录
for root, dirs, files in os.walk('images/'):
for file in files:
#载入图片
image_data = tf.gfile.FastGFile(os.path.join(root, file), 'rb').read()
predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式
predictions = np.squeeze(predictions)#把结果转为1维数据 #打印图片路径及名称
image_path = os.path.join(root, file)
print(image_path) # 显示图片
img = Image.open(image_path)
plt.imshow(img)
plt.axis('off')
plt.show() #排序
top_k = predictions.argsort()[-5:][::-1]
node_lookup = NodeLookup()
for node_id in top_k:
# 获取分类名称
human_string = node_lookup.id_to_string(node_id)
# 获取该分类的置信度
score = predictions[node_id]
print('%s(score = %.5f)' % (human_string, score))
print()
#执行结果:
images/1.jpg
giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca(score = 0.87265)
badger(score = 0.00260)
lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens(score = 0.00205)
brown bear, bruin, Ursus arctos(score = 0.00102)
ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus(score = 0.00099) images/2.jpg
French bulldog(score = 0.94474)
bull mastiff(score = 0.00559)
pug, pug-dog(score = 0.00352)
Staffordshire bullterrier, Staffordshire bull terrier(score = 0.00165)
boxer(score = 0.00116) images/3.jpg
zebra(score = 0.94011)
tiger, Panthera tigris(score = 0.00080)
pencil box, pencil case(score = 0.00066)
hartebeest(score = 0.00059)
tiger cat(score = 0.00042) images/4.jpg
hare(score = 0.87019)
wood rabbit, cottontail, cottontail rabbit(score = 0.04802)
Angora, Angora rabbit(score = 0.00612)
wallaby, brush kangaroo(score = 0.00181)
fox squirrel, eastern fox squirrel, Sciurus niger(score = 0.00056) images/5.jpg
fox squirrel, eastern fox squirrel, Sciurus niger(score = 0.95047)
marmot(score = 0.00265)
mongoose(score = 0.00217)
weasel(score = 0.00201)
mink(score = 0.00199)
机器学习与Tensorflow(7)——tf.train.Saver()、inception-v3的应用的更多相关文章
- TensorFlow:tf.train.Saver()模型保存与恢复
1.保存 将训练好的模型参数保存起来,以便以后进行验证或测试.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf.train.S ...
- tensorflow的tf.train.Saver()模型保存与恢复
将训练好的模型参数保存起来,以便以后进行验证或测试.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf.train.Saver( ...
- 图融合之加载子图:Tensorflow.contrib.slim与tf.train.Saver之坑
import tensorflow as tf import tensorflow.contrib.slim as slim import rawpy import numpy as np impor ...
- TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model
TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model Checkmate is designed to be a simple drop-i ...
- 跟我学算法- tensorflow模型的保存与读取 tf.train.Saver()
save = tf.train.Saver() 通过save. save() 实现数据的加载 通过save.restore() 实现数据的导出 第一步: 数据的载入 import tensorflo ...
- tf.train.Saver()
1. 实例化对象 saver = tf.train.Saver(max_to_keep=1) max_to_keep: 表明保存的最大checkpoint文件数.当一个新文件创建的时候,旧文件就会被删 ...
- tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数(转)
tensorflow数据读取机制 tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数 ...
- tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数
tensorflow数据读取机制 tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数 ...
- 【转载】 tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数
原文地址: https://blog.csdn.net/dcrmg/article/details/79776876 ----------------------------------------- ...
随机推荐
- canvas绘制气泡
思路:使用Math.random()函数绘制是个不同位置,大小,颜色的圆形,然后设置定时器,前一个状态用一个与画布相同颜色的背景图片进行覆盖,改变圆形的位置,每次改变都是在这张空白的背景图片上面重新进 ...
- Kafka自带zookeeper报错INFO Got user-level KeeperException when processing xxx Error Path:/brokers Error:KeeperErrorCode = NodeExists for /brokers (org.apache.zookeeper.server.PrepRequestProcessor)
问题描述: 按照kafka官方文档的操作步骤,解压kafka压缩包后.依次启动zookeeper,和kafka服务 kafka服务启动后,查看到zookeeper日志里有以下异常 问题原因及解决办法: ...
- MySQL 详细学习笔记
Windows服务 -- 启动MySQL net start mysql -- 创建Windows服务 sc create mysql binPath= mysqld_bin_path(注意:等号与值 ...
- cisco PBR
access-list 2000 permit ip 10.11.50.0 0.0.0.255 anyaccess-list 2001 permit ip 10.11.50.0 0.0.0.255 1 ...
- 114. Flatten Binary Tree to Linked List 把二叉树变成链表
[抄题]: Given a binary tree, flatten it to a linked list in-place. For example, given the following tr ...
- pythone函数基础(7)第三方模块学习
一,time模块学习 import time # print(int(time.time()))#时间戳# res = time.strftime('%Y-%m-%d %H:%M:%S')#取当前格式 ...
- Xadmin添加,编辑,删除
Xadmin添加,编辑,删除 1.HTML 编辑和添加页面得内容相同,使用include将他们整合 {% include xxx.html %} 获取指定页面的所有内容 1.单独建个html存放编辑和 ...
- 解决安装xcode后git使用报错的问题
一.现象: htmlxdeMacBook-Pro:demo htmlx$ git status Agreeing to the Xcode/iOS license requires admin pri ...
- Eigen使用矩阵作为函数参数
1 使用矩阵作为函数参数介绍 文章来源Writing Functions Taking %Eigen Types as Parameters Eigen为了在函数中传递不同的类型使用了表达式模板技术. ...
- NC nc开发工具java虚拟机参数
-Dnc.exclude.modules=${FIELD_EX_MODULES} -Dnc.runMode=develop -Dnc.server.location=${FIELD_NC_HO ...