1. tf.train.Saver()

  • tf.train.Saver()是一个类,提供了变量、模型(也称图Graph)的保存和恢复模型方法。
  • TensorFlow是通过构造Graph的方式进行深度学习,任何操作(如卷积、池化等)都需要operator,保存和恢复操作也不例外。
  • 在tf.train.Saver()类初始化时,用于保存和恢复的save和restore operator会被加入Graph。所以,下列类初始化操作应在搭建Graph时完成。
saver = tf.train.Saver()

TensorFlow的保存和恢复分为两种:

  • 保存和恢复变量
  • 保存和恢复模型

saver.save()保存模型

#举例:
保存一个训练好的手写数据集识别模型
保存在当前路径的net文件夹中

 import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True) #每个批次100张照片
batch_size = 100
#计算一个需要多少个批次
n_batch = mnist.train.num_examples // batch_size #定义两个placeholder
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10]) #创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x, W) + b)
#代价函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化变量
init = tf.global_variables_initializer() #结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver() with tf.Session() as sess:
sess.run(init)
for epoch in range(11):
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})
acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
print('Iter = ' + str(epoch) +', Testing Accuracy = ' + str(acc))
#保存模型
saver.save(sess, 'net/my_net.ckpt')
#保存路径中的文件为:
checkpoint:保存当前网络状态的文件
my_net.ckpt.data-00000-of-00001
my_net.ckpt.index
my_net.ckpt.meta:保存Graph结构的文件

#关于函数saver.save(),常用的参数就是前三个:
save(
sess, # 必需参数,Session对象
save_path, # 必需参数,存储路径
global_step=None, # 可以是Tensor, Tensor name, 整型数
latest_filename=None, # 协议缓冲文件名,默认为'checkpoint',不用管
meta_graph_suffix='meta', # 图文件的后缀,默认为'.meta',不用管
write_meta_graph=True, # 是否保存Graph
write_state=True, # 建议选择默认值True
strip_default_attrs=False # 是否跳过具有默认值的节点

saver.restore()加载已经训练好的模型

#举例:
通过加载刚才保存的训练好的手写数据集识别模型进行手写数据集的识别

 import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batch_size = 100
n_batch = mnist.train.num_examples // batch_size x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10]) W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x, W) + b) loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) init = tf.global_variables_initializer() correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver() with tf.Session() as sess:
sess.run(init)
print(sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels}))
saver.restore(sess, 'net/my_net.ckpt')
print(sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels}))
#执行结果:

0.098
0.9178 #直接得到的准确率相当低,通过加载训练好的模型,识别准确率大大提升。

2. 下载google图像识别网络inception-v3并查看结构

模型背景:
  Inception(v3) 模型是Google 训练好的最新一个图像识别模型,我们可以利用它来对我们的图像进行识别。

下载地址:
  https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip

文件描述:

  • classify_image_graph_def.pb 文件就是训练好的Inception-v3模型。
  • imagenet_synset_to_human_label_map.txt是类别文件,包含人类标签和uid之间的映射的文件。
  • imagenet_2012_challenge_label_map_proto.pbtxt是包含类号和uid之间的映射的文件。

代码实现

 import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
import tensorflow as tf
import tarfile
import requests #inception模型下载地址
inception_pretrain_model_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' #inception模型存放地址
inception_pretrain_model_dir = 'inception_model'
if not os.path.exists(inception_pretrain_model_dir):
os.makedirs(inception_pretrain_model_dir)
#获取文件名,以及文件路径
filename = inception_pretrain_model_url.split('/')[-1]
filepath = os.path.join(inception_pretrain_model_dir, filename) #下载模型
if not os.path.exists(filepath):
print('download: ', filename)
r = requests.get(inception_pretrain_model_url, stream=True)
with open(filepath, 'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
print('finish: ', filename)
#解压文件
tarfile.open(filepath, 'r:gz').extractall(inception_pretrain_model_dir) #模型结构存放文件
log_dir = 'inception_log'
if not os.path.exists(log_dir):
os.makedirs(log_dir) #classify_image_graph_def.pb为google训练好的模型
inception_graph_def_file = os.path.join(inception_pretrain_model_dir, 'classify_image_graph_def.pb')
with tf.Session() as sess:
#创建一个图来存放google训练好的模型
with tf.gfile.FastGFile(inception_graph_def_file, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
#保存图的结构
writer = tf.summary.FileWriter(log_dir, sess.graph)
writer.close()
#在下载过程中,下的特别慢,不知道是网络原因还是什么
#程序总卡着不动
#所以我就手动下载压缩包并进行解压

下载结果

3. 使用inception-v3做各种图像的识别

#代码实现:

 import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
import tensorflow as tf
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt #这部分是对标签号和类别号文件进行一个预处理 class NodeLookup(object):
def __init__(self):
label_lookup_path = 'inception_model/imagenet_2012_challenge_label_map_proto.pbtxt'
uid_lookup_path = 'inception_model/imagenet_synset_to_human_label_map.txt'
self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
def load(self, label_lookup_path, uid_lookup_path):
#加载分类字符串n********对应分类名称的文件
proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
uid_to_human={}
#一行一行读取数据
for line in proto_as_ascii_lines:
#去掉换行符
line = line.strip('\n')
#按照‘\t’进行分割
parsed_items = line.split('\t')
#获取分类编号
uid = parsed_items[0]
#获取分类名称
human_string = parsed_items[1]
#保存编号字符串n********与分类名称的映射关系
uid_to_human[uid] = human_string #加载分类字符串n********对应分类编号1-1000的文件
proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
node_id_to_uid = {}
for line in proto_as_ascii:
if line.startswith(' target_class:'):
#获取分类编号1-1000
target_class = int(line.split(': ')[1])
if line.startswith(' target_class_string:'):
#获取编号字符串nn********
target_class_string = line.split(': ')[1]
# 保存分类编号1-1000与编号字符串n********映射关系
node_id_to_uid[target_class] = target_class_string[1:-2]
# 建立分类编号1-1000对应分类名称的映射关系
node_id_to_name = {}
for key, val in node_id_to_uid.items():
#获取分类名称
name = uid_to_human[val]
# 建立分类编号1-1000到分类名称的映射关系
node_id_to_name[key] = name
return node_id_to_name
# 传入分类编号1-1000返回分类名称
def id_to_string(self, node_id):
if node_id not in self.node_lookup:
return ''
return self.node_lookup[node_id] #创建一个图来存放google训练好的模型 with tf.gfile.FastGFile('inception_model/classify_image_graph_def.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='') with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
#遍历目录
for root, dirs, files in os.walk('images/'):
for file in files:
#载入图片
image_data = tf.gfile.FastGFile(os.path.join(root, file), 'rb').read()
predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式
predictions = np.squeeze(predictions)#把结果转为1维数据 #打印图片路径及名称
image_path = os.path.join(root, file)
print(image_path) # 显示图片
img = Image.open(image_path)
plt.imshow(img)
plt.axis('off')
plt.show() #排序
top_k = predictions.argsort()[-5:][::-1]
node_lookup = NodeLookup()
for node_id in top_k:
# 获取分类名称
human_string = node_lookup.id_to_string(node_id)
# 获取该分类的置信度
score = predictions[node_id]
print('%s(score = %.5f)' % (human_string, score))
print()

#执行结果:

images/1.jpg
giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca(score = 0.87265)
badger(score = 0.00260)
lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens(score = 0.00205)
brown bear, bruin, Ursus arctos(score = 0.00102)
ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus(score = 0.00099) images/2.jpg
French bulldog(score = 0.94474)
bull mastiff(score = 0.00559)
pug, pug-dog(score = 0.00352)
Staffordshire bullterrier, Staffordshire bull terrier(score = 0.00165)
boxer(score = 0.00116) images/3.jpg
zebra(score = 0.94011)
tiger, Panthera tigris(score = 0.00080)
pencil box, pencil case(score = 0.00066)
hartebeest(score = 0.00059)
tiger cat(score = 0.00042) images/4.jpg
hare(score = 0.87019)
wood rabbit, cottontail, cottontail rabbit(score = 0.04802)
Angora, Angora rabbit(score = 0.00612)
wallaby, brush kangaroo(score = 0.00181)
fox squirrel, eastern fox squirrel, Sciurus niger(score = 0.00056) images/5.jpg
fox squirrel, eastern fox squirrel, Sciurus niger(score = 0.95047)
marmot(score = 0.00265)
mongoose(score = 0.00217)
weasel(score = 0.00201)
mink(score = 0.00199)

机器学习与Tensorflow(7)——tf.train.Saver()、inception-v3的应用的更多相关文章

  1. TensorFlow:tf.train.Saver()模型保存与恢复

    1.保存 将训练好的模型参数保存起来,以便以后进行验证或测试.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf.train.S ...

  2. tensorflow的tf.train.Saver()模型保存与恢复

    将训练好的模型参数保存起来,以便以后进行验证或测试.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf.train.Saver( ...

  3. 图融合之加载子图:Tensorflow.contrib.slim与tf.train.Saver之坑

    import tensorflow as tf import tensorflow.contrib.slim as slim import rawpy import numpy as np impor ...

  4. TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model

      TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model Checkmate is designed to be a simple drop-i ...

  5. 跟我学算法- tensorflow模型的保存与读取 tf.train.Saver()

    save =  tf.train.Saver() 通过save. save() 实现数据的加载 通过save.restore() 实现数据的导出 第一步: 数据的载入 import tensorflo ...

  6. tf.train.Saver()

    1. 实例化对象 saver = tf.train.Saver(max_to_keep=1) max_to_keep: 表明保存的最大checkpoint文件数.当一个新文件创建的时候,旧文件就会被删 ...

  7. tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数(转)

    tensorflow数据读取机制 tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数 ...

  8. tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数

    tensorflow数据读取机制 tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数 ...

  9. 【转载】 tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数

    原文地址: https://blog.csdn.net/dcrmg/article/details/79776876 ----------------------------------------- ...

随机推荐

  1. javaScript+html5实现图片拖拽

    源码: <!DOCTYPE html><html><head> <meta charset="utf-8"/> <title& ...

  2. vue 关键词模糊查询

    页面html,绑定的列表数据为datas,关键词为 select_words,如下图 其中d.accounts和d.roleName是需要进行搜索的字段,也可以进行大小写都可以

  3. win7 win10下80端口被System进程占用的解决方法

    用如下方法可以解决System进程占用80端口的问题:打开RegEdit:开始-运行-输入regedit-调出注册表找到HKEY_LOCAL_MACHINE\SYSTEM\CurrentControl ...

  4. 9. Palindrome Number (JAVA)

    Determine whether an integer is a palindrome. An integer is a palindrome when it reads the same back ...

  5. 20175234 2018-2019-2 《Java程序设计》第三周学习总结

    20175234 2018-2019-2 <Java程序设计>第三周学习总结 教材学习内容重难点总结 关于驼峰式的认识 为了增加程序的可读性,除了在代码之间增加注释之外,程序员大都把代码中 ...

  6. C#,如何程序使用正则表达式如何使用匹配的位置的结果修改匹配到的值

    程序代码使用正则表达式如何修改匹配到的值: 代码一: using System; using System.Text.RegularExpressions; public class Example ...

  7. zabbix自定义监控

    有的时候zabbix提供的监控项目,不能满足我们生产环境下的监控需求,此时我们就要按照zabbix的规范自定义监控项目,达到监控的目的 zabbix_get:模拟zabbix_server和agent ...

  8. leetcode3:无重复字符的最长子串

    给定一个字符串,找出不含有重复字符的最长子串的长度. 示例: 给定 "abcabcbb" ,没有重复字符的最长子串是 "abc" ,那么长度就是3. 给定 &q ...

  9. 粒子动画——Pygame

    你是否也想做出下图这么漂亮的动态效果?想的话就跟着我一起做吧=.= 工具: Python--Pygame 仔细观察上图,你能发现哪些机制呢?再在下面对比一下是否跟你想的一样. 运行机制: 1.随机方向 ...

  10. Java 日志体系

    Java 日志体系 <java 日志和 SLF4J 随想>:http://ifeve.com/java-slf4j-think/ 一.常用的日志组件 名称 jar 描述 log4j log ...