在TensorFlow中,保存模型与加载模型所用到的是tf.train.Saver()这个类。我们一般的想法就是,保存模型之后,在另外的文件中重新将模型导入,我可以利用模型中的operation和variable来测试新的数据。


什么是TensorFlow中的模型

首先,我们先来理解一下TensorFlow里面的模型是什么。在保存模型后,一般会出现下面四个文件:

meta graph:保存了TensorFlow的graph。包括all variables,operations,collections等等。这个文件就是上面的.meta文件。

checkpoint files:二进制文件,保存了所有weights,biases,gradient and all the other variables的值。也就是上图中的.data-00000-of-00001和.index文件。.data文件包含了所有的训练变量。以前的TensorFlow版本是一个ckpt文件,现在就是这两个文件了。与此同时,Tensorflow还有一个名为checkpoint的文件,只保存最新检查点文件的记录,即最新的保存路径。


保存一个TensorFlow的模型

在TensorFlow中,如果想保存一个图(graph)或者所有的参数的值,那么就需要用到tf.train.Saver()这个类。

import tensorflow as tf
saver = tf.train.Saver()
sess = tf.Session()
saver.save(sess, 'my_test_model')

  

上面这段代码最后一句就是保存模型,第二个参数是一个路径(包含模型的名字)。当然还有其他的形参,我们接下来讲:
global_step:给一个数字,用于保存文件时tensorflow帮你命名。主要是说明了迭代多次后保存了。
write_meta_graph:bool型,说明要不要把TensorFlow的图保存下来。
关于save函数更多的说明请参考:
https://www.tensorflow.org/api_docs/python/tf/train/Saver#save

例子:

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model') # This will save following files in Tensorflow v >= 0.11
# my_test_model.data-00000-of-00001
# my_test_model.index
# my_test_model.meta
# checkpoint

  


导入一个训练好的模型

前门讲了如何保存一个模型,现在要把模型导出来用了。

训练好的模型,.meta文件中已经保存了整个graph,我们无需重建,只要导入.meta文件即可。

with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')#这个函数就是讲graph导出来

  

下面用一个例子来说明一下,直接上完整代码:

第一个文件,训练模型并保存模型:

#定义模型
X = tf.placeholder(tf.float32,shape = [None,x_dim],name = 'X')
Y = tf.placeholder(tf.float32,shape = [None,1], name = 'Y')
W = tf.Variable(tf.random_normal([x_dim,1]),name='weight')
b = tf.Variable(tf.random_normal([1]),name='bias')
hypothesis = tf.sigmoid(tf.matmul(X,W)+b)
cost = -tf.reduce_mean(Y*tf.log(hypothesis) + (1-Y)*tf.log(1-hypothesis))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train = optimizer.minimize(cost) #假如想要保存hypothesis和cost,以便在保存模型后,重新导入模型时可以使用。
tf.add_to_collection('hypothesis',hypothesis)#必须有个名字,即第一个参数
tf.add_to_collection('cost',cost) mysaver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for epoch in range(50):
avg_cost, _ = sess.run([cost,train],feed_dict = {X:x_data,Y:y_data}) mysaver.save(sess, '../model/model_LR_test') #保存模型

  

第二个文件,加载模型,并利用训练好的模型预测:

sess = tf.Session()
#本来我们需要重新像上一个文件那样重新构建整个graph,但是利用下面这个语句就可以加载整个graph了,方便
new_saver = tf.train.import_meta_graph('../model/model_LR_test.meta')
new_saver.restore(sess,'../model/model_LR_test')#加载模型中各种变量的值,注意这里不用文件的后缀 #对应第一个文件的add_to_collection()函数
hyp = tf.get_collection('hypothesis')[0] #返回值是一个list,我们要的是第一个,这也说明可以有多个变量的名字一样。 graph = tf.get_default_graph()
X = graph.get_operation_by_name('X').outputs[0]#为了将placeholder加载出来 pred = sess.run(hyp,feed_dict = {X:x_valid})
print('auc:',auc(y_valid,pred))

是这样的,使用TensorFlow构建模型的时候,如果一些operation想要在加载模型时用到。那么需要使用add_to_collection()函数来将operation存起来。然后再加载模型后可以调用。当然tensorflow无论怎样都需要给每个东西一个名字(string型),只有通过名字才可以找到对应的operation。

TensorFlow 模型保存和导入、加载的更多相关文章

  1. tensorflow 模型保存后的加载路径问题

    import tensorflow as tf #保存模型 saver = tf.train.Saver() saver.save(sess, "e://code//python//test ...

  2. Tensorflow 模型持久化saver及加载图结构

    主要内容: 1. 直接保存,加载模型; (可以指定加载,保存的var_list) 2. 加载,保存指定变量的模型 3. slim加载模型使用 4. 加载模型图结构和参数等 tensorflow 恢复部 ...

  3. tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署

    TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...

  4. Tensorflow模型保存与加载

    在使用Tensorflow时,我们经常要将以训练好的模型保存到本地或者使用别人已训练好的模型,因此,作此笔记记录下来. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提 ...

  5. TensorFlow模型保存和加载方法

    TensorFlow模型保存和加载方法 模型保存 import tensorflow as tf w1 = tf.Variable(tf.constant(2.0, shape=[1]), name= ...

  6. keras模型的保存与重新加载

    # 模型保存JSON文件 model_json = model.to_json() with open('model.json', 'w') as file: file.write(model_jso ...

  7. TensorFlow模型保存和提取方法

    一.TensorFlow模型保存和提取方法 1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取.tf.train.Saver对象saver的save方法将Tens ...

  8. TensorFlow 模型保存/载入

    我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来.tensorflow保存模型的方式与sklearn不太一样,sklearn很直接,一个sklearn.externals.jobl ...

  9. Unity3d-WWW实现图片资源显示以及保存和本地加载

    本文固定连接:http://blog.csdn.net/u013108312/article/details/52712844 WWW实现图片资源显示以及保存和本地加载 using UnityEngi ...

随机推荐

  1. Crash使用参考

    整理自man 8 crash 1.简介 Crash工具可以用来分析一个正在运行的内核,也可以用来分析一个内核的crash dump文件,这里说的是内核代码异常产生的crash dump文件,不是应用层 ...

  2. <Effective C++>读书摘要--Introduction

    Introduction 1.Learning the fundamentals of a programming language is one thing; learning how to des ...

  3. 微信小程序项目笔记以及openId体验版获取问题

    公司一直说要搞小程序,说了几个月,最近才算落地,一个很小的项目,就结果来讲,勉强让自己窥得小程序门径. 下面总结一下,为了弄好小程序,所学到的知识,以及项目中遇到的问题以及解决的办法.纯属个人见解. ...

  4. win7 安装 MongoDB 及简单操作

    下载地址 http://dl.mongodb.org/dl/win32/x86_64 这里用的版本是 mongodb-latest-signed.msi 同时下载 mongodb-compass 下载 ...

  5. Print之modile, level

    一般print打印的design都会引入module, level. xxxprint(module, level, arg,...)... 每个Module都可以有各自的bitmap,代表这个mod ...

  6. array to object

    array to object native js & ES6 https://stackoverflow.com/questions/4215737/convert-array-to-obj ...

  7. BZOJ 1211 树的计数(purfer序列)

    首先考虑无解的情况, 根据purfer序列,当dee[i]=0并且n!=1的时候,必然无解.否则为1. 且sum(dee[i]-1)!=n-2也必然无解. 剩下的使用排列组合即可推出公式.需要注意的是 ...

  8. BZOJ 1076 奖励关(状压期望DP)

    当前得分期望=(上一轮得分期望+这一轮得分)/m dp[i,j]:第i轮拿的物品方案为j的最优得分期望 如果我们正着去做,会出现从不合法状态(比如前i个根本无法达到j这种方案),所以从后向前推 如果当 ...

  9. ARC077C pushpush 递推

    ---题面--- 题解: 貌似一般c题都是递推... 观察到最后一个插入的数一定在第一个,倒数第二个插入的数一定在倒数第一个,倒数第三个插入的数一定在第2个,倒数第四个插入的数一定在倒数第2个…… O ...

  10. ContestHunter暑假欢乐赛 SRM 03

    你们也没人提醒我有atcoderQAQ... A题曼哈顿距离=欧拉距离就是在同一行或者同一列,记录下i,j出现过的次数,减去就行,直接map过. B题一开始拿衣服了,一直以为排序和不排序答案是一个样的 ...