数据集下载地址:http://www.nlpr.ia.ac.cn/databases/handwriting/download.html

chinese_write_detection.py

# -*- coding: utf-8 -*-
import tensorflow as tf
import os
import random
import tensorflow.contrib.slim as slim
import time
import numpy as np
import pickle
from PIL import Image
from log_utils import get_logger logger = get_logger("HandWritten Practice")
root_path = 'D:/eclipse-workspace/sxzsb'
tf.app.flags.DEFINE_boolean('random_flip_up_down', False, "Whether to random flip up down")
tf.app.flags.DEFINE_boolean('random_brightness', True, "whether to adjust brightness")
tf.app.flags.DEFINE_boolean('random_contrast', True, "whether to random constrast") tf.app.flags.DEFINE_integer('charset_size', 3755, "Choose the first `charset_size` character to conduct our experiment.")
tf.app.flags.DEFINE_integer('image_size', 64, "Needs to provide same value as in training.")
tf.app.flags.DEFINE_boolean('gray', True, "whether to change the rbg to gray")
tf.app.flags.DEFINE_integer('max_steps', 12002, 'the max training steps ')
tf.app.flags.DEFINE_integer('eval_steps', 50, "the step num to eval")
tf.app.flags.DEFINE_integer('save_steps', 2000, "the steps to save") tf.app.flags.DEFINE_string('checkpoint_dir', 'D:/eclipse-workspace/sxzsb/checkpoint', 'the checkpoint dir')
tf.app.flags.DEFINE_string('train_data_dir', 'D:/eclipse-workspace/sxzsb/data/train', 'the train dataset dir(containing png files)')
tf.app.flags.DEFINE_string('test_data_dir', 'D:/eclipse-workspace/sxzsb/data/test', 'the test dataset dir(containing png files)')
tf.app.flags.DEFINE_string('log_dir', 'D:/eclipse-workspace/sxzsb/log', 'the logging path)') tf.app.flags.DEFINE_boolean('restore', False, 'whether to restore from checkpoint')
tf.app.flags.DEFINE_integer('epoch', 1, 'Number of epoches')
tf.app.flags.DEFINE_integer('batch_size', 128, 'Validation batch size')
tf.app.flags.DEFINE_string('mode', 'train', 'Running mode. One of {"train", "valid", "test"}')
FLAGS = tf.app.flags.FLAGS class DataIterator: def __init__(self, data_dir):
# Set FLAGS.charset_size to a small value if available computation power is limited.
truncate_path = data_dir + ('%05d' % FLAGS.charset_size)
print(truncate_path)
self.image_names = []
for root, sub_folder, file_list in os.walk(data_dir):
if root < truncate_path: # some problem here ,because the first root is contain inside ,and there is no file_list
self.image_names += [os.path.join(root, file_path) for file_path in file_list]
random.shuffle(self.image_names)
self.labels = [int(file_name[len(data_dir):].split(os.sep)[0]) for file_name in self.image_names] # int("00020") output:20 @property
def size(self): # @property,负责把一个方法变成属性调用的,还可以定义只读属性,只定义getter方法,不定义setter方法就是一个只读属性
return len(self.labels) @staticmethod
def data_augmentation(images):
if FLAGS.random_flip_up_down:
images = tf.image.random_flip_up_down(images)
if FLAGS.random_brightness:
images = tf.image.random_brightness(images, max_delta=0.3)
if FLAGS.random_contrast:
images = tf.image.random_contrast(images, 0.8, 1.2)
return images def input_pipeline(self, batch_size, num_epochs=None, aug=False):
# 1、convert images to a tensor 构造数据queue
images_tensor = tf.convert_to_tensor(self.image_names, dtype=tf.string)
# 执行tf.convert_to_tensor()的时候,在图上生成了一个Op,Op中保存了传入参数的数据。op经过计算产生tensor
labels_tensor = tf.convert_to_tensor(self.labels, dtype=tf.int64)
input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor], num_epochs=num_epochs)
# 2、 ## queue输出数据
labels = input_queue[1]
images_content = tf.read_file(input_queue[0]) # read images from the queue,refer to input_queue
images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32)
if aug:
images = self.data_augmentation(images)
new_size = tf.constant([FLAGS.image_size, FLAGS.image_size], dtype=tf.int32)
images = tf.image.resize_images(images, new_size)
# collect batches of images before processing
# 3、shuffle_batch批量从queu批量读取数据
image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size, capacity=50000,
min_after_dequeue=10000) # produce shunffled batch
return image_batch, label_batch def build_graph(top_k):
# with tf.device('/cpu:0'):
keep_prob = tf.placeholder(dtype=tf.float32, shape=[], name='keep_prob')
images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1], name='image_batch')
labels = tf.placeholder(dtype=tf.int64, shape=[None], name='label_batch') conv_1 = slim.conv2d(images, 64, [3, 3], 1, padding='SAME', scope='conv1')
# (inputs,num_outputs,[卷积核个数] kernel_size,[卷积核的高度,卷积核的宽]stride=1,padding='SAME',)
max_pool_1 = slim.max_pool2d(conv_1, [2, 2], [2, 2], padding='SAME')
conv_2 = slim.conv2d(max_pool_1, 128, [3, 3], padding='SAME', scope='conv2')
max_pool_2 = slim.max_pool2d(conv_2, [2, 2], [2, 2], padding='SAME')
conv_3 = slim.conv2d(max_pool_2, 256, [3, 3], padding='SAME', scope='conv3')
max_pool_3 = slim.max_pool2d(conv_3, [2, 2], [2, 2], padding='SAME') flatten = slim.flatten(max_pool_3)
fc1 = slim.fully_connected(tf.nn.dropout(flatten, keep_prob), 1024, activation_fn=tf.nn.tanh, scope='fc1')
logits = slim.fully_connected(tf.nn.dropout(fc1, keep_prob), FLAGS.charset_size, activation_fn=None, scope='fc2')
# logits = slim.fully_connected(flatten, FLAGS.charset_size, activation_fn=None, reuse=reuse, scope='fc')
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))
# y表示的是实际类别,y_表示预测结果,这实际上面是把原来的神经网络输出层的softmax和cross_entrop何在一起计算,为了追求速度
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32)) global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False) # global_step interesting sharing varialbes
rate = tf.train.exponential_decay(2e-4, global_step, decay_steps=2000, decay_rate=0.97, staircase=True)
train_op = tf.train.AdamOptimizer(learning_rate=rate).minimize(loss, global_step=global_step) # train_op 包含了训练数据
probabilities = tf.nn.softmax(logits) # 上一个用logits是soft_max和cross_entropy一起算的,这次只是算了soft_max输出 tf.summary.scalar('loss', loss)
tf.summary.scalar('accuracy', accuracy)
merged_summary_op = tf.summary.merge_all()
predicted_val_top_k, predicted_index_top_k = tf.nn.top_k(probabilities, k=top_k)
accuracy_in_top_k = tf.reduce_mean(tf.cast(tf.nn.in_top_k(probabilities, labels, top_k), tf.float32)) # 这个思路真是清奇!!!看来我回答对了 # return the operator
return {'images': images,
'labels': labels,
'keep_prob': keep_prob,
'top_k': top_k,
'global_step': global_step,
'train_op': train_op,
'loss': loss,
'accuracy': accuracy,
'accuracy_top_k': accuracy_in_top_k,
'merged_summary_op': merged_summary_op,
'predicted_distribution': probabilities,
'predicted_index_top_k': predicted_index_top_k,
'predicted_val_top_k': predicted_val_top_k} def train():
print('Begin training')
train_feeder = DataIterator(data_dir='../data/train/')
test_feeder = DataIterator(data_dir='../data/test/')
with tf.Session() as sess:
# session操作之前启动队列runners才能激活pipelines/input pipeline 并载入数据
train_images, train_labels = train_feeder.input_pipeline(batch_size=FLAGS.batch_size, aug=True) # num_epochs what's refer to ?
test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size)
graph = build_graph(top_k=1) # very important
sess.run(tf.global_variables_initializer())
# 4、 ## 启动queue线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
saver = tf.train.Saver() train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/val')
start_step = 0
if FLAGS.restore: # 这里是加载保存好的模型,的到step继续训练
ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if ckpt:
saver.restore(sess, ckpt)
print("restore from the checkpoint {0}".format(ckpt))
start_step += int(ckpt.split('-')[-1]) logger.info(':::Training Start:::')
try:
while not coord.should_stop(): ###----
start_time = time.time()
print(start_time)
train_images_batch, train_labels_batch = sess.run([train_images, train_labels])
print(len(train_images_batch))
feed_dict = {graph['images']: train_images_batch,
graph['labels']: train_labels_batch,
graph['keep_prob']: 0.8} # keep 80% connection
_, loss_val, train_summary, step = sess.run(
[graph['train_op'], graph['loss'], graph['merged_summary_op'], graph['global_step']],
feed_dict=feed_dict)
train_writer.add_summary(train_summary, step)
end_time = time.time()
logger.info("the step {0} takes {1} loss {2}".format(step, end_time - start_time, loss_val))
if step > FLAGS.max_steps:
break
if step % FLAGS.eval_steps == 1:
test_images_batch, test_labels_batch = sess.run([test_images, test_labels])
feed_dict = {graph['images']: test_images_batch,
graph['labels']: test_labels_batch,
graph['keep_prob']: 1.0}
accuracy_test, test_summary = sess.run(
[graph['accuracy'], graph['merged_summary_op']],
feed_dict=feed_dict) # 这里的多层括号问题
test_writer.add_summary(test_summary, step)
logger.info('===============Eval a batch=======================')
logger.info('the step {0} test accuracy: {1}'
.format(step, accuracy_test))
logger.info('===============Eval a batch=======================')
if step % FLAGS.save_steps == 1:
logger.info('Save the ckpt of {0}'.format(step))
saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'),
global_step=graph['global_step'])
except tf.errors.OutOfRangeError:
logger.info('==================Train Finished================')
saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=graph['global_step'])
finally:
coord.request_stop() # 任何一个线程请求停止,则coord.should_stop()就会返回True ,然后都停下来
coord.join(threads) def validation():
print('validation')
test_feeder = DataIterator(data_dir='../data/test/') final_predict_val = []
final_predict_index = []
groundtruth = [] with tf.Session() as sess:
test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size, num_epochs=1)
graph = build_graph(3) sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer()) # initialize test_feeder's inside state coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord) saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if ckpt is not None:
saver.restore(sess, ckpt)
print("restore from the checkpoint {0}".format(ckpt)) logger.info(':::Start validation:::')
try:
i = 0
acc_top_1, acc_top_k = 0.0, 0.0
while not coord.should_stop():
i += 1
start_time = time.time()
test_images_batch, test_labels_batch = sess.run([test_images, test_labels])
feed_dict = {graph['images']: test_images_batch,
graph['labels']: test_labels_batch,
graph['keep_prob']: 1.0}
batch_labels, probs, indices, acc_1, acc_k = sess.run([graph['labels'],
graph['predicted_val_top_k'],
graph['predicted_index_top_k'],
graph['accuracy'],
graph['accuracy_top_k']], feed_dict=feed_dict)
final_predict_val += probs.tolist()
final_predict_index += indices.tolist()
groundtruth += batch_labels.tolist()
acc_top_1 += acc_1
acc_top_k += acc_k
end_time = time.time()
logger.info("the batch {0} takes {1} seconds, accuracy = {2}(top_1) {3}(top_k)"
.format(i, end_time - start_time, acc_1, acc_k)) except tf.errors.OutOfRangeError:
logger.info('==================Validation Finished================')
acc_top_1 = acc_top_1 * FLAGS.batch_size / test_feeder.size
acc_top_k = acc_top_k * FLAGS.batch_size / test_feeder.size
logger.info('top 1 accuracy {0} top k accuracy {1}'.format(acc_top_1, acc_top_k))
finally:
coord.request_stop()
coord.join(threads)
return {'prob': final_predict_val, 'indices': final_predict_index, 'groundtruth': groundtruth} def inference(image):
print('inference')
temp_image = Image.open(image).convert('L')
temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size), Image.ANTIALIAS)
temp_image = np.asarray(temp_image) / 255.0
temp_image = temp_image.reshape([-1, 64, 64, 1])
with tf.Session() as sess:
logger.info('========start inference============')
# images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1])
# Pass a shadow label 0. This label will not affect the computation graph.
graph = build_graph(top_k=3)
saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if ckpt:
saver.restore(sess, ckpt)
predict_val, predict_index = sess.run([graph['predicted_val_top_k'], graph['predicted_index_top_k']],
feed_dict={graph['images']: temp_image, graph['keep_prob']: 1.0})
return predict_val, predict_index def main(_):
print(FLAGS.mode)
if FLAGS.mode == "train":
train()
elif FLAGS.mode == 'validation':
dct = validation() # thinking what is "dct"
result_file = 'result.dict'
logger.info('Write result into {0}'.format(result_file))
with open(result_file, 'wb') as f:
pickle.dump(dct, f)
logger.info('Write file ends')
elif FLAGS.mode == 'inference':
image_path = '../data/test/00159/75700.png'
final_predict_val, final_predict_index = inference(image_path) # figure out what is inference
logger.info('the result info label {0} predict index {1} predict_val {2}'.format(190, final_predict_index,
final_predict_val)) if __name__ == "__main__":
tf.app.run() # It's just a very quick wrapper that handles flag parsing and then dispatches to your own main.

log_utils.py

# -*- coding:utf-8 -*-
import os, os.path as osp
import time def strftime(t=None):
return time.strftime("%Y%m%d-%H%M%S", time.localtime(t or time.time())) #################
# Logging
#################
import logging
from logging.handlers import TimedRotatingFileHandler
logging.basicConfig(format="[ %(asctime)s][%(module)s.%(funcName)s] %(message)s") DEFAULT_LEVEL = logging.INFO
DEFAULT_LOGGING_DIR = osp.join("logs", "gcforest")
fh = None def init_fh():
global fh
if fh is not None:
return
if DEFAULT_LOGGING_DIR is None:
return
if not osp.exists(DEFAULT_LOGGING_DIR): os.makedirs(DEFAULT_LOGGING_DIR)
logging_path = osp.join(DEFAULT_LOGGING_DIR, strftime() + ".log")
fh = logging.FileHandler(logging_path)
fh.setFormatter(logging.Formatter("[ %(asctime)s][%(module)s.%(funcName)s] %(message)s")) def update_default_level(defalut_level):
global DEFAULT_LEVEL
DEFAULT_LEVEL = defalut_level def update_default_logging_dir(default_logging_dir):
global DEFAULT_LOGGING_DIR
DEFAULT_LOGGING_DIR = default_logging_dir def get_logger(name="HandWrittenPractice", level=None):
level = level or DEFAULT_LEVEL
logger = logging.getLogger(name)
logger.setLevel(level)
init_fh()
if fh is not None:
logger.addHandler(fh)
return logger

Train

python chinese_write_detection.py --mode=train --max_steps=200000 --eval_steps=1000 --save_steps=10000

Validation

python chinese_write_detection.py --mode=validation

Inference

python chinese_write_detection.py --mode=inference

tensorflow创建cnn网络进行中文手写文字识别的更多相关文章

  1. Atitit s2018.2 s2 doc list on home ntpc.docx  \Atiitt uke制度体系 法律 法规 规章 条例 国王诏书.docx \Atiitt 手写文字识别 讯飞科大 语音云.docx \Atitit 代码托管与虚拟主机.docx \Atitit 企业文化 每日心灵 鸡汤 值班 发布.docx \Atitit 几大研发体系对比 Stage-Gat

    Atitit s2018.2 s2 doc list on home ntpc.docx \Atiitt uke制度体系  法律 法规 规章 条例 国王诏书.docx \Atiitt 手写文字识别   ...

  2. 5 TensorFlow入门笔记之RNN实现手写数字识别

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  3. Tensorflow项目实战一:MNIST手写数字识别

    此模型中,输入是28*28*1的图片,经过两个卷积层(卷积+池化)层之后,尺寸变为7*7*64,将最后一个卷积层展成一个以为向量,然后接两个全连接层,第一个全连接层加一个dropout,最后一个全连接 ...

  4. CNN完成mnist数据集手写数字识别

    # coding: utf-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data d ...

  5. TensorFlow(十):卷积神经网络实现手写数字识别以及可视化

    上代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = inpu ...

  6. Google机器学习笔记(七)TF.Learn 手写文字识别

    转载请注明作者:梦里风林 Google Machine Learning Recipes 7 官方中文博客 - 视频地址 Github工程地址 https://github.com/ahangchen ...

  7. TensorFlow卷积神经网络实现手写数字识别以及可视化

    边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...

  8. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  9. 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)

    上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...

随机推荐

  1. canvas toBlob ,ie兼容

    /* canvas-toBlob.js * A canvas.toBlob() implementation. * 2016-05-26 * * By Eli Grey, http://eligrey ...

  2. Codeforces Round #608 (Div. 2) E - Common Number (二分 思维 树结构)

  3. 解决Windows2003 Server终端服务120天限制

    用过windows server 2003做服务器的人都知道,windows2003的性能安全性比以前的windows版本高出很多,但是也带来很多麻烦.其中服务器最重要的远程管理“终端服务”居然要求授 ...

  4. DDL DML DCL的理解

    DDL的操作对象是表,不会对具体的数据进行操作. DML的操作对象是记录, DCL的操作对象是数据库对象的权限.

  5. vue定义自定义事件方法、事件传值及事件对象

    1.自定义事件 例如v-on:click="run" 或者 @click="run" <template> <div id="app ...

  6. Java实践-远程调用Shell脚本并获取输出信息

    1.添加依赖 <dependency> <groupId>ch.ethz.ganymed</groupId> <artifactId>ganymed-s ...

  7. HDU6669 Game(思维,贪心)

    HDU6669 Game 维护区间 \([l,r]\) 为完成前 \(i\) 步使用最少步数后可能落在的区间. 初始时区间 \([l,r]\) 为整个坐标轴. 对于第 \(i\) 个任务区间 \([a ...

  8. HDU6668 Polynomial(模拟)

    HDU6668 Polynomial 顺序遍历找出最高次幂项的系数 分三种情况 \(1/0\).\(0/1\).\(f(x)/g(x)\) . 复杂度为 \(O(n)\) . #include< ...

  9. vue2.0 之 douban (六)axios的简单使用

    由于项目中用到了豆瓣api,涉及到跨域访问,就需要在config的index.js添加代理,例如 proxyTable: { // 设置代理,解决跨域问题 '/api': { target: 'htt ...

  10. MongoDb python连接

    方式一:简写 client = MongClient() 方式二:指定端口和地址 client = MongoClient('localhost':27017) 方式三:使用URI  --统一资源定位 ...