如何用Tensorflow训练模型成pb文件和和如何加载已经训练好的模型文件
这篇薄荷主要是讲了如何用tensorflow去训练好一个模型,然后生成相应的pb文件。最后会将如何重新加载这个pb文件。
首先先放出PO主的github:
https://github.com/ppplinday/tensorflow-vgg16-train-and-test
其中的pitcute文件是狗和猫的图片分别15张一共30(别吐槽,只是为了练手学习的233333), train那个就是训练的文件,test这个就是测试的文件。
接着PO主会慢慢讲解相应的步骤。
!!!ps:由于PO主也是新手,所以难免会出现一点(很多)小错误,希望大婶看了能够提出来让PO主好好学习233333。
- train
首先说一下train。一开始当然是读图片啦。
def read_img(path):
cate = [path + x for x in os.listdir(path) if os.path.isdir(path + x)]
imgs = []
labels = []
for idx, folder in enumerate(cate):
for im in glob.glob(folder + '/*.jpg'):
print('reading the image: %s' % (im))
img = io.imread(im)
img = transform.resize(img, (w, h, c))
imgs.append(img)
labels.append(idx)
return np.asarray(imgs, np.float32), np.asarray(labels, np.int32)
data, label = read_img(path)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
用io.imread来读取每一张图片,然后resize成vgg的输入的大小(224,224,3),最后分别放入了data和label中。
num_example = data.shape[0]
arr = np.arange(num_example)
np.random.shuffle(arr)
data = data[arr]
label = label[arr]
- 1
- 2
- 3
- 4
- 5
这里是把图片的顺序打乱,先生成一个等差数列,然后打乱,最后赋值回原来的data和label
ratio = 0.8
s = np.int(num_example * ratio)
x_train = data[:s]
y_train = label[:s]
x_val = data[s:]
y_val = label[s:]
- 1
- 2
- 3
- 4
- 5
- 6
全部的数据中百分之80的用来train,剩下20的用来test(虽然一共才30张图片。。。。。)
def build_network(height, width, channel):
x = tf.placeholder(tf.float32, shape=[None, height, width, channel], name='input')
y = tf.placeholder(tf.int64, shape=[None, 2], name='labels_placeholder')
- 1
- 2
- 3
- 4
开始build相应的vgg model,这一步不难,但是每一层最好都给上相应的name。上面的x和y是相应的输入和相应的标签。
finaloutput = tf.nn.softmax(output_fc8, name="softmax")
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=finaloutput, labels=y))
optimize = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(cost)
prediction_labels = tf.argmax(finaloutput, axis=1, name="output")
read_labels = y
correct_prediction = tf.equal(prediction_labels, read_labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
correct_times_in_batch = tf.reduce_sum(tf.cast(correct_prediction, tf.int32))
return dict(
x=x,
y=y,
optimize=optimize,
correct_prediction=correct_prediction,
correct_times_in_batch=correct_times_in_batch,
cost=cost,
)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
在build的最后,是需要进行误差计算。finaloutput是最后的输出,cost是计算误差,optimize是定义训练时候安什么方式,也注意一下最后的return。
接着是训练过程。
def train_network(graph, batch_size, num_epochs, pb_file_path):
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
epoch_delta = 2
for epoch_index in range(num_epochs):
for i in range(12):
sess.run([graph['optimize']], feed_dict={
graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
})
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
其实训练的代码就这些,定好了batchsize和numepoch进行训练。下面的代码主要是为了看每几次相应的正确率。
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"])
with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
f.write(constant_graph.SerializeToString())
- 1
- 2
- 3
这两句是重要的代码,用来把训练好的模型保存为pb文件。运行完之后就会发现应该的文件夹多出了一个pb文件。
- test
def recognize(jpg_path, pb_file_path):
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(pb_file_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
- 1
- 2
- 3
- 4
- 5
- 6
- 7
打开相应的pb文件。
img = io.imread(jpg_path)
img = transform.resize(img, (224, 224, 3))
img_out_softmax = sess.run(out_softmax, feed_dict={input_x:np.reshape(img, [-1, 224, 224, 3])})
- 1
- 2
- 3
- 4
读取图片文件,resize之后放入模型的输入位置,之后img_out_softmax就是相应输出的结果。
这大概就是整个流程。目的是为了练练手,PO主应该有挺多小错误,希望大家能够提出来让PO主好好学习哈哈哈!!!
最后放出整个的train和test的代码:
train
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import tensorflow as tf
import os
import glob
from skimage import io, transform
from tensorflow.python.framework import graph_util
import collections
path = '/home/zhoupeilin/vgg16/picture/'
w = 224
h = 224
c = 3
def read_img(path):
cate = [path + x for x in os.listdir(path) if os.path.isdir(path + x)]
imgs = []
labels = []
for idx, folder in enumerate(cate):
for im in glob.glob(folder + '/*.jpg'):
print('reading the image: %s' % (im))
img = io.imread(im)
img = transform.resize(img, (w, h, c))
imgs.append(img)
labels.append(idx)
return np.asarray(imgs, np.float32), np.asarray(labels, np.int32)
data, label = read_img(path)
num_example = data.shape[0]
arr = np.arange(num_example)
np.random.shuffle(arr)
data = data[arr]
label = label[arr]
ratio = 0.8
s = np.int(num_example * ratio)
x_train = data[:s]
y_train = label[:s]
x_val = data[s:]
y_val = label[s:]
def build_network(height, width, channel):
x = tf.placeholder(tf.float32, shape=[None, height, width, channel], name='input')
y = tf.placeholder(tf.int64, shape=[None, 2], name='labels_placeholder')
def weight_variable(shape, name="weights"):
initial = tf.truncated_normal(shape, dtype=tf.float32, stddev=0.1)
return tf.Variable(initial, name=name)
def bias_variable(shape, name="biases"):
initial = tf.constant(0.1, dtype=tf.float32, shape=shape)
return tf.Variable(initial, name=name)
def conv2d(input, w):
return tf.nn.conv2d(input, w, [1, 1, 1, 1], padding='SAME')
def pool_max(input):
return tf.nn.max_pool(input,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding='SAME',
name='pool1')
def fc(input, w, b):
return tf.matmul(input, w) + b
# conv1
with tf.name_scope('conv1_1') as scope:
kernel = weight_variable([3, 3, 3, 64])
biases = bias_variable([64])
output_conv1_1 = tf.nn.relu(conv2d(x, kernel) + biases, name=scope)
with tf.name_scope('conv1_2') as scope:
kernel = weight_variable([3, 3, 64, 64])
biases = bias_variable([64])
output_conv1_2 = tf.nn.relu(conv2d(output_conv1_1, kernel) + biases, name=scope)
pool1 = pool_max(output_conv1_2)
# conv2
with tf.name_scope('conv2_1') as scope:
kernel = weight_variable([3, 3, 64, 128])
biases = bias_variable([128])
output_conv2_1 = tf.nn.relu(conv2d(pool1, kernel) + biases, name=scope)
with tf.name_scope('conv2_2') as scope:
kernel = weight_variable([3, 3, 128, 128])
biases = bias_variable([128])
output_conv2_2 = tf.nn.relu(conv2d(output_conv2_1, kernel) + biases, name=scope)
pool2 = pool_max(output_conv2_2)
# conv3
with tf.name_scope('conv3_1') as scope:
kernel = weight_variable([3, 3, 128, 256])
biases = bias_variable([256])
output_conv3_1 = tf.nn.relu(conv2d(pool2, kernel) + biases, name=scope)
with tf.name_scope('conv3_2') as scope:
kernel = weight_variable([3, 3, 256, 256])
biases = bias_variable([256])
output_conv3_2 = tf.nn.relu(conv2d(output_conv3_1, kernel) + biases, name=scope)
with tf.name_scope('conv3_3') as scope:
kernel = weight_variable([3, 3, 256, 256])
biases = bias_variable([256])
output_conv3_3 = tf.nn.relu(conv2d(output_conv3_2, kernel) + biases, name=scope)
pool3 = pool_max(output_conv3_3)
# conv4
with tf.name_scope('conv4_1') as scope:
kernel = weight_variable([3, 3, 256, 512])
biases = bias_variable([512])
output_conv4_1 = tf.nn.relu(conv2d(pool3, kernel) + biases, name=scope)
with tf.name_scope('conv4_2') as scope:
kernel = weight_variable([3, 3, 512, 512])
biases = bias_variable([512])
output_conv4_2 = tf.nn.relu(conv2d(output_conv4_1, kernel) + biases, name=scope)
with tf.name_scope('conv4_3') as scope:
kernel = weight_variable([3, 3, 512, 512])
biases = bias_variable([512])
output_conv4_3 = tf.nn.relu(conv2d(output_conv4_2, kernel) + biases, name=scope)
pool4 = pool_max(output_conv4_3)
# conv5
with tf.name_scope('conv5_1') as scope:
kernel = weight_variable([3, 3, 512, 512])
biases = bias_variable([512])
output_conv5_1 = tf.nn.relu(conv2d(pool4, kernel) + biases, name=scope)
with tf.name_scope('conv5_2') as scope:
kernel = weight_variable([3, 3, 512, 512])
biases = bias_variable([512])
output_conv5_2 = tf.nn.relu(conv2d(output_conv5_1, kernel) + biases, name=scope)
with tf.name_scope('conv5_3') as scope:
kernel = weight_variable([3, 3, 512, 512])
biases = bias_variable([512])
output_conv5_3 = tf.nn.relu(conv2d(output_conv5_2, kernel) + biases, name=scope)
pool5 = pool_max(output_conv5_3)
#fc6
with tf.name_scope('fc6') as scope:
shape = int(np.prod(pool5.get_shape()[1:]))
kernel = weight_variable([shape, 4096])
biases = bias_variable([4096])
pool5_flat = tf.reshape(pool5, [-1, shape])
output_fc6 = tf.nn.relu(fc(pool5_flat, kernel, biases), name=scope)
#fc7
with tf.name_scope('fc7') as scope:
kernel = weight_variable([4096, 4096])
biases = bias_variable([4096])
output_fc7 = tf.nn.relu(fc(output_fc6, kernel, biases), name=scope)
#fc8
with tf.name_scope('fc8') as scope:
kernel = weight_variable([4096, 2])
biases = bias_variable([2])
output_fc8 = tf.nn.relu(fc(output_fc7, kernel, biases), name=scope)
finaloutput = tf.nn.softmax(output_fc8, name="softmax")
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=finaloutput, labels=y))
optimize = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(cost)
prediction_labels = tf.argmax(finaloutput, axis=1, name="output")
read_labels = y
correct_prediction = tf.equal(prediction_labels, read_labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
correct_times_in_batch = tf.reduce_sum(tf.cast(correct_prediction, tf.int32))
return dict(
x=x,
y=y,
optimize=optimize,
correct_prediction=correct_prediction,
correct_times_in_batch=correct_times_in_batch,
cost=cost,
)
def train_network(graph, batch_size, num_epochs, pb_file_path):
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
epoch_delta = 2
for epoch_index in range(num_epochs):
for i in range(12):
sess.run([graph['optimize']], feed_dict={
graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
})
if epoch_index % epoch_delta == 0:
total_batches_in_train_set = 0
total_correct_times_in_train_set = 0
total_cost_in_train_set = 0.
for i in range(12):
return_correct_times_in_batch = sess.run(graph['correct_times_in_batch'], feed_dict={
graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
})
mean_cost_in_batch = sess.run(graph['cost'], feed_dict={
graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
})
total_batches_in_train_set += 1
total_correct_times_in_train_set += return_correct_times_in_batch
total_cost_in_train_set += (mean_cost_in_batch * batch_size)
total_batches_in_test_set = 0
total_correct_times_in_test_set = 0
total_cost_in_test_set = 0.
for i in range(3):
return_correct_times_in_batch = sess.run(graph['correct_times_in_batch'], feed_dict={
graph['x']: np.reshape(x_val[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_val[i] == 0 else [[0, 1]])
})
mean_cost_in_batch = sess.run(graph['cost'], feed_dict={
graph['x']: np.reshape(x_val[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_val[i] == 0 else [[0, 1]])
})
total_batches_in_test_set += 1
total_correct_times_in_test_set += return_correct_times_in_batch
total_cost_in_test_set += (mean_cost_in_batch * batch_size)
acy_on_test = total_correct_times_in_test_set / float(total_batches_in_test_set * batch_size)
acy_on_train = total_correct_times_in_train_set / float(total_batches_in_train_set * batch_size)
print('Epoch - {:2d}, acy_on_test:{:6.2f}%({}/{}),loss_on_test:{:6.2f}, acy_on_train:{:6.2f}%({}/{}),loss_on_train:{:6.2f}'.format(epoch_index, acy_on_test*100.0,total_correct_times_in_test_set,
total_batches_in_test_set * batch_size,
total_cost_in_test_set,
acy_on_train * 100.0,
total_correct_times_in_train_set,
total_batches_in_train_set * batch_size,
total_cost_in_train_set))
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"])
with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
f.write(constant_graph.SerializeToString())
def main():
batch_size = 12
num_epochs = 50
pb_file_path = "vggs.pb"
g = build_network(height=224, width=224, channel=3)
train_network(g, batch_size, num_epochs, pb_file_path)
main()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
- 117
- 118
- 119
- 120
- 121
- 122
- 123
- 124
- 125
- 126
- 127
- 128
- 129
- 130
- 131
- 132
- 133
- 134
- 135
- 136
- 137
- 138
- 139
- 140
- 141
- 142
- 143
- 144
- 145
- 146
- 147
- 148
- 149
- 150
- 151
- 152
- 153
- 154
- 155
- 156
- 157
- 158
- 159
- 160
- 161
- 162
- 163
- 164
- 165
- 166
- 167
- 168
- 169
- 170
- 171
- 172
- 173
- 174
- 175
- 176
- 177
- 178
- 179
- 180
- 181
- 182
- 183
- 184
- 185
- 186
- 187
- 188
- 189
- 190
- 191
- 192
- 193
- 194
- 195
- 196
- 197
- 198
- 199
- 200
- 201
- 202
- 203
- 204
- 205
- 206
- 207
- 208
- 209
- 210
- 211
- 212
- 213
- 214
- 215
- 216
- 217
- 218
- 219
- 220
- 221
- 222
- 223
- 224
- 225
- 226
- 227
- 228
- 229
- 230
- 231
- 232
- 233
- 234
- 235
- 236
- 237
- 238
- 239
- 240
- 241
- 242
- 243
- 244
- 245
- 246
- 247
- 248
- 249
- 250
- 251
- 252
- 253
- 254
- 255
- 256
- 257
- 258
- 259
- 260
- 261
- 262
- 263
- 264
- 265
test
import tensorflow as tf
import numpy as np
import PIL.Image as Image
from skimage import io, transform
def recognize(jpg_path, pb_file_path):
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(pb_file_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
input_x = sess.graph.get_tensor_by_name("input:0")
print input_x
out_softmax = sess.graph.get_tensor_by_name("softmax:0")
print out_softmax
out_label = sess.graph.get_tensor_by_name("output:0")
print out_label
img = io.imread(jpg_path)
img = transform.resize(img, (224, 224, 3))
img_out_softmax = sess.run(out_softmax, feed_dict={input_x:np.reshape(img, [-1, 224, 224, 3])})
print "img_out_softmax:",img_out_softmax
prediction_labels = np.argmax(img_out_softmax, axis=1)
print "label:",prediction_labels
recognize("vgg16/picture/dog/dog3.jpg", "vgg16/vggs.pb")
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
如何用Tensorflow训练模型成pb文件和和如何加载已经训练好的模型文件的更多相关文章
- cordova加载层、进度条、文件选择插件
在做cordova项目的时候,感觉应用的响应速度跟原生应用比相差甚远,一个主要问题就是如加载层.进度条等弹出对话框的效率不行.毕竟项目中的这些弹框都是用dom拼成的,dom的渲染效率和原生控件比起来慢 ...
- asp.net使用httphandler打包多CSS或JS文件以加快页面加载速度
介绍 使用许多小得JS.CSS文件代替一个庞大的JS或CSS文件来让代码获得更好的可维 护性,这是一个很好的实践.但这样做反过来却损失了网站的性能.虽然你应该将你的Javascript代码写在小文件中 ...
- ThinkPHP第三天(公共函数Common加载,dump定义,模板文件,定义替换__PUBLIC__)
1.公共函数定义 自动加载:在项目的common文件夹中定义,公共函数文件命名规则为common.php,只有命名成common.php才能被自动载入. 动态加载:可以修改配置项‘LOAD_EXT_F ...
- java中调用本地动态链接库(*.DLL)的两种方式详解和not found library、打包成jar,war包dll无法加载等等问题解决办法
我们经常会遇到需要java调用c++的案例,这里就java调用DLL本地动态链接库两种方式,和加载过程中遇到的问题进行详细介绍 1.通过System.loadLibrary("dll名称,不 ...
- 精尽MyBatis源码分析 - MyBatis初始化(二)之加载Mapper接口与XML映射文件
该系列文档是本人在学习 Mybatis 的源码过程中总结下来的,可能对读者不太友好,请结合我的源码注释(Mybatis源码分析 GitHub 地址.Mybatis-Spring 源码分析 GitHub ...
- [Android]异步加载图片,内存缓存,文件缓存,imageview显示图片时增加淡入淡出动画
以下内容为原创,欢迎转载,转载请注明 来自天天博客:http://www.cnblogs.com/tiantianbyconan/p/3574131.html 这个可以实现ImageView异步加载 ...
- 网站加载css/js/img等静态文件失败
网站加载css/js/img等静态文件失败,报网站http服务器内部500错误.而服务器中静态文件存在且权限正常. 从浏览器中直接访问文件,出来乱码.这种问题原因在于iis中该网站mime配置报错,不 ...
- 安装Win7或者XP系统用虚拟光驱加载Win7或者XP镜像 iso文件xp win7wim文件
安装Win7或者XP系统用虚拟光驱加载Win7或者XP镜像 iso文件xp win7wim文件 http://pcedu.pconline.com.cn/teach/xt/1201/2657834_8 ...
- MVP+RXJAVA+RecyclerView实现sd卡根目录下的所有文件中的照片加载并显示
初学Rxjava,目前只能遍历加载指定目录下的所有文件夹中的照片,文件夹中如果还嵌套有文件夹目前还没找到实现方法. 先看mvp目录结构: 很抱歉,没有model. 接下来是view层的接口代码和pre ...
随机推荐
- 服务器非root安装gcc 4.8.5
1.下载gcc-4.8.5: $ wget ftp://ftp.gnu.org/gnu/gcc/gcc-4.8.5/gcc-4.8.5.tar.gz 2.解压gcc: $ tar zxvf gcc-. ...
- Windows服务器时间不同步问题
一台域内的服务器时间不停地被修改,我先向用户收集了一些信息 只有这一台出现此问题,其他服务器均为正常(补充一下,问题快解决完的时候用户告诉我一个重要的消息,就是时间被修改了一段时间后自动会被修改回去) ...
- Python对list列表及子列表进行排序
python代码,对list进行升序排序,所有子列表也要进行排序 def iterList(listVar): listVar = sorted(listVar) for i,v in enumera ...
- Selenium基础知识(六)下拉列表定位
1.下拉列表定位 要选择下拉列表中的元素,要先定位到,下拉列表元素,然后可以通过xpath去点击,表内内容 例如,百度搜索-->百度设置-->搜索设置-->选择下拉列表框内" ...
- keras图像分类参考大神博客总结
利用keras预加载模型添加新的层来构建自己所需的模型: from keras.layers import GlobalAveragePooling2D,Dense from keras.applic ...
- 利用css伪类编写冒泡小三角
HTML代码 <div class="lf otherLogin"> <span>其他方式注册</span> <div class=&qu ...
- Linux shell脚本 批量创建多个用户
Linux shell脚本 批量创建多个用户 #!/bin/bash groupadd charlesgroup for username in charles1 charles2 charles3 ...
- es6proxy
Proxy 支持的拦截操作一览. 对于可以设置.但没有设置拦截的操作,则直接落在目标对象上,按照原先的方式产生结果. (1)get(target, propKey, receiver) 拦截对象属性的 ...
- 011-Server服务器对象属性
Transfer:第一个页面直接调用第二个页面,执行完第二个页面后不再返回第一个页面,立即响应到客户端浏览器.Execute:第一个页面直接调用第二个页面,执行完第二个页面后再返回第一个页面执行,最后 ...
- jquery.ajax请求aspx和ashx的异同 Jquery Ajax调用aspx页面方法
1.jquery.ajax请求aspx 请求aspx的静态方法要注意一下问题: (1)aspx的后台方法必须静态,而且添加webmethod特性 (2)在ajax方法中contentType必须是“a ...