数据集下载地址:http://www.nlpr.ia.ac.cn/databases/handwriting/download.html

chinese_write_detection.py

# -*- coding: utf-8 -*-
import tensorflow as tf
import os
import random
import tensorflow.contrib.slim as slim
import time
import numpy as np
import pickle
from PIL import Image
from log_utils import get_logger logger = get_logger("HandWritten Practice")
root_path = 'D:/eclipse-workspace/sxzsb'
tf.app.flags.DEFINE_boolean('random_flip_up_down', False, "Whether to random flip up down")
tf.app.flags.DEFINE_boolean('random_brightness', True, "whether to adjust brightness")
tf.app.flags.DEFINE_boolean('random_contrast', True, "whether to random constrast") tf.app.flags.DEFINE_integer('charset_size', 3755, "Choose the first `charset_size` character to conduct our experiment.")
tf.app.flags.DEFINE_integer('image_size', 64, "Needs to provide same value as in training.")
tf.app.flags.DEFINE_boolean('gray', True, "whether to change the rbg to gray")
tf.app.flags.DEFINE_integer('max_steps', 12002, 'the max training steps ')
tf.app.flags.DEFINE_integer('eval_steps', 50, "the step num to eval")
tf.app.flags.DEFINE_integer('save_steps', 2000, "the steps to save") tf.app.flags.DEFINE_string('checkpoint_dir', 'D:/eclipse-workspace/sxzsb/checkpoint', 'the checkpoint dir')
tf.app.flags.DEFINE_string('train_data_dir', 'D:/eclipse-workspace/sxzsb/data/train', 'the train dataset dir(containing png files)')
tf.app.flags.DEFINE_string('test_data_dir', 'D:/eclipse-workspace/sxzsb/data/test', 'the test dataset dir(containing png files)')
tf.app.flags.DEFINE_string('log_dir', 'D:/eclipse-workspace/sxzsb/log', 'the logging path)') tf.app.flags.DEFINE_boolean('restore', False, 'whether to restore from checkpoint')
tf.app.flags.DEFINE_integer('epoch', 1, 'Number of epoches')
tf.app.flags.DEFINE_integer('batch_size', 128, 'Validation batch size')
tf.app.flags.DEFINE_string('mode', 'train', 'Running mode. One of {"train", "valid", "test"}')
FLAGS = tf.app.flags.FLAGS class DataIterator: def __init__(self, data_dir):
# Set FLAGS.charset_size to a small value if available computation power is limited.
truncate_path = data_dir + ('%05d' % FLAGS.charset_size)
print(truncate_path)
self.image_names = []
for root, sub_folder, file_list in os.walk(data_dir):
if root < truncate_path: # some problem here ,because the first root is contain inside ,and there is no file_list
self.image_names += [os.path.join(root, file_path) for file_path in file_list]
random.shuffle(self.image_names)
self.labels = [int(file_name[len(data_dir):].split(os.sep)[0]) for file_name in self.image_names] # int("00020") output:20 @property
def size(self): # @property,负责把一个方法变成属性调用的,还可以定义只读属性,只定义getter方法,不定义setter方法就是一个只读属性
return len(self.labels) @staticmethod
def data_augmentation(images):
if FLAGS.random_flip_up_down:
images = tf.image.random_flip_up_down(images)
if FLAGS.random_brightness:
images = tf.image.random_brightness(images, max_delta=0.3)
if FLAGS.random_contrast:
images = tf.image.random_contrast(images, 0.8, 1.2)
return images def input_pipeline(self, batch_size, num_epochs=None, aug=False):
# 1、convert images to a tensor 构造数据queue
images_tensor = tf.convert_to_tensor(self.image_names, dtype=tf.string)
# 执行tf.convert_to_tensor()的时候,在图上生成了一个Op,Op中保存了传入参数的数据。op经过计算产生tensor
labels_tensor = tf.convert_to_tensor(self.labels, dtype=tf.int64)
input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor], num_epochs=num_epochs)
# 2、 ## queue输出数据
labels = input_queue[1]
images_content = tf.read_file(input_queue[0]) # read images from the queue,refer to input_queue
images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32)
if aug:
images = self.data_augmentation(images)
new_size = tf.constant([FLAGS.image_size, FLAGS.image_size], dtype=tf.int32)
images = tf.image.resize_images(images, new_size)
# collect batches of images before processing
# 3、shuffle_batch批量从queu批量读取数据
image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size, capacity=50000,
min_after_dequeue=10000) # produce shunffled batch
return image_batch, label_batch def build_graph(top_k):
# with tf.device('/cpu:0'):
keep_prob = tf.placeholder(dtype=tf.float32, shape=[], name='keep_prob')
images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1], name='image_batch')
labels = tf.placeholder(dtype=tf.int64, shape=[None], name='label_batch') conv_1 = slim.conv2d(images, 64, [3, 3], 1, padding='SAME', scope='conv1')
# (inputs,num_outputs,[卷积核个数] kernel_size,[卷积核的高度,卷积核的宽]stride=1,padding='SAME',)
max_pool_1 = slim.max_pool2d(conv_1, [2, 2], [2, 2], padding='SAME')
conv_2 = slim.conv2d(max_pool_1, 128, [3, 3], padding='SAME', scope='conv2')
max_pool_2 = slim.max_pool2d(conv_2, [2, 2], [2, 2], padding='SAME')
conv_3 = slim.conv2d(max_pool_2, 256, [3, 3], padding='SAME', scope='conv3')
max_pool_3 = slim.max_pool2d(conv_3, [2, 2], [2, 2], padding='SAME') flatten = slim.flatten(max_pool_3)
fc1 = slim.fully_connected(tf.nn.dropout(flatten, keep_prob), 1024, activation_fn=tf.nn.tanh, scope='fc1')
logits = slim.fully_connected(tf.nn.dropout(fc1, keep_prob), FLAGS.charset_size, activation_fn=None, scope='fc2')
# logits = slim.fully_connected(flatten, FLAGS.charset_size, activation_fn=None, reuse=reuse, scope='fc')
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))
# y表示的是实际类别,y_表示预测结果,这实际上面是把原来的神经网络输出层的softmax和cross_entrop何在一起计算,为了追求速度
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32)) global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False) # global_step interesting sharing varialbes
rate = tf.train.exponential_decay(2e-4, global_step, decay_steps=2000, decay_rate=0.97, staircase=True)
train_op = tf.train.AdamOptimizer(learning_rate=rate).minimize(loss, global_step=global_step) # train_op 包含了训练数据
probabilities = tf.nn.softmax(logits) # 上一个用logits是soft_max和cross_entropy一起算的,这次只是算了soft_max输出 tf.summary.scalar('loss', loss)
tf.summary.scalar('accuracy', accuracy)
merged_summary_op = tf.summary.merge_all()
predicted_val_top_k, predicted_index_top_k = tf.nn.top_k(probabilities, k=top_k)
accuracy_in_top_k = tf.reduce_mean(tf.cast(tf.nn.in_top_k(probabilities, labels, top_k), tf.float32)) # 这个思路真是清奇!!!看来我回答对了 # return the operator
return {'images': images,
'labels': labels,
'keep_prob': keep_prob,
'top_k': top_k,
'global_step': global_step,
'train_op': train_op,
'loss': loss,
'accuracy': accuracy,
'accuracy_top_k': accuracy_in_top_k,
'merged_summary_op': merged_summary_op,
'predicted_distribution': probabilities,
'predicted_index_top_k': predicted_index_top_k,
'predicted_val_top_k': predicted_val_top_k} def train():
print('Begin training')
train_feeder = DataIterator(data_dir='../data/train/')
test_feeder = DataIterator(data_dir='../data/test/')
with tf.Session() as sess:
# session操作之前启动队列runners才能激活pipelines/input pipeline 并载入数据
train_images, train_labels = train_feeder.input_pipeline(batch_size=FLAGS.batch_size, aug=True) # num_epochs what's refer to ?
test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size)
graph = build_graph(top_k=1) # very important
sess.run(tf.global_variables_initializer())
# 4、 ## 启动queue线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
saver = tf.train.Saver() train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/val')
start_step = 0
if FLAGS.restore: # 这里是加载保存好的模型,的到step继续训练
ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if ckpt:
saver.restore(sess, ckpt)
print("restore from the checkpoint {0}".format(ckpt))
start_step += int(ckpt.split('-')[-1]) logger.info(':::Training Start:::')
try:
while not coord.should_stop(): ###----
start_time = time.time()
print(start_time)
train_images_batch, train_labels_batch = sess.run([train_images, train_labels])
print(len(train_images_batch))
feed_dict = {graph['images']: train_images_batch,
graph['labels']: train_labels_batch,
graph['keep_prob']: 0.8} # keep 80% connection
_, loss_val, train_summary, step = sess.run(
[graph['train_op'], graph['loss'], graph['merged_summary_op'], graph['global_step']],
feed_dict=feed_dict)
train_writer.add_summary(train_summary, step)
end_time = time.time()
logger.info("the step {0} takes {1} loss {2}".format(step, end_time - start_time, loss_val))
if step > FLAGS.max_steps:
break
if step % FLAGS.eval_steps == 1:
test_images_batch, test_labels_batch = sess.run([test_images, test_labels])
feed_dict = {graph['images']: test_images_batch,
graph['labels']: test_labels_batch,
graph['keep_prob']: 1.0}
accuracy_test, test_summary = sess.run(
[graph['accuracy'], graph['merged_summary_op']],
feed_dict=feed_dict) # 这里的多层括号问题
test_writer.add_summary(test_summary, step)
logger.info('===============Eval a batch=======================')
logger.info('the step {0} test accuracy: {1}'
.format(step, accuracy_test))
logger.info('===============Eval a batch=======================')
if step % FLAGS.save_steps == 1:
logger.info('Save the ckpt of {0}'.format(step))
saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'),
global_step=graph['global_step'])
except tf.errors.OutOfRangeError:
logger.info('==================Train Finished================')
saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=graph['global_step'])
finally:
coord.request_stop() # 任何一个线程请求停止,则coord.should_stop()就会返回True ,然后都停下来
coord.join(threads) def validation():
print('validation')
test_feeder = DataIterator(data_dir='../data/test/') final_predict_val = []
final_predict_index = []
groundtruth = [] with tf.Session() as sess:
test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size, num_epochs=1)
graph = build_graph(3) sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer()) # initialize test_feeder's inside state coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord) saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if ckpt is not None:
saver.restore(sess, ckpt)
print("restore from the checkpoint {0}".format(ckpt)) logger.info(':::Start validation:::')
try:
i = 0
acc_top_1, acc_top_k = 0.0, 0.0
while not coord.should_stop():
i += 1
start_time = time.time()
test_images_batch, test_labels_batch = sess.run([test_images, test_labels])
feed_dict = {graph['images']: test_images_batch,
graph['labels']: test_labels_batch,
graph['keep_prob']: 1.0}
batch_labels, probs, indices, acc_1, acc_k = sess.run([graph['labels'],
graph['predicted_val_top_k'],
graph['predicted_index_top_k'],
graph['accuracy'],
graph['accuracy_top_k']], feed_dict=feed_dict)
final_predict_val += probs.tolist()
final_predict_index += indices.tolist()
groundtruth += batch_labels.tolist()
acc_top_1 += acc_1
acc_top_k += acc_k
end_time = time.time()
logger.info("the batch {0} takes {1} seconds, accuracy = {2}(top_1) {3}(top_k)"
.format(i, end_time - start_time, acc_1, acc_k)) except tf.errors.OutOfRangeError:
logger.info('==================Validation Finished================')
acc_top_1 = acc_top_1 * FLAGS.batch_size / test_feeder.size
acc_top_k = acc_top_k * FLAGS.batch_size / test_feeder.size
logger.info('top 1 accuracy {0} top k accuracy {1}'.format(acc_top_1, acc_top_k))
finally:
coord.request_stop()
coord.join(threads)
return {'prob': final_predict_val, 'indices': final_predict_index, 'groundtruth': groundtruth} def inference(image):
print('inference')
temp_image = Image.open(image).convert('L')
temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size), Image.ANTIALIAS)
temp_image = np.asarray(temp_image) / 255.0
temp_image = temp_image.reshape([-1, 64, 64, 1])
with tf.Session() as sess:
logger.info('========start inference============')
# images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1])
# Pass a shadow label 0. This label will not affect the computation graph.
graph = build_graph(top_k=3)
saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if ckpt:
saver.restore(sess, ckpt)
predict_val, predict_index = sess.run([graph['predicted_val_top_k'], graph['predicted_index_top_k']],
feed_dict={graph['images']: temp_image, graph['keep_prob']: 1.0})
return predict_val, predict_index def main(_):
print(FLAGS.mode)
if FLAGS.mode == "train":
train()
elif FLAGS.mode == 'validation':
dct = validation() # thinking what is "dct"
result_file = 'result.dict'
logger.info('Write result into {0}'.format(result_file))
with open(result_file, 'wb') as f:
pickle.dump(dct, f)
logger.info('Write file ends')
elif FLAGS.mode == 'inference':
image_path = '../data/test/00159/75700.png'
final_predict_val, final_predict_index = inference(image_path) # figure out what is inference
logger.info('the result info label {0} predict index {1} predict_val {2}'.format(190, final_predict_index,
final_predict_val)) if __name__ == "__main__":
tf.app.run() # It's just a very quick wrapper that handles flag parsing and then dispatches to your own main.

log_utils.py

# -*- coding:utf-8 -*-
import os, os.path as osp
import time def strftime(t=None):
return time.strftime("%Y%m%d-%H%M%S", time.localtime(t or time.time())) #################
# Logging
#################
import logging
from logging.handlers import TimedRotatingFileHandler
logging.basicConfig(format="[ %(asctime)s][%(module)s.%(funcName)s] %(message)s") DEFAULT_LEVEL = logging.INFO
DEFAULT_LOGGING_DIR = osp.join("logs", "gcforest")
fh = None def init_fh():
global fh
if fh is not None:
return
if DEFAULT_LOGGING_DIR is None:
return
if not osp.exists(DEFAULT_LOGGING_DIR): os.makedirs(DEFAULT_LOGGING_DIR)
logging_path = osp.join(DEFAULT_LOGGING_DIR, strftime() + ".log")
fh = logging.FileHandler(logging_path)
fh.setFormatter(logging.Formatter("[ %(asctime)s][%(module)s.%(funcName)s] %(message)s")) def update_default_level(defalut_level):
global DEFAULT_LEVEL
DEFAULT_LEVEL = defalut_level def update_default_logging_dir(default_logging_dir):
global DEFAULT_LOGGING_DIR
DEFAULT_LOGGING_DIR = default_logging_dir def get_logger(name="HandWrittenPractice", level=None):
level = level or DEFAULT_LEVEL
logger = logging.getLogger(name)
logger.setLevel(level)
init_fh()
if fh is not None:
logger.addHandler(fh)
return logger

Train

python chinese_write_detection.py --mode=train --max_steps=200000 --eval_steps=1000 --save_steps=10000

Validation

python chinese_write_detection.py --mode=validation

Inference

python chinese_write_detection.py --mode=inference

tensorflow创建cnn网络进行中文手写文字识别的更多相关文章

  1. Atitit s2018.2 s2 doc list on home ntpc.docx  \Atiitt uke制度体系 法律 法规 规章 条例 国王诏书.docx \Atiitt 手写文字识别 讯飞科大 语音云.docx \Atitit 代码托管与虚拟主机.docx \Atitit 企业文化 每日心灵 鸡汤 值班 发布.docx \Atitit 几大研发体系对比 Stage-Gat

    Atitit s2018.2 s2 doc list on home ntpc.docx \Atiitt uke制度体系  法律 法规 规章 条例 国王诏书.docx \Atiitt 手写文字识别   ...

  2. 5 TensorFlow入门笔记之RNN实现手写数字识别

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  3. Tensorflow项目实战一:MNIST手写数字识别

    此模型中,输入是28*28*1的图片,经过两个卷积层(卷积+池化)层之后,尺寸变为7*7*64,将最后一个卷积层展成一个以为向量,然后接两个全连接层,第一个全连接层加一个dropout,最后一个全连接 ...

  4. CNN完成mnist数据集手写数字识别

    # coding: utf-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data d ...

  5. TensorFlow(十):卷积神经网络实现手写数字识别以及可视化

    上代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = inpu ...

  6. Google机器学习笔记(七)TF.Learn 手写文字识别

    转载请注明作者:梦里风林 Google Machine Learning Recipes 7 官方中文博客 - 视频地址 Github工程地址 https://github.com/ahangchen ...

  7. TensorFlow卷积神经网络实现手写数字识别以及可视化

    边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...

  8. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  9. 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)

    上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...

随机推荐

  1. 反射getDeclaredFields()

    public static void main(String[] args) { // 获取所有属性值 Field[] fields = People.class.getDeclaredFields( ...

  2. [python 学习] logging模块

    1.将简单日志打印到屏幕: import logging logging.debug('debug message') logging.info('info message') logging.war ...

  3. Dubbo学习-1-基础知识

    分布式基础理论 1.什么是分布式系统: 分布式系统是若干独立计算机的集合,这些计算机对于用户来说就像是单个相关系统.分布式系统是建立在网络之上的软件系统 随着互联网发展,网站应用规模的不断扩大,常规的 ...

  4. OC 类的load方法

    +(void)load 方法 不需要主动调用,类加载时会走这个方法

  5. Harbor修改暴露端口

    把原来的端口映射改成1180 一 修改docker-compose.yml [root@topcheer ~]# cat /mnt/harbor/docker-compose.yml version: ...

  6. 二叉搜索树第k个节点

    /* struct TreeNode { int val; struct TreeNode *left; struct TreeNode *right; TreeNode(int x) : val(x ...

  7. 在成为测试大牛的路上,我推荐BestTest

    BestTest-Python自动化测试9月份班开始招生啦! 网络+现场同步进行,课程新升级,web自动化+接口自动化双管齐下,一线互联网测试开发工程师带你在自动化的世界里自由翱翔! 推荐优惠多多,欢 ...

  8. [design pattern](0) 概述

    一 引语 大家好,这是我第一次在网上写文章.从学校毕业一年多,感觉还有很多东西需要去学习.最近正在学习设计模式,希望可以在博客园把我学习的知识记录下来,能够和大家一起讨论设计模式相关的话题,也希望这个 ...

  9. R 保存包含中文的 eps 图片--showtext

    来自统计之都,感谢 Ihavenothing(http://cos.name/cn/profile/81532) 详情参考:http://cos.name/cn/topic/151358?replie ...

  10. applicationContext.xml无错有红叉,Error occured processing XML 'Provider org.apache.xerces.parsers.解决方案

    applicationContext.xml无错有红叉,网上讲的取消xml验证的方法没用... 甚至我的myeclipse10连windows-->perferences-->myecli ...