如何用Tensorflow训练模型成pb文件和和如何加载已经训练好的模型文件
这篇薄荷主要是讲了如何用tensorflow去训练好一个模型,然后生成相应的pb文件。最后会将如何重新加载这个pb文件。
首先先放出PO主的github:
https://github.com/ppplinday/tensorflow-vgg16-train-and-test
其中的pitcute文件是狗和猫的图片分别15张一共30(别吐槽,只是为了练手学习的233333), train那个就是训练的文件,test这个就是测试的文件。
接着PO主会慢慢讲解相应的步骤。
!!!ps:由于PO主也是新手,所以难免会出现一点(很多)小错误,希望大婶看了能够提出来让PO主好好学习233333。
- train
首先说一下train。一开始当然是读图片啦。
def read_img(path):
cate = [path + x for x in os.listdir(path) if os.path.isdir(path + x)]
imgs = []
labels = []
for idx, folder in enumerate(cate):
for im in glob.glob(folder + '/*.jpg'):
print('reading the image: %s' % (im))
img = io.imread(im)
img = transform.resize(img, (w, h, c))
imgs.append(img)
labels.append(idx)
return np.asarray(imgs, np.float32), np.asarray(labels, np.int32)
data, label = read_img(path)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
用io.imread来读取每一张图片,然后resize成vgg的输入的大小(224,224,3),最后分别放入了data和label中。
num_example = data.shape[0]
arr = np.arange(num_example)
np.random.shuffle(arr)
data = data[arr]
label = label[arr]
- 1
- 2
- 3
- 4
- 5
这里是把图片的顺序打乱,先生成一个等差数列,然后打乱,最后赋值回原来的data和label
ratio = 0.8
s = np.int(num_example * ratio)
x_train = data[:s]
y_train = label[:s]
x_val = data[s:]
y_val = label[s:]
- 1
- 2
- 3
- 4
- 5
- 6
全部的数据中百分之80的用来train,剩下20的用来test(虽然一共才30张图片。。。。。)
def build_network(height, width, channel):
x = tf.placeholder(tf.float32, shape=[None, height, width, channel], name='input')
y = tf.placeholder(tf.int64, shape=[None, 2], name='labels_placeholder')
- 1
- 2
- 3
- 4
开始build相应的vgg model,这一步不难,但是每一层最好都给上相应的name。上面的x和y是相应的输入和相应的标签。
finaloutput = tf.nn.softmax(output_fc8, name="softmax")
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=finaloutput, labels=y))
optimize = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(cost)
prediction_labels = tf.argmax(finaloutput, axis=1, name="output")
read_labels = y
correct_prediction = tf.equal(prediction_labels, read_labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
correct_times_in_batch = tf.reduce_sum(tf.cast(correct_prediction, tf.int32))
return dict(
x=x,
y=y,
optimize=optimize,
correct_prediction=correct_prediction,
correct_times_in_batch=correct_times_in_batch,
cost=cost,
)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
在build的最后,是需要进行误差计算。finaloutput是最后的输出,cost是计算误差,optimize是定义训练时候安什么方式,也注意一下最后的return。
接着是训练过程。
def train_network(graph, batch_size, num_epochs, pb_file_path):
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
epoch_delta = 2
for epoch_index in range(num_epochs):
for i in range(12):
sess.run([graph['optimize']], feed_dict={
graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
})
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
其实训练的代码就这些,定好了batchsize和numepoch进行训练。下面的代码主要是为了看每几次相应的正确率。
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"])
with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
f.write(constant_graph.SerializeToString())
- 1
- 2
- 3
这两句是重要的代码,用来把训练好的模型保存为pb文件。运行完之后就会发现应该的文件夹多出了一个pb文件。
- test
def recognize(jpg_path, pb_file_path):
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(pb_file_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
- 1
- 2
- 3
- 4
- 5
- 6
- 7
打开相应的pb文件。
img = io.imread(jpg_path)
img = transform.resize(img, (224, 224, 3))
img_out_softmax = sess.run(out_softmax, feed_dict={input_x:np.reshape(img, [-1, 224, 224, 3])})
- 1
- 2
- 3
- 4
读取图片文件,resize之后放入模型的输入位置,之后img_out_softmax就是相应输出的结果。
这大概就是整个流程。目的是为了练练手,PO主应该有挺多小错误,希望大家能够提出来让PO主好好学习哈哈哈!!!
最后放出整个的train和test的代码:
train
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import tensorflow as tf
import os
import glob
from skimage import io, transform
from tensorflow.python.framework import graph_util
import collections
path = '/home/zhoupeilin/vgg16/picture/'
w = 224
h = 224
c = 3
def read_img(path):
cate = [path + x for x in os.listdir(path) if os.path.isdir(path + x)]
imgs = []
labels = []
for idx, folder in enumerate(cate):
for im in glob.glob(folder + '/*.jpg'):
print('reading the image: %s' % (im))
img = io.imread(im)
img = transform.resize(img, (w, h, c))
imgs.append(img)
labels.append(idx)
return np.asarray(imgs, np.float32), np.asarray(labels, np.int32)
data, label = read_img(path)
num_example = data.shape[0]
arr = np.arange(num_example)
np.random.shuffle(arr)
data = data[arr]
label = label[arr]
ratio = 0.8
s = np.int(num_example * ratio)
x_train = data[:s]
y_train = label[:s]
x_val = data[s:]
y_val = label[s:]
def build_network(height, width, channel):
x = tf.placeholder(tf.float32, shape=[None, height, width, channel], name='input')
y = tf.placeholder(tf.int64, shape=[None, 2], name='labels_placeholder')
def weight_variable(shape, name="weights"):
initial = tf.truncated_normal(shape, dtype=tf.float32, stddev=0.1)
return tf.Variable(initial, name=name)
def bias_variable(shape, name="biases"):
initial = tf.constant(0.1, dtype=tf.float32, shape=shape)
return tf.Variable(initial, name=name)
def conv2d(input, w):
return tf.nn.conv2d(input, w, [1, 1, 1, 1], padding='SAME')
def pool_max(input):
return tf.nn.max_pool(input,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding='SAME',
name='pool1')
def fc(input, w, b):
return tf.matmul(input, w) + b
# conv1
with tf.name_scope('conv1_1') as scope:
kernel = weight_variable([3, 3, 3, 64])
biases = bias_variable([64])
output_conv1_1 = tf.nn.relu(conv2d(x, kernel) + biases, name=scope)
with tf.name_scope('conv1_2') as scope:
kernel = weight_variable([3, 3, 64, 64])
biases = bias_variable([64])
output_conv1_2 = tf.nn.relu(conv2d(output_conv1_1, kernel) + biases, name=scope)
pool1 = pool_max(output_conv1_2)
# conv2
with tf.name_scope('conv2_1') as scope:
kernel = weight_variable([3, 3, 64, 128])
biases = bias_variable([128])
output_conv2_1 = tf.nn.relu(conv2d(pool1, kernel) + biases, name=scope)
with tf.name_scope('conv2_2') as scope:
kernel = weight_variable([3, 3, 128, 128])
biases = bias_variable([128])
output_conv2_2 = tf.nn.relu(conv2d(output_conv2_1, kernel) + biases, name=scope)
pool2 = pool_max(output_conv2_2)
# conv3
with tf.name_scope('conv3_1') as scope:
kernel = weight_variable([3, 3, 128, 256])
biases = bias_variable([256])
output_conv3_1 = tf.nn.relu(conv2d(pool2, kernel) + biases, name=scope)
with tf.name_scope('conv3_2') as scope:
kernel = weight_variable([3, 3, 256, 256])
biases = bias_variable([256])
output_conv3_2 = tf.nn.relu(conv2d(output_conv3_1, kernel) + biases, name=scope)
with tf.name_scope('conv3_3') as scope:
kernel = weight_variable([3, 3, 256, 256])
biases = bias_variable([256])
output_conv3_3 = tf.nn.relu(conv2d(output_conv3_2, kernel) + biases, name=scope)
pool3 = pool_max(output_conv3_3)
# conv4
with tf.name_scope('conv4_1') as scope:
kernel = weight_variable([3, 3, 256, 512])
biases = bias_variable([512])
output_conv4_1 = tf.nn.relu(conv2d(pool3, kernel) + biases, name=scope)
with tf.name_scope('conv4_2') as scope:
kernel = weight_variable([3, 3, 512, 512])
biases = bias_variable([512])
output_conv4_2 = tf.nn.relu(conv2d(output_conv4_1, kernel) + biases, name=scope)
with tf.name_scope('conv4_3') as scope:
kernel = weight_variable([3, 3, 512, 512])
biases = bias_variable([512])
output_conv4_3 = tf.nn.relu(conv2d(output_conv4_2, kernel) + biases, name=scope)
pool4 = pool_max(output_conv4_3)
# conv5
with tf.name_scope('conv5_1') as scope:
kernel = weight_variable([3, 3, 512, 512])
biases = bias_variable([512])
output_conv5_1 = tf.nn.relu(conv2d(pool4, kernel) + biases, name=scope)
with tf.name_scope('conv5_2') as scope:
kernel = weight_variable([3, 3, 512, 512])
biases = bias_variable([512])
output_conv5_2 = tf.nn.relu(conv2d(output_conv5_1, kernel) + biases, name=scope)
with tf.name_scope('conv5_3') as scope:
kernel = weight_variable([3, 3, 512, 512])
biases = bias_variable([512])
output_conv5_3 = tf.nn.relu(conv2d(output_conv5_2, kernel) + biases, name=scope)
pool5 = pool_max(output_conv5_3)
#fc6
with tf.name_scope('fc6') as scope:
shape = int(np.prod(pool5.get_shape()[1:]))
kernel = weight_variable([shape, 4096])
biases = bias_variable([4096])
pool5_flat = tf.reshape(pool5, [-1, shape])
output_fc6 = tf.nn.relu(fc(pool5_flat, kernel, biases), name=scope)
#fc7
with tf.name_scope('fc7') as scope:
kernel = weight_variable([4096, 4096])
biases = bias_variable([4096])
output_fc7 = tf.nn.relu(fc(output_fc6, kernel, biases), name=scope)
#fc8
with tf.name_scope('fc8') as scope:
kernel = weight_variable([4096, 2])
biases = bias_variable([2])
output_fc8 = tf.nn.relu(fc(output_fc7, kernel, biases), name=scope)
finaloutput = tf.nn.softmax(output_fc8, name="softmax")
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=finaloutput, labels=y))
optimize = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(cost)
prediction_labels = tf.argmax(finaloutput, axis=1, name="output")
read_labels = y
correct_prediction = tf.equal(prediction_labels, read_labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
correct_times_in_batch = tf.reduce_sum(tf.cast(correct_prediction, tf.int32))
return dict(
x=x,
y=y,
optimize=optimize,
correct_prediction=correct_prediction,
correct_times_in_batch=correct_times_in_batch,
cost=cost,
)
def train_network(graph, batch_size, num_epochs, pb_file_path):
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
epoch_delta = 2
for epoch_index in range(num_epochs):
for i in range(12):
sess.run([graph['optimize']], feed_dict={
graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
})
if epoch_index % epoch_delta == 0:
total_batches_in_train_set = 0
total_correct_times_in_train_set = 0
total_cost_in_train_set = 0.
for i in range(12):
return_correct_times_in_batch = sess.run(graph['correct_times_in_batch'], feed_dict={
graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
})
mean_cost_in_batch = sess.run(graph['cost'], feed_dict={
graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
})
total_batches_in_train_set += 1
total_correct_times_in_train_set += return_correct_times_in_batch
total_cost_in_train_set += (mean_cost_in_batch * batch_size)
total_batches_in_test_set = 0
total_correct_times_in_test_set = 0
total_cost_in_test_set = 0.
for i in range(3):
return_correct_times_in_batch = sess.run(graph['correct_times_in_batch'], feed_dict={
graph['x']: np.reshape(x_val[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_val[i] == 0 else [[0, 1]])
})
mean_cost_in_batch = sess.run(graph['cost'], feed_dict={
graph['x']: np.reshape(x_val[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_val[i] == 0 else [[0, 1]])
})
total_batches_in_test_set += 1
total_correct_times_in_test_set += return_correct_times_in_batch
total_cost_in_test_set += (mean_cost_in_batch * batch_size)
acy_on_test = total_correct_times_in_test_set / float(total_batches_in_test_set * batch_size)
acy_on_train = total_correct_times_in_train_set / float(total_batches_in_train_set * batch_size)
print('Epoch - {:2d}, acy_on_test:{:6.2f}%({}/{}),loss_on_test:{:6.2f}, acy_on_train:{:6.2f}%({}/{}),loss_on_train:{:6.2f}'.format(epoch_index, acy_on_test*100.0,total_correct_times_in_test_set,
total_batches_in_test_set * batch_size,
total_cost_in_test_set,
acy_on_train * 100.0,
total_correct_times_in_train_set,
total_batches_in_train_set * batch_size,
total_cost_in_train_set))
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"])
with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
f.write(constant_graph.SerializeToString())
def main():
batch_size = 12
num_epochs = 50
pb_file_path = "vggs.pb"
g = build_network(height=224, width=224, channel=3)
train_network(g, batch_size, num_epochs, pb_file_path)
main()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
- 117
- 118
- 119
- 120
- 121
- 122
- 123
- 124
- 125
- 126
- 127
- 128
- 129
- 130
- 131
- 132
- 133
- 134
- 135
- 136
- 137
- 138
- 139
- 140
- 141
- 142
- 143
- 144
- 145
- 146
- 147
- 148
- 149
- 150
- 151
- 152
- 153
- 154
- 155
- 156
- 157
- 158
- 159
- 160
- 161
- 162
- 163
- 164
- 165
- 166
- 167
- 168
- 169
- 170
- 171
- 172
- 173
- 174
- 175
- 176
- 177
- 178
- 179
- 180
- 181
- 182
- 183
- 184
- 185
- 186
- 187
- 188
- 189
- 190
- 191
- 192
- 193
- 194
- 195
- 196
- 197
- 198
- 199
- 200
- 201
- 202
- 203
- 204
- 205
- 206
- 207
- 208
- 209
- 210
- 211
- 212
- 213
- 214
- 215
- 216
- 217
- 218
- 219
- 220
- 221
- 222
- 223
- 224
- 225
- 226
- 227
- 228
- 229
- 230
- 231
- 232
- 233
- 234
- 235
- 236
- 237
- 238
- 239
- 240
- 241
- 242
- 243
- 244
- 245
- 246
- 247
- 248
- 249
- 250
- 251
- 252
- 253
- 254
- 255
- 256
- 257
- 258
- 259
- 260
- 261
- 262
- 263
- 264
- 265
test
import tensorflow as tf
import numpy as np
import PIL.Image as Image
from skimage import io, transform
def recognize(jpg_path, pb_file_path):
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(pb_file_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
input_x = sess.graph.get_tensor_by_name("input:0")
print input_x
out_softmax = sess.graph.get_tensor_by_name("softmax:0")
print out_softmax
out_label = sess.graph.get_tensor_by_name("output:0")
print out_label
img = io.imread(jpg_path)
img = transform.resize(img, (224, 224, 3))
img_out_softmax = sess.run(out_softmax, feed_dict={input_x:np.reshape(img, [-1, 224, 224, 3])})
print "img_out_softmax:",img_out_softmax
prediction_labels = np.argmax(img_out_softmax, axis=1)
print "label:",prediction_labels
recognize("vgg16/picture/dog/dog3.jpg", "vgg16/vggs.pb")
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
如何用Tensorflow训练模型成pb文件和和如何加载已经训练好的模型文件的更多相关文章
- cordova加载层、进度条、文件选择插件
在做cordova项目的时候,感觉应用的响应速度跟原生应用比相差甚远,一个主要问题就是如加载层.进度条等弹出对话框的效率不行.毕竟项目中的这些弹框都是用dom拼成的,dom的渲染效率和原生控件比起来慢 ...
- asp.net使用httphandler打包多CSS或JS文件以加快页面加载速度
介绍 使用许多小得JS.CSS文件代替一个庞大的JS或CSS文件来让代码获得更好的可维 护性,这是一个很好的实践.但这样做反过来却损失了网站的性能.虽然你应该将你的Javascript代码写在小文件中 ...
- ThinkPHP第三天(公共函数Common加载,dump定义,模板文件,定义替换__PUBLIC__)
1.公共函数定义 自动加载:在项目的common文件夹中定义,公共函数文件命名规则为common.php,只有命名成common.php才能被自动载入. 动态加载:可以修改配置项‘LOAD_EXT_F ...
- java中调用本地动态链接库(*.DLL)的两种方式详解和not found library、打包成jar,war包dll无法加载等等问题解决办法
我们经常会遇到需要java调用c++的案例,这里就java调用DLL本地动态链接库两种方式,和加载过程中遇到的问题进行详细介绍 1.通过System.loadLibrary("dll名称,不 ...
- 精尽MyBatis源码分析 - MyBatis初始化(二)之加载Mapper接口与XML映射文件
该系列文档是本人在学习 Mybatis 的源码过程中总结下来的,可能对读者不太友好,请结合我的源码注释(Mybatis源码分析 GitHub 地址.Mybatis-Spring 源码分析 GitHub ...
- [Android]异步加载图片,内存缓存,文件缓存,imageview显示图片时增加淡入淡出动画
以下内容为原创,欢迎转载,转载请注明 来自天天博客:http://www.cnblogs.com/tiantianbyconan/p/3574131.html 这个可以实现ImageView异步加载 ...
- 网站加载css/js/img等静态文件失败
网站加载css/js/img等静态文件失败,报网站http服务器内部500错误.而服务器中静态文件存在且权限正常. 从浏览器中直接访问文件,出来乱码.这种问题原因在于iis中该网站mime配置报错,不 ...
- 安装Win7或者XP系统用虚拟光驱加载Win7或者XP镜像 iso文件xp win7wim文件
安装Win7或者XP系统用虚拟光驱加载Win7或者XP镜像 iso文件xp win7wim文件 http://pcedu.pconline.com.cn/teach/xt/1201/2657834_8 ...
- MVP+RXJAVA+RecyclerView实现sd卡根目录下的所有文件中的照片加载并显示
初学Rxjava,目前只能遍历加载指定目录下的所有文件夹中的照片,文件夹中如果还嵌套有文件夹目前还没找到实现方法. 先看mvp目录结构: 很抱歉,没有model. 接下来是view层的接口代码和pre ...
随机推荐
- iOS UI基础-11.0 UINavigationController
导航控制器 利用UINavigationController,可以轻松地管理多个控制器,轻松完成控制器之间的切换,典型例子就是系统自带的“设置”应用 UINavigationController的使用 ...
- iOS 开发笔记-NSURLConnection的使用
通过NSURLConnection发送一个HTTP GET请求 //send a GET request to server with some params -(void)httpGetWithPa ...
- Sql之left join(左关联)、right join(右关联)、inner join(自关联)的区别
参考:https://blog.csdn.net/hj7jay/article/details/51749863
- vue中使用elementui里的table时不被选中设置
情景:例如提现列表,转账失败后转账金额直接返回用户余额,所以当前数据不可以再次操作 直接粘贴代码: <el-table-column type="selection" wid ...
- VMware vSphere
在进行操作vSphere产品之前,就曾经对它进行过一个简单了解:[运维]VMware vSphere简单了解,现在再回头看,发现了解的真的是太简单了.经过前一段时间学习之后,对它又有了新的感悟,再来谈 ...
- VC2012+QT新建一个控制台程序
1.新建一个项目,选择控制台程序 2.下一步.project setting 可以包含模块,可以再这选择也可以之后选择 3.配置工程属性 1)需要源码的话添加VC++目录里的源目录 2)包含头文件 ...
- Intel 80386 微处理器的存储器管理
一.存储器的管理 存储器的管理是一种硬件机制,微处理器在总线地址上对物理存储器进行寻址.但是,为了给程序提供比物理存储器容量更大的空间,就引入了虚拟存储器的概念,它在外存(比如磁盘)的支持 ...
- lua元表学习
a = {, } b= {, } vector2 = {} function vector2.Add(v1, v2) if(v1 == nil or v2 == nil)then return nil ...
- java集合类图
- Xfire基础
XFire 是与Axis 2并列的新一代Web Service框架,通过提供简单的API支持Web Service各项标准协议,能够快速地开发Web Service应用.和其他Web服务引擎相比,XF ...