tflearn 中文汉字识别,训练后模型存为pb给TensorFlow使用。

数据目录在data,data下放了汉字识别图片:

data$ ls
0  1  10  11  12  13  14  15  16  2  3  4  5  6  7  8  9
datag$ ls 0
xxx.png yyy.png ....

代码:

如果将get model里的模型层数加非常深,训练时候很可能不会收敛,精度一直停留下1%以内。

# -*- coding: utf-8 -*-

from __future__ import division, print_function, absolute_import

import os
import numpy as np
import pickle
import tflearn from PIL import Image
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d, avg_pool_2d
from tflearn.layers.merge_ops import merge
from tflearn.layers.estimator import regression
from tflearn.data_utils import to_categorical, shuffle
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from tflearn.layers.conv import highway_conv_2d, max_pool_2d
from tflearn.layers.normalization import local_response_normalization, batch_normalization def resize_image(in_image, new_width, new_height, out_image=None,
resize_mode=Image.ANTIALIAS):
""" Resize an image.
Arguments:
in_image: `PIL.Image`. The image to resize.
new_width: `int`. The image new width.
new_height: `int`. The image new height.
out_image: `str`. If specified, save the image to the given path.
resize_mode: `PIL.Image.mode`. The resizing mode. Returns:
`PIL.Image`. The resize image.
"""
img = in_image.resize((new_width, new_height), resize_mode)
if out_image:
img.save(out_image)
return img def convert_color(in_image, mode):
""" Convert image color with provided `mode`. """
return in_image.convert(mode) def pil_to_nparray(pil_image):
""" Convert a PIL.Image to numpy array. """
pil_image.load()
return np.asarray(pil_image, dtype="float32") def iterbrowse(path):
for home, dirs, files in os.walk(path):
for filename in files:
yield os.path.join(home, filename) def directory_to_samples(directory, flags):
""" Read a directory, and list all subdirectories files as class sample """
samples = []
targets = []
# label class is from 0 !!!
label = 0
try: # Python 2
classes = sorted(os.walk(directory).next()[1])
except Exception: # Python 3
classes = sorted(os.walk(directory).__next__()[1])
for c in classes:
c_dir = os.path.join(directory, c)
try: # Python 2
walk = os.walk(c_dir).next()
except Exception: # Python 3
walk = os.walk(c_dir).__next__()
for sample in walk[2]:
if any(flag in sample for flag in flags):
samples.append(os.path.join(c_dir, sample))
targets.append(label)
label += 1
return samples, targets # Get the pixel from the given image
def get_pixel(image, i, j):
# Inside image bounds?
width, height = image.size
if i > width or j > height:
return None # Get Pixel
pixel = image.getpixel((i, j))
return pixel # Create a Grayscale version of the image
def convert_to_one_channel(image):
# !!! I assume that the png file is grayscale. And R == G == B !!!! So I check it...
"""
for i in range(len(image)):
for j in range(len(image[i])):
pixel = image[i][j]
# Get R, G, B values (This are int from 0 to 255)
assert len(pixel) == 3
red = pixel[0]
green = pixel[1]
blue = pixel[2]
assert red == green == blue
assert 0 <= red <= 1
"""
# Just extract 1 channel data
return image[:, :, [0]] def image_dirs_to_samples(directory, resize=None, convert_gray=False,
filetypes=None):
print("Starting to parse images...")
samples, targets = directory_to_samples(directory, flags=filetypes)
for i, s in enumerate(samples):
print("Process %d th file %s" % (i+1, s))
samples[i] = Image.open(s) # Load an image, returns PIL.Image.
if resize:
######################## TODO #######################
samples[i] = resize_image(samples[i], resize[0],
resize[1])
######################### TODO ####################### convert to more data
# if convert_gray:
# samples[i] = convert_color(samples[i], 'L')
samples[i] = pil_to_nparray(samples[i])
samples[i] /= 255. # ormalize a list of sample image data in the range of 0 to 1
samples[i] = convert_to_one_channel(samples[i]) # just want 1 channel data
print("Parsing Done!")
return samples, targets def load_data(dirname, resize_pics=(128, 128), shuffle_data=True):
dataset_file = os.path.join(dirname, 'data.pkl')
try:
X, Y, org_labels = pickle.load(open(dataset_file, 'rb'))
except Exception:
# X, Y = image_dirs_to_samples(os.path.join(dirname, 'train/'), resize_pics, False, ['.jpg', '.png'])
X, Y = image_dirs_to_samples(dirname, resize_pics, False,
['.jpg', '.png']) # TODO, memory is too small to load all data
org_labels = Y
Y = to_categorical(Y, np.max(Y) + 1) # First class is '0', Convert class vector (integers from 0 to nb_classes)
if shuffle_data:
X, Y, org_labels = shuffle(X, Y, org_labels)
pickle.dump((X, Y, org_labels), open(dataset_file, 'wb'))
return X, Y, org_labels class EarlyStoppingCallback(tflearn.callbacks.Callback):
def __init__(self, val_acc_thresh):
# Store a validation accuracy threshold, which we can compare against
# the current validation accuracy at, say, each epoch, each batch step, etc.
self.val_acc_thresh = val_acc_thresh def on_epoch_end(self, training_state):
"""
This is the final method called in trainer.py in the epoch loop.
We can stop training and leave without losing any information with a simple exception.
"""
# print dir(training_state)
print("Terminating training at the end of epoch", training_state.epoch)
if training_state.val_acc >= self.val_acc_thresh and training_state.acc_value >= self.val_acc_thresh:
raise StopIteration def on_train_end(self, training_state):
"""
Furthermore, tflearn will then immediately call this method after we terminate training,
(or when training ends regardless). This would be a good time to store any additional
information that tflearn doesn't store already.
"""
print("Successfully left training! Final model accuracy:", training_state.acc_value) def get_model(width, height, classes=40):
# TODO, modify model
# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)
img_prep.add_featurewise_stdnorm()
network = input_data(shape=[None, width, height, 1], data_preprocessing=img_prep) # if RGB, 224,224,3
network = conv_2d(network, 32, 3, activation='relu')
network = max_pool_2d(network, 2)
network = conv_2d(network, 64, 3, activation='relu')
network = conv_2d(network, 64, 3, activation='relu')
network = max_pool_2d(network, 2)
network = fully_connected(network, 512, activation='relu')
network = dropout(network, 0.5)
network = fully_connected(network, classes, activation='softmax')
network = regression(network, optimizer='adam',
loss='categorical_crossentropy',
learning_rate=0.001)
model = tflearn.DNN(network, tensorboard_verbose=0)
return model if __name__ == "__main__":
width, height = 32, 32
X, Y, org_labels = load_data(dirname="data", resize_pics=(width, height))
trainX, testX, trainY, testY = train_test_split(X, Y, test_size=0.2, random_state=666)
print("sample data:")
print(trainX[0])
print(trainY[0])
print(testX[-1])
print(testY[-1]) model = get_model(width, height, classes=100) filename = 'cnn_handwrite-acc0.8.tflearn'
# try to load model and resume training
#try:
# model.load(filename)
# print("Model loaded OK. Resume training!")
#except:
# pass # Initialize our callback with desired accuracy threshold.
early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.9)
try:
model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,
snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
show_metric=True, batch_size=32, callbacks=early_stopping_cb, run_id='cnn_handwrite')
except StopIteration as e:
print("OK, stop iterate!Good!") model.save(filename) # predict all data and calculate confusion_matrix
model.load(filename) pro_arr =model.predict(X)
predict_labels = np.argmax(pro_arr, axis=1)
print(classification_report(org_labels, predict_labels))
print(confusion_matrix(org_labels, predict_labels))

运行效果:100个汉字2分钟就可以达到95%精度。

---------------------------------
Run id: cnn_handwrite
Log directory: /tmp/tflearn_logs/
---------------------------------
Preprocessing... Calculating mean over all dataset (this may take long)...
Mean: [ 0.89235026] (To avoid repetitive computation, add it to argument 'mean' of `add_featurewise_zero_center`)
---------------------------------
Preprocessing... Calculating std over all dataset (this may take long)...
STD: 0.192279 (To avoid repetitive computation, add it to argument 'std' of `add_featurewise_stdnorm`)
---------------------------------
Training samples: 19094
Validation samples: 4774
--
Training Step: 597 | total loss: 0.70288 | time: 40.959ss
| Adam | epoch: 001 | loss: 0.70288 - acc: 0.7922 | val_loss: 0.54380 - val_acc: 0.8460 -- iter: 19094/19094
--
Terminating training at the end of epoch 1
Training Step: 1194 | total loss: 0.48860 | time: 40.245s
| Adam | epoch: 002 | loss: 0.48860 - acc: 0.8783 | val_loss: 0.37020 - val_acc: 0.8923 -- iter: 19094/19094
--
Terminating training at the end of epoch 2
Training Step: 1791 | total loss: 0.35790 | time: 41.315ss
| Adam | epoch: 003 | loss: 0.35790 - acc: 0.9090 | val_loss: 0.34719 - val_acc: 0.9049 -- iter: 19094/19094
--
Terminating training at the end of epoch 3
Successfully left training! Final model accuracy: 0.908959209919
OK, stop iterate!Good!
precision recall f1-score support 0 1.00 0.99 0.99 239
1 0.95 0.96 0.96 237
2 0.91 0.98 0.94 240
3 0.90 0.98 0.94 239
4 0.96 0.98 0.97 239
5 0.94 0.97 0.96 239
6 0.98 0.98 0.98 239
7 0.84 0.99 0.91 240
8 0.99 0.87 0.93 239
9 0.95 0.98 0.96 239
10 0.97 0.94 0.96 240
11 0.95 0.98 0.97 240
12 0.92 0.99 0.95 240
13 0.95 0.96 0.96 239
14 0.94 0.94 0.94 236
15 0.94 0.97 0.96 240
16 0.94 0.98 0.96 240
17 0.97 0.99 0.98 240
18 0.94 0.93 0.94 240
19 1.00 0.95 0.98 239
20 0.96 0.98 0.97 240
21 0.98 0.91 0.95 239
22 0.97 0.95 0.96 239
23 1.00 0.97 0.98 239
24 0.94 0.98 0.96 240
25 0.98 0.98 0.98 237
26 0.91 1.00 0.95 239
27 0.91 0.96 0.93 239
28 0.97 0.88 0.92 239
29 1.00 0.98 0.99 240
30 0.99 0.94 0.96 239
31 0.97 0.97 0.97 237
32 0.94 0.98 0.96 236
33 0.94 0.96 0.95 239
34 0.98 0.99 0.98 239
35 0.99 0.98 0.99 240
36 0.96 0.92 0.94 239
37 1.00 0.93 0.96 240
38 0.96 0.99 0.98 238
39 0.98 0.97 0.97 238
40 0.92 0.90 0.91 240
41 0.96 0.97 0.96 237
42 0.98 0.97 0.97 240
43 0.95 0.96 0.95 239
44 0.97 0.96 0.97 239
45 0.95 0.94 0.95 239
46 0.93 0.96 0.94 232
47 0.98 0.91 0.94 237
48 0.95 0.97 0.96 239
49 0.97 0.80 0.88 226
50 0.90 0.95 0.92 240
51 0.98 0.99 0.99 236
52 0.96 0.90 0.93 240
53 0.99 0.96 0.97 235
54 0.97 0.93 0.95 240
55 0.99 0.98 0.99 240
56 0.97 0.92 0.95 239
57 0.97 0.97 0.97 239
58 1.00 0.98 0.99 238
59 0.92 0.98 0.95 240
60 0.99 0.90 0.94 240
61 1.00 0.99 0.99 238
62 0.92 0.95 0.94 239
63 0.92 0.98 0.95 238
64 0.98 0.92 0.95 240
65 0.99 0.92 0.95 239
66 0.98 0.99 0.99 240
67 0.95 0.95 0.95 240
68 0.96 0.98 0.97 239
69 0.97 0.97 0.97 239
70 0.98 0.94 0.96 240
71 0.91 0.96 0.93 239
72 0.98 0.97 0.97 239
73 0.99 0.89 0.94 240
74 0.97 0.99 0.98 237
75 0.89 0.97 0.92 240
76 0.97 0.96 0.97 241
77 0.89 0.91 0.90 240
78 1.00 0.89 0.94 239
79 0.90 0.98 0.94 239
80 0.89 0.96 0.92 240
81 0.96 0.71 0.81 225
82 0.95 1.00 0.97 238
83 0.67 0.96 0.79 239
84 0.97 0.85 0.91 240
85 0.95 0.98 0.96 239
86 0.99 0.93 0.96 240
87 0.98 0.91 0.94 239
88 0.97 0.97 0.97 240
89 0.97 0.94 0.95 239
90 0.97 0.96 0.96 236
91 0.91 0.97 0.94 239
92 0.98 0.95 0.96 240
93 0.98 0.97 0.98 239
94 0.98 0.95 0.97 240
95 0.98 0.99 0.99 239
96 0.95 0.97 0.96 240
97 0.98 0.97 0.98 239
98 0.95 0.98 0.97 237
99 0.97 0.96 0.97 239 avg / total 0.96 0.95 0.95 23868 [[237 0 0 ..., 0 0 0]
[ 0 228 0 ..., 0 0 0]
[ 0 0 235 ..., 0 0 0]
...,
[ 0 0 0 ..., 233 0 0]
[ 0 0 0 ..., 0 233 0]
[ 0 0 0 ..., 0 0 230]]

更多模型见:http://www.cnblogs.com/bonelee/p/8978060.html

将上述模型保存并给TensorFlow使用,仅仅在保存模型前加del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:],仅仅保留inference时候的OP(如果需要retrain注意),如下:

    model = get_model(width, height, classes=100)

    filename = 'cnn_handwrite-acc0.8.tflearn'
# try to load model and resume training
#try:
# model.load(filename)
# print("Model loaded OK. Resume training!")
#except:
# pass # Initialize our callback with desired accuracy threshold.
early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.8)
try:
model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,
snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
show_metric=True, batch_size=32, callbacks=early_stopping_cb, run_id='cnn_handwrite')
except StopIteration as e:
print("OK, stop iterate!Good!") model.save(filename)
del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:] """
# print op name
with tf.Session() as sess:
init_op = tf.initialize_all_variables()
sess.run(init_op)
for v in sess.graph.get_operations():
print(v.name)
""" filename = 'cnn_handwrite-acc0.8.infer.tflearn'
model.save(filename)

参考:http://www.cnblogs.com/bonelee/p/8445261.html 里的脚本,修改:

output_node_names = "FullyConnected/Softmax"
通常为:
output_node_names = "FullyConnected/Softmax"
或者
output_node_names = "FullyConnected_1/Softmax"
output_node_names = "FullyConnected_2/Softmax"
就看你使用的全连接层数,上面分别是1,2,3层。
最后,tensorflow里的使用:
def inference(image):
print('inference')
temp_image = Image.open(image).convert('L')
temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size), Image.ANTIALIAS)
temp_image = np.asarray(temp_image) / 255.0
temp_image = temp_image.reshape([-1, 32, 32, 1])
from tensorflow.python.platform import gfile
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open("frozen_model.pb", "rb") as f:
output_graph_def.ParseFromString(f.read())
tensors = tf.import_graph_def(output_graph_def, name="")
#print tensors
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
op = sess.graph.get_operations()
"""
for m in op:
print(m.values())
"""
op = sess.graph.get_tensor_by_name("FullyConnected_1/Softmax:0")
input_tensor = sess.graph.get_tensor_by_name('InputData/X:0')
probs = sess.run(op,feed_dict = {input_tensor:temp_image})
print probs result = []
for word in probs:
result.append(np.argsort(-word)[:3])
return result def main(_):
image_path = './data/test/00098/104405.png'
#image_path = '../data/00010/17724.png'
final_predict_val = inference(image_path)
logger.info('the result info label {0} predict index {1}'.format(98, final_predict_val))

一般,输入TensorFlow input name默认为InputData/X,但只是op,如果要tensor的话,加上数字0,也就是:InputData/X:0

同理,FullyConnected_1/Softmax:0。

最后预测效果:

[[  8.42533936e-08   1.60850794e-11   2.60133332e-10   2.42555542e-14
4.96124599e-08 4.45251297e-15 3.98175590e-11 1.64476592e-11
7.03968351e-13 5.42319011e-12 8.55469237e-11 4.91866422e-13
1.77282828e-07 4.05237593e-10 3.13049003e-10 1.34780919e-11
2.05803235e-06 2.87827305e-07 1.47789994e-12 2.53391891e-11
3.77086790e-13 2.02639586e-10 9.03167027e-13 3.96698889e-11
1.30850096e-11 5.71980917e-12 3.03487374e-11 2.04132298e-14
6.25303683e-13 1.46122332e-07 2.17450633e-07 1.69623715e-09
6.80857757e-12 2.52643609e-13 6.56771096e-11 8.55152287e-16
1.34496514e-09 1.22644633e-06 1.12011307e-07 7.93476283e-05
8.24334611e-12 4.77531155e-14 9.39397757e-13 2.38438267e-14
2.11416329e-10 5.54395712e-08 2.30046147e-12 2.63584043e-10
4.70621564e-16 5.14432724e-12 6.42602327e-09 1.62485829e-13
7.39078274e-08 3.19146315e-12 5.25887156e-09 1.35877786e-13
1.39127886e-13 2.11998293e-13 9.09501097e-09 9.46486750e-07
2.47498733e-09 2.74523763e-12 1.02716433e-14 1.02069058e-17
3.09356682e-16 1.51022904e-15 9.34333665e-13 2.62195051e-14
3.38079781e-16 7.43019903e-13 1.92409091e-13 3.86611994e-13
2.61276265e-12 1.07969211e-09 1.30814548e-09 2.44038188e-14
9.79275905e-13 1.41007803e-10 6.15137758e-12 2.08893070e-10
1.34751668e-14 2.76824767e-15 7.84100464e-16 7.70873335e-15
5.45704757e-12 3.69386271e-18 2.06012223e-13 1.62567273e-14
1.54544960e-03 2.05292008e-06 1.31726174e-09 7.04993663e-09
4.11338266e-03 3.19344110e-07 3.96519717e-05 2.26919351e-12
2.39114349e-12 2.35558744e-07 9.94213998e-01 1.10125060e-11]]
the result info label 98 predict index [array([98, 92, 88])]
 
 

tflearn 中文汉字识别,训练后模型存为pb给TensorFlow使用——模型层次太深,或者太复杂训练时候都不会收敛的更多相关文章

  1. tflearn 中文汉字识别模型试验汇总

    def get_model(width, height, classes=40): # TODO, modify model # Building 'VGG Network' network = in ...

  2. TensorFlow 模型优化工具包  —  训练后整型量化

    模型优化工具包是一套先进的技术工具包,可协助新手和高级开发者优化待部署和执行的机器学习模型.自推出该工具包以来,  我们一直努力降低机器学习模型量化的复杂性 (https://www.tensorfl ...

  3. Tika结合Tesseract-OCR 实现光学汉字识别(简体、宋体的识别率百分之百)—附Java源码、测试数据和训练集下载地址

     OCR(Optical character recognition) —— 光学字符识别,是图像处理的一个重要分支,中文的识别具有一定挑战性,特别是手写体和草书的识别,是重要和热门的科学研究方向.可 ...

  4. python中文utf8编码后是占3个字符,unicode汉字为2字节

    一个中文utf8编码后是占3个字符,所以求长度的函数可以这样写 def str_len(str): try: row_l=len(str) utf8_l=len(str.encode('utf-8') ...

  5. python 将中文转拼音后填充到url做参数并写入excel

    闲着没事写了个小工具,将中文转拼音后填充到url做参数并写如excel 一.先看下演示,是个什么东西 二.代码 代码用到一个中文转拼音的库,库是网上下的,稍微做了下修改,已经找不原来下载的地址了,然后 ...

  6. Python之TensorFlow的模型训练保存与加载-3

    一.TensorFlow的模型保存和加载,使我们在训练和使用时的一种常用方式.我们把训练好的模型通过二次加载训练,或者独立加载模型训练.这基本上都是比较常用的方式. 二.模型的保存与加载类型有2种 1 ...

  7. DL4NLP——词表示模型(二)基于神经网络的模型:NPLM;word2vec(CBOW/Skip-gram)

    本文简述了以下内容: 神经概率语言模型NPLM,训练语言模型并同时得到词表示 word2vec:CBOW / Skip-gram,直接以得到词表示为目标的模型 (一)原始CBOW(Continuous ...

  8. 『高性能模型』Roofline Model与深度学习模型的性能分析

    转载自知乎:Roofline Model与深度学习模型的性能分析 在真实世界中,任何模型(例如 VGG / MobileNet 等)都必须依赖于具体的计算平台(例如CPU / GPU / ASIC 等 ...

  9. 学习笔记CB014:TensorFlow seq2seq模型步步进阶

    神经网络.<Make Your Own Neural Network>,用非常通俗易懂描述讲解人工神经网络原理用代码实现,试验效果非常好. 循环神经网络和LSTM.Christopher ...

随机推荐

  1. Linux下挂载指定分区下的某个文件夹到指定目录(mount)

    # 挂载 mount --bind olddir newdir # 卸载 umount newdir 参考: http://www.cnblogs.com/dabaopku/archive/2010/ ...

  2. Delphi Helper Record Class

    unit Unit1; {$DEFINE USESGUIDHELP} interface implementation {$IFDEF USESGUIDHELP} uses System.SysUti ...

  3. awk如何区分shell脚本传进来的参数和自身的参数?awk如何获取shell脚本传进来的参数;awk中如何执行shell命令

    问题:对于shell脚本,$0表示脚本本身,$1表示脚本的第一个参数,$2……依次类推:对于awk,$1表示分割后的第一个字段,$2……依次类推.那么对于shell脚本中的awk如何区分两者呢? 答案 ...

  4. C#开发ActiveX控件,.NET开发OCX控件案例 【转】

    http://xiaochen.2003.4.blog.163.com/blog/static/480409672012530227678/ 讲下什么是ActiveX控件,到底有什么作用?在网页中又如 ...

  5. C中的继承和多态

    昨天同学面试被问到这个问题,很有水平,以前都没有遇到过这个问题,一时自己也不知道怎么回答. 网上学习了一下,记录以备后用! C/C++ Internals : 里面的问题都写的不错,可以读读! Ref ...

  6. Android二维码工具zxing使用

    二维码在我们生活中随处可见.在我眼里简直能够用"泛滥"来形容啦.那怎样在我们Android项目中扫描识别二维码或生成二维码图片呢? 我们通常使用的开源框架是zxing.在githu ...

  7. PS 基础知识 CMYK全称是什么

    已解决 请问谁知道CMYK四色的英文全称? 悬赏分:20 - 解决时间:2006-9-6 16:23 C代表什么颜色?英文全称是什么? M代表什么颜色?英文全称是什么? Y代表什么颜色?英文全称是什么 ...

  8. opencv yuv420与Mat互转

    项目用到opencv 融合图片的功能,经过一天的调试,达到预期目标,先将如何调用opencv库实现YUV42与Mat互转记录下来. 一.下载opencv编译的库下载地址是:http://opencv. ...

  9. [Android5 系列—] 1. 构建一个简单的用户界面

    前言 安卓应用的用户界面是构建在View 和ViewGroup 这两个物件的层级之上的. View 就是一般的UI组件.像button,输入框等. viewGroup 是一些不可见的view的容器,用 ...

  10. ubuntu环境eclipse配置

    ubuntu环境eclipse配置 首先下载Eclipse和JDK: 然后将上边两个压缩包解压到安装文件夹(如;/home/linux/softwares/java).然后配置/etc/profile ...