1. tf.train.Saver()

  • tf.train.Saver()是一个类,提供了变量、模型(也称图Graph)的保存和恢复模型方法。
  • TensorFlow是通过构造Graph的方式进行深度学习,任何操作(如卷积、池化等)都需要operator,保存和恢复操作也不例外。
  • 在tf.train.Saver()类初始化时,用于保存和恢复的save和restore operator会被加入Graph。所以,下列类初始化操作应在搭建Graph时完成。
saver = tf.train.Saver()

TensorFlow的保存和恢复分为两种:

  • 保存和恢复变量
  • 保存和恢复模型

saver.save()保存模型

#举例:
保存一个训练好的手写数据集识别模型
保存在当前路径的net文件夹中

 import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True) #每个批次100张照片
batch_size = 100
#计算一个需要多少个批次
n_batch = mnist.train.num_examples // batch_size #定义两个placeholder
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10]) #创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x, W) + b)
#代价函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化变量
init = tf.global_variables_initializer() #结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver() with tf.Session() as sess:
sess.run(init)
for epoch in range(11):
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})
acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
print('Iter = ' + str(epoch) +', Testing Accuracy = ' + str(acc))
#保存模型
saver.save(sess, 'net/my_net.ckpt')
#保存路径中的文件为:
checkpoint:保存当前网络状态的文件
my_net.ckpt.data-00000-of-00001
my_net.ckpt.index
my_net.ckpt.meta:保存Graph结构的文件

#关于函数saver.save(),常用的参数就是前三个:
save(
sess, # 必需参数,Session对象
save_path, # 必需参数,存储路径
global_step=None, # 可以是Tensor, Tensor name, 整型数
latest_filename=None, # 协议缓冲文件名,默认为'checkpoint',不用管
meta_graph_suffix='meta', # 图文件的后缀,默认为'.meta',不用管
write_meta_graph=True, # 是否保存Graph
write_state=True, # 建议选择默认值True
strip_default_attrs=False # 是否跳过具有默认值的节点

saver.restore()加载已经训练好的模型

#举例:
通过加载刚才保存的训练好的手写数据集识别模型进行手写数据集的识别

 import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batch_size = 100
n_batch = mnist.train.num_examples // batch_size x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10]) W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x, W) + b) loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) init = tf.global_variables_initializer() correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver() with tf.Session() as sess:
sess.run(init)
print(sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels}))
saver.restore(sess, 'net/my_net.ckpt')
print(sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels}))
#执行结果:

0.098
0.9178 #直接得到的准确率相当低,通过加载训练好的模型,识别准确率大大提升。

2. 下载google图像识别网络inception-v3并查看结构

模型背景:
  Inception(v3) 模型是Google 训练好的最新一个图像识别模型,我们可以利用它来对我们的图像进行识别。

下载地址:
  https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip

文件描述:

  • classify_image_graph_def.pb 文件就是训练好的Inception-v3模型。
  • imagenet_synset_to_human_label_map.txt是类别文件,包含人类标签和uid之间的映射的文件。
  • imagenet_2012_challenge_label_map_proto.pbtxt是包含类号和uid之间的映射的文件。

代码实现

 import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
import tensorflow as tf
import tarfile
import requests #inception模型下载地址
inception_pretrain_model_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' #inception模型存放地址
inception_pretrain_model_dir = 'inception_model'
if not os.path.exists(inception_pretrain_model_dir):
os.makedirs(inception_pretrain_model_dir)
#获取文件名,以及文件路径
filename = inception_pretrain_model_url.split('/')[-1]
filepath = os.path.join(inception_pretrain_model_dir, filename) #下载模型
if not os.path.exists(filepath):
print('download: ', filename)
r = requests.get(inception_pretrain_model_url, stream=True)
with open(filepath, 'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
print('finish: ', filename)
#解压文件
tarfile.open(filepath, 'r:gz').extractall(inception_pretrain_model_dir) #模型结构存放文件
log_dir = 'inception_log'
if not os.path.exists(log_dir):
os.makedirs(log_dir) #classify_image_graph_def.pb为google训练好的模型
inception_graph_def_file = os.path.join(inception_pretrain_model_dir, 'classify_image_graph_def.pb')
with tf.Session() as sess:
#创建一个图来存放google训练好的模型
with tf.gfile.FastGFile(inception_graph_def_file, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
#保存图的结构
writer = tf.summary.FileWriter(log_dir, sess.graph)
writer.close()
#在下载过程中,下的特别慢,不知道是网络原因还是什么
#程序总卡着不动
#所以我就手动下载压缩包并进行解压

下载结果

3. 使用inception-v3做各种图像的识别

#代码实现:

 import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
import tensorflow as tf
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt #这部分是对标签号和类别号文件进行一个预处理 class NodeLookup(object):
def __init__(self):
label_lookup_path = 'inception_model/imagenet_2012_challenge_label_map_proto.pbtxt'
uid_lookup_path = 'inception_model/imagenet_synset_to_human_label_map.txt'
self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
def load(self, label_lookup_path, uid_lookup_path):
#加载分类字符串n********对应分类名称的文件
proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
uid_to_human={}
#一行一行读取数据
for line in proto_as_ascii_lines:
#去掉换行符
line = line.strip('\n')
#按照‘\t’进行分割
parsed_items = line.split('\t')
#获取分类编号
uid = parsed_items[0]
#获取分类名称
human_string = parsed_items[1]
#保存编号字符串n********与分类名称的映射关系
uid_to_human[uid] = human_string #加载分类字符串n********对应分类编号1-1000的文件
proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
node_id_to_uid = {}
for line in proto_as_ascii:
if line.startswith(' target_class:'):
#获取分类编号1-1000
target_class = int(line.split(': ')[1])
if line.startswith(' target_class_string:'):
#获取编号字符串nn********
target_class_string = line.split(': ')[1]
# 保存分类编号1-1000与编号字符串n********映射关系
node_id_to_uid[target_class] = target_class_string[1:-2]
# 建立分类编号1-1000对应分类名称的映射关系
node_id_to_name = {}
for key, val in node_id_to_uid.items():
#获取分类名称
name = uid_to_human[val]
# 建立分类编号1-1000到分类名称的映射关系
node_id_to_name[key] = name
return node_id_to_name
# 传入分类编号1-1000返回分类名称
def id_to_string(self, node_id):
if node_id not in self.node_lookup:
return ''
return self.node_lookup[node_id] #创建一个图来存放google训练好的模型 with tf.gfile.FastGFile('inception_model/classify_image_graph_def.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='') with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
#遍历目录
for root, dirs, files in os.walk('images/'):
for file in files:
#载入图片
image_data = tf.gfile.FastGFile(os.path.join(root, file), 'rb').read()
predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式
predictions = np.squeeze(predictions)#把结果转为1维数据 #打印图片路径及名称
image_path = os.path.join(root, file)
print(image_path) # 显示图片
img = Image.open(image_path)
plt.imshow(img)
plt.axis('off')
plt.show() #排序
top_k = predictions.argsort()[-5:][::-1]
node_lookup = NodeLookup()
for node_id in top_k:
# 获取分类名称
human_string = node_lookup.id_to_string(node_id)
# 获取该分类的置信度
score = predictions[node_id]
print('%s(score = %.5f)' % (human_string, score))
print()

#执行结果:

images/1.jpg
giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca(score = 0.87265)
badger(score = 0.00260)
lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens(score = 0.00205)
brown bear, bruin, Ursus arctos(score = 0.00102)
ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus(score = 0.00099) images/2.jpg
French bulldog(score = 0.94474)
bull mastiff(score = 0.00559)
pug, pug-dog(score = 0.00352)
Staffordshire bullterrier, Staffordshire bull terrier(score = 0.00165)
boxer(score = 0.00116) images/3.jpg
zebra(score = 0.94011)
tiger, Panthera tigris(score = 0.00080)
pencil box, pencil case(score = 0.00066)
hartebeest(score = 0.00059)
tiger cat(score = 0.00042) images/4.jpg
hare(score = 0.87019)
wood rabbit, cottontail, cottontail rabbit(score = 0.04802)
Angora, Angora rabbit(score = 0.00612)
wallaby, brush kangaroo(score = 0.00181)
fox squirrel, eastern fox squirrel, Sciurus niger(score = 0.00056) images/5.jpg
fox squirrel, eastern fox squirrel, Sciurus niger(score = 0.95047)
marmot(score = 0.00265)
mongoose(score = 0.00217)
weasel(score = 0.00201)
mink(score = 0.00199)

机器学习与Tensorflow(7)——tf.train.Saver()、inception-v3的应用的更多相关文章

  1. TensorFlow:tf.train.Saver()模型保存与恢复

    1.保存 将训练好的模型参数保存起来,以便以后进行验证或测试.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf.train.S ...

  2. tensorflow的tf.train.Saver()模型保存与恢复

    将训练好的模型参数保存起来,以便以后进行验证或测试.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf.train.Saver( ...

  3. 图融合之加载子图:Tensorflow.contrib.slim与tf.train.Saver之坑

    import tensorflow as tf import tensorflow.contrib.slim as slim import rawpy import numpy as np impor ...

  4. TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model

      TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model Checkmate is designed to be a simple drop-i ...

  5. 跟我学算法- tensorflow模型的保存与读取 tf.train.Saver()

    save =  tf.train.Saver() 通过save. save() 实现数据的加载 通过save.restore() 实现数据的导出 第一步: 数据的载入 import tensorflo ...

  6. tf.train.Saver()

    1. 实例化对象 saver = tf.train.Saver(max_to_keep=1) max_to_keep: 表明保存的最大checkpoint文件数.当一个新文件创建的时候,旧文件就会被删 ...

  7. tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数(转)

    tensorflow数据读取机制 tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数 ...

  8. tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数

    tensorflow数据读取机制 tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数 ...

  9. 【转载】 tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数

    原文地址: https://blog.csdn.net/dcrmg/article/details/79776876 ----------------------------------------- ...

随机推荐

  1. Scrollview包裹布局问题。

    输入框获取焦点,键盘弹出,背景图片上移: https://blog.csdn.net/wljian1/article/details/79962802 android:scrollbarThumbVe ...

  2. cdnbest如何在站点里开启强制缓存

    在站点设置中如下图设置: 强制缓存有两种方式,一种是文件类型,一种是url方式

  3. Python设计模式 - UML - 部署图(Deployment Diagram)

    简介 部署图也称配置图,用来显示系统中硬件和软件的物理架构.从中可以了解到软件和硬件组件之间的物理拓扑.连接关系以及处理节点的分布情况. 部署图建模步骤 - 找出需要进行部署的各类节点,如网络硬件设备 ...

  4. PHP安装linux

    PHP 7.1.26 安装 1. 下载安装包 例php-7.1.26.tar.gz 下载地址:http://cn2.php.net/downloads.php 2. 放入/root 3. 切换目录 # ...

  5. java_16Arrays类

    1sort():对数组进行升序排列 public static void main(String[] args) { int[] arr= {2,43,6,7}; Arrays.sort(arr); ...

  6. 文件上传的UI自动化

    from pywinauto.application import Application import win32gui handle = win32gui.FindWindow("#32 ...

  7. spring入门--spring入门案例

    spring是一个框架,这个框架可以干很多很多的事情.感觉特别吊.但是,对于初学者来说,很难理解spring到底是干什么的.我刚开始的时候也不懂,后来就跟着敲,在后来虽然懂了,但是依然说不明白它到底是 ...

  8. [字符串]TrBBnsformBBtion

    TrBBnsformBBtion Let us consider the following operations on a string consisting of A and B: Select ...

  9. ad_imh

    1. stc15    ad_1.7  9600 1t #include < #include <absacc.h> #include <stdio.h> #includ ...

  10. 黑马java课程2222

    课程叫做27天学通java零基础 java 安装: 必须装jdk-7u72-windows-i586.exe 注意必须安装32位的就是i586这个.因为x64的不向下兼容.会有意向不到的bug 配置P ...