数据集下载地址:http://www.nlpr.ia.ac.cn/databases/handwriting/download.html

chinese_write_detection.py

# -*- coding: utf-8 -*-
import tensorflow as tf
import os
import random
import tensorflow.contrib.slim as slim
import time
import numpy as np
import pickle
from PIL import Image
from log_utils import get_logger logger = get_logger("HandWritten Practice")
root_path = 'D:/eclipse-workspace/sxzsb'
tf.app.flags.DEFINE_boolean('random_flip_up_down', False, "Whether to random flip up down")
tf.app.flags.DEFINE_boolean('random_brightness', True, "whether to adjust brightness")
tf.app.flags.DEFINE_boolean('random_contrast', True, "whether to random constrast") tf.app.flags.DEFINE_integer('charset_size', 3755, "Choose the first `charset_size` character to conduct our experiment.")
tf.app.flags.DEFINE_integer('image_size', 64, "Needs to provide same value as in training.")
tf.app.flags.DEFINE_boolean('gray', True, "whether to change the rbg to gray")
tf.app.flags.DEFINE_integer('max_steps', 12002, 'the max training steps ')
tf.app.flags.DEFINE_integer('eval_steps', 50, "the step num to eval")
tf.app.flags.DEFINE_integer('save_steps', 2000, "the steps to save") tf.app.flags.DEFINE_string('checkpoint_dir', 'D:/eclipse-workspace/sxzsb/checkpoint', 'the checkpoint dir')
tf.app.flags.DEFINE_string('train_data_dir', 'D:/eclipse-workspace/sxzsb/data/train', 'the train dataset dir(containing png files)')
tf.app.flags.DEFINE_string('test_data_dir', 'D:/eclipse-workspace/sxzsb/data/test', 'the test dataset dir(containing png files)')
tf.app.flags.DEFINE_string('log_dir', 'D:/eclipse-workspace/sxzsb/log', 'the logging path)') tf.app.flags.DEFINE_boolean('restore', False, 'whether to restore from checkpoint')
tf.app.flags.DEFINE_integer('epoch', 1, 'Number of epoches')
tf.app.flags.DEFINE_integer('batch_size', 128, 'Validation batch size')
tf.app.flags.DEFINE_string('mode', 'train', 'Running mode. One of {"train", "valid", "test"}')
FLAGS = tf.app.flags.FLAGS class DataIterator: def __init__(self, data_dir):
# Set FLAGS.charset_size to a small value if available computation power is limited.
truncate_path = data_dir + ('%05d' % FLAGS.charset_size)
print(truncate_path)
self.image_names = []
for root, sub_folder, file_list in os.walk(data_dir):
if root < truncate_path: # some problem here ,because the first root is contain inside ,and there is no file_list
self.image_names += [os.path.join(root, file_path) for file_path in file_list]
random.shuffle(self.image_names)
self.labels = [int(file_name[len(data_dir):].split(os.sep)[0]) for file_name in self.image_names] # int("00020") output:20 @property
def size(self): # @property,负责把一个方法变成属性调用的,还可以定义只读属性,只定义getter方法,不定义setter方法就是一个只读属性
return len(self.labels) @staticmethod
def data_augmentation(images):
if FLAGS.random_flip_up_down:
images = tf.image.random_flip_up_down(images)
if FLAGS.random_brightness:
images = tf.image.random_brightness(images, max_delta=0.3)
if FLAGS.random_contrast:
images = tf.image.random_contrast(images, 0.8, 1.2)
return images def input_pipeline(self, batch_size, num_epochs=None, aug=False):
# 1、convert images to a tensor 构造数据queue
images_tensor = tf.convert_to_tensor(self.image_names, dtype=tf.string)
# 执行tf.convert_to_tensor()的时候,在图上生成了一个Op,Op中保存了传入参数的数据。op经过计算产生tensor
labels_tensor = tf.convert_to_tensor(self.labels, dtype=tf.int64)
input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor], num_epochs=num_epochs)
# 2、 ## queue输出数据
labels = input_queue[1]
images_content = tf.read_file(input_queue[0]) # read images from the queue,refer to input_queue
images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32)
if aug:
images = self.data_augmentation(images)
new_size = tf.constant([FLAGS.image_size, FLAGS.image_size], dtype=tf.int32)
images = tf.image.resize_images(images, new_size)
# collect batches of images before processing
# 3、shuffle_batch批量从queu批量读取数据
image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size, capacity=50000,
min_after_dequeue=10000) # produce shunffled batch
return image_batch, label_batch def build_graph(top_k):
# with tf.device('/cpu:0'):
keep_prob = tf.placeholder(dtype=tf.float32, shape=[], name='keep_prob')
images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1], name='image_batch')
labels = tf.placeholder(dtype=tf.int64, shape=[None], name='label_batch') conv_1 = slim.conv2d(images, 64, [3, 3], 1, padding='SAME', scope='conv1')
# (inputs,num_outputs,[卷积核个数] kernel_size,[卷积核的高度,卷积核的宽]stride=1,padding='SAME',)
max_pool_1 = slim.max_pool2d(conv_1, [2, 2], [2, 2], padding='SAME')
conv_2 = slim.conv2d(max_pool_1, 128, [3, 3], padding='SAME', scope='conv2')
max_pool_2 = slim.max_pool2d(conv_2, [2, 2], [2, 2], padding='SAME')
conv_3 = slim.conv2d(max_pool_2, 256, [3, 3], padding='SAME', scope='conv3')
max_pool_3 = slim.max_pool2d(conv_3, [2, 2], [2, 2], padding='SAME') flatten = slim.flatten(max_pool_3)
fc1 = slim.fully_connected(tf.nn.dropout(flatten, keep_prob), 1024, activation_fn=tf.nn.tanh, scope='fc1')
logits = slim.fully_connected(tf.nn.dropout(fc1, keep_prob), FLAGS.charset_size, activation_fn=None, scope='fc2')
# logits = slim.fully_connected(flatten, FLAGS.charset_size, activation_fn=None, reuse=reuse, scope='fc')
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))
# y表示的是实际类别,y_表示预测结果,这实际上面是把原来的神经网络输出层的softmax和cross_entrop何在一起计算,为了追求速度
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32)) global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False) # global_step interesting sharing varialbes
rate = tf.train.exponential_decay(2e-4, global_step, decay_steps=2000, decay_rate=0.97, staircase=True)
train_op = tf.train.AdamOptimizer(learning_rate=rate).minimize(loss, global_step=global_step) # train_op 包含了训练数据
probabilities = tf.nn.softmax(logits) # 上一个用logits是soft_max和cross_entropy一起算的,这次只是算了soft_max输出 tf.summary.scalar('loss', loss)
tf.summary.scalar('accuracy', accuracy)
merged_summary_op = tf.summary.merge_all()
predicted_val_top_k, predicted_index_top_k = tf.nn.top_k(probabilities, k=top_k)
accuracy_in_top_k = tf.reduce_mean(tf.cast(tf.nn.in_top_k(probabilities, labels, top_k), tf.float32)) # 这个思路真是清奇!!!看来我回答对了 # return the operator
return {'images': images,
'labels': labels,
'keep_prob': keep_prob,
'top_k': top_k,
'global_step': global_step,
'train_op': train_op,
'loss': loss,
'accuracy': accuracy,
'accuracy_top_k': accuracy_in_top_k,
'merged_summary_op': merged_summary_op,
'predicted_distribution': probabilities,
'predicted_index_top_k': predicted_index_top_k,
'predicted_val_top_k': predicted_val_top_k} def train():
print('Begin training')
train_feeder = DataIterator(data_dir='../data/train/')
test_feeder = DataIterator(data_dir='../data/test/')
with tf.Session() as sess:
# session操作之前启动队列runners才能激活pipelines/input pipeline 并载入数据
train_images, train_labels = train_feeder.input_pipeline(batch_size=FLAGS.batch_size, aug=True) # num_epochs what's refer to ?
test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size)
graph = build_graph(top_k=1) # very important
sess.run(tf.global_variables_initializer())
# 4、 ## 启动queue线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
saver = tf.train.Saver() train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/val')
start_step = 0
if FLAGS.restore: # 这里是加载保存好的模型,的到step继续训练
ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if ckpt:
saver.restore(sess, ckpt)
print("restore from the checkpoint {0}".format(ckpt))
start_step += int(ckpt.split('-')[-1]) logger.info(':::Training Start:::')
try:
while not coord.should_stop(): ###----
start_time = time.time()
print(start_time)
train_images_batch, train_labels_batch = sess.run([train_images, train_labels])
print(len(train_images_batch))
feed_dict = {graph['images']: train_images_batch,
graph['labels']: train_labels_batch,
graph['keep_prob']: 0.8} # keep 80% connection
_, loss_val, train_summary, step = sess.run(
[graph['train_op'], graph['loss'], graph['merged_summary_op'], graph['global_step']],
feed_dict=feed_dict)
train_writer.add_summary(train_summary, step)
end_time = time.time()
logger.info("the step {0} takes {1} loss {2}".format(step, end_time - start_time, loss_val))
if step > FLAGS.max_steps:
break
if step % FLAGS.eval_steps == 1:
test_images_batch, test_labels_batch = sess.run([test_images, test_labels])
feed_dict = {graph['images']: test_images_batch,
graph['labels']: test_labels_batch,
graph['keep_prob']: 1.0}
accuracy_test, test_summary = sess.run(
[graph['accuracy'], graph['merged_summary_op']],
feed_dict=feed_dict) # 这里的多层括号问题
test_writer.add_summary(test_summary, step)
logger.info('===============Eval a batch=======================')
logger.info('the step {0} test accuracy: {1}'
.format(step, accuracy_test))
logger.info('===============Eval a batch=======================')
if step % FLAGS.save_steps == 1:
logger.info('Save the ckpt of {0}'.format(step))
saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'),
global_step=graph['global_step'])
except tf.errors.OutOfRangeError:
logger.info('==================Train Finished================')
saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=graph['global_step'])
finally:
coord.request_stop() # 任何一个线程请求停止,则coord.should_stop()就会返回True ,然后都停下来
coord.join(threads) def validation():
print('validation')
test_feeder = DataIterator(data_dir='../data/test/') final_predict_val = []
final_predict_index = []
groundtruth = [] with tf.Session() as sess:
test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size, num_epochs=1)
graph = build_graph(3) sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer()) # initialize test_feeder's inside state coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord) saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if ckpt is not None:
saver.restore(sess, ckpt)
print("restore from the checkpoint {0}".format(ckpt)) logger.info(':::Start validation:::')
try:
i = 0
acc_top_1, acc_top_k = 0.0, 0.0
while not coord.should_stop():
i += 1
start_time = time.time()
test_images_batch, test_labels_batch = sess.run([test_images, test_labels])
feed_dict = {graph['images']: test_images_batch,
graph['labels']: test_labels_batch,
graph['keep_prob']: 1.0}
batch_labels, probs, indices, acc_1, acc_k = sess.run([graph['labels'],
graph['predicted_val_top_k'],
graph['predicted_index_top_k'],
graph['accuracy'],
graph['accuracy_top_k']], feed_dict=feed_dict)
final_predict_val += probs.tolist()
final_predict_index += indices.tolist()
groundtruth += batch_labels.tolist()
acc_top_1 += acc_1
acc_top_k += acc_k
end_time = time.time()
logger.info("the batch {0} takes {1} seconds, accuracy = {2}(top_1) {3}(top_k)"
.format(i, end_time - start_time, acc_1, acc_k)) except tf.errors.OutOfRangeError:
logger.info('==================Validation Finished================')
acc_top_1 = acc_top_1 * FLAGS.batch_size / test_feeder.size
acc_top_k = acc_top_k * FLAGS.batch_size / test_feeder.size
logger.info('top 1 accuracy {0} top k accuracy {1}'.format(acc_top_1, acc_top_k))
finally:
coord.request_stop()
coord.join(threads)
return {'prob': final_predict_val, 'indices': final_predict_index, 'groundtruth': groundtruth} def inference(image):
print('inference')
temp_image = Image.open(image).convert('L')
temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size), Image.ANTIALIAS)
temp_image = np.asarray(temp_image) / 255.0
temp_image = temp_image.reshape([-1, 64, 64, 1])
with tf.Session() as sess:
logger.info('========start inference============')
# images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1])
# Pass a shadow label 0. This label will not affect the computation graph.
graph = build_graph(top_k=3)
saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if ckpt:
saver.restore(sess, ckpt)
predict_val, predict_index = sess.run([graph['predicted_val_top_k'], graph['predicted_index_top_k']],
feed_dict={graph['images']: temp_image, graph['keep_prob']: 1.0})
return predict_val, predict_index def main(_):
print(FLAGS.mode)
if FLAGS.mode == "train":
train()
elif FLAGS.mode == 'validation':
dct = validation() # thinking what is "dct"
result_file = 'result.dict'
logger.info('Write result into {0}'.format(result_file))
with open(result_file, 'wb') as f:
pickle.dump(dct, f)
logger.info('Write file ends')
elif FLAGS.mode == 'inference':
image_path = '../data/test/00159/75700.png'
final_predict_val, final_predict_index = inference(image_path) # figure out what is inference
logger.info('the result info label {0} predict index {1} predict_val {2}'.format(190, final_predict_index,
final_predict_val)) if __name__ == "__main__":
tf.app.run() # It's just a very quick wrapper that handles flag parsing and then dispatches to your own main.

log_utils.py

# -*- coding:utf-8 -*-
import os, os.path as osp
import time def strftime(t=None):
return time.strftime("%Y%m%d-%H%M%S", time.localtime(t or time.time())) #################
# Logging
#################
import logging
from logging.handlers import TimedRotatingFileHandler
logging.basicConfig(format="[ %(asctime)s][%(module)s.%(funcName)s] %(message)s") DEFAULT_LEVEL = logging.INFO
DEFAULT_LOGGING_DIR = osp.join("logs", "gcforest")
fh = None def init_fh():
global fh
if fh is not None:
return
if DEFAULT_LOGGING_DIR is None:
return
if not osp.exists(DEFAULT_LOGGING_DIR): os.makedirs(DEFAULT_LOGGING_DIR)
logging_path = osp.join(DEFAULT_LOGGING_DIR, strftime() + ".log")
fh = logging.FileHandler(logging_path)
fh.setFormatter(logging.Formatter("[ %(asctime)s][%(module)s.%(funcName)s] %(message)s")) def update_default_level(defalut_level):
global DEFAULT_LEVEL
DEFAULT_LEVEL = defalut_level def update_default_logging_dir(default_logging_dir):
global DEFAULT_LOGGING_DIR
DEFAULT_LOGGING_DIR = default_logging_dir def get_logger(name="HandWrittenPractice", level=None):
level = level or DEFAULT_LEVEL
logger = logging.getLogger(name)
logger.setLevel(level)
init_fh()
if fh is not None:
logger.addHandler(fh)
return logger

Train

python chinese_write_detection.py --mode=train --max_steps=200000 --eval_steps=1000 --save_steps=10000

Validation

python chinese_write_detection.py --mode=validation

Inference

python chinese_write_detection.py --mode=inference

tensorflow创建cnn网络进行中文手写文字识别的更多相关文章

  1. Atitit s2018.2 s2 doc list on home ntpc.docx  \Atiitt uke制度体系 法律 法规 规章 条例 国王诏书.docx \Atiitt 手写文字识别 讯飞科大 语音云.docx \Atitit 代码托管与虚拟主机.docx \Atitit 企业文化 每日心灵 鸡汤 值班 发布.docx \Atitit 几大研发体系对比 Stage-Gat

    Atitit s2018.2 s2 doc list on home ntpc.docx \Atiitt uke制度体系  法律 法规 规章 条例 国王诏书.docx \Atiitt 手写文字识别   ...

  2. 5 TensorFlow入门笔记之RNN实现手写数字识别

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  3. Tensorflow项目实战一:MNIST手写数字识别

    此模型中,输入是28*28*1的图片,经过两个卷积层(卷积+池化)层之后,尺寸变为7*7*64,将最后一个卷积层展成一个以为向量,然后接两个全连接层,第一个全连接层加一个dropout,最后一个全连接 ...

  4. CNN完成mnist数据集手写数字识别

    # coding: utf-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data d ...

  5. TensorFlow(十):卷积神经网络实现手写数字识别以及可视化

    上代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = inpu ...

  6. Google机器学习笔记(七)TF.Learn 手写文字识别

    转载请注明作者:梦里风林 Google Machine Learning Recipes 7 官方中文博客 - 视频地址 Github工程地址 https://github.com/ahangchen ...

  7. TensorFlow卷积神经网络实现手写数字识别以及可视化

    边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...

  8. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  9. 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)

    上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...

随机推荐

  1. 黑客已经瞄准5G网络,如何防止LTE网络攻击?

    黑客是如何攻击5G网络?即使5G进行大规模应用,LTE技术会被淘汰吗?那么我们应该如何防止LTE网络攻击? 5G-网络黑客 即将推出的5G网络也可能容易受到这些攻击,来自中国网络安全研究人员表示,尽管 ...

  2. 未来HTML5的发展前景如何?黑客专家是这样回答的

    如果你想进军IT行业,如果你准备好掌握一项新技术,那么就选择HTML5.近日,我们采访了国内知名网络黑客安全专家郭盛华,帮助您了解当今最重要的技术.在本篇文章中,黑客安全专家郭盛华回答了有关HTML5 ...

  3. springboot 整合Druid

    Druid连接池监控配置 1) 引入依赖 <!-- https://mvnrepository.com/artifact/com.alibaba/druid --> <depende ...

  4. 034:DTL常用过滤器(3)

    default过滤器: 如果值被评估为 False .比如 [] , "" , None , {} 等这些在 if 判断中为 False 的值,都会使用 default 过滤器提供 ...

  5. Test 6.24 T2 集合

    问题描述 有一个可重集合,一开始只有一个元素 0. 你可以进行若干轮操作,每轮你需要对集合中每个元素 x 执行以下三种操作之一: 将 x 变为 x+1; 选择两个非负整数 y,z 满足 y+z=x , ...

  6. Spark在MaxCompute的运行方式

    一.Spark系统概述 左侧是原生Spark的架构图,右边Spark on MaxCompute运行在阿里云自研的Cupid的平台之上,该平台可以原生支持开源社区Yarn所支持的计算框架,如Spark ...

  7. 【HDOJ6614】AND Minimum Spanning Tree(签到)

    题意:给定标号从1到n的n个点,链接两个点x,y的代价为x AND y,求最小生成树总代价与满足代价最小的前提下字典序最小的方案 n<=2e5 思路: #include<bits/stdc ...

  8. Win7 64位系统 注册 ocx控件

    32位系统注册ocx就不谈了.网上一搜一大把.下面说下win7 64位 旗舰版下如果注册ocx控件    1.首先复制 XXXX.OCX文件到“C:\Windows\SysWOW64”目录. (XXX ...

  9. 后端技术杂谈7:OpenStack的基石KVM

    Qemu,KVM,Virsh傻傻的分不清 本文转载自Itweet的博客 本系列文章将整理到我在GitHub上的<Java面试指南>仓库,更多精彩内容请到我的仓库里查看 https://gi ...

  10. 五一清北学堂培训之数据结构(Day 1&Day 2)

    Day 1 前置知识: 二进制 2.基本语法 3.时间复杂度 正片       1.枚举 洛谷P1046陶陶摘苹果  入门题没什么好说的 判断素数:从2枚举到sqrt(n),若出现n的因子,则n是合数 ...