# -*- coding: utf-8 -*-

import tensorflow as tf
import os
import random
import tensorflow.contrib.slim as slim
import time
import numpy as np
import pickle
from PIL import Image mode = "inference"
char_size = 3755
epochs = 5
batch_size = 128
checkpoint_dir = '/aiml/code/'
#train_data_dir = 'D:/Yang/softwares/Spider_ws/WordRecognition/data/train/'
#test_data_dir = 'D:/Yang/softwares/Spider_ws/WordRecognition/data/test/' class DataIterator:
def __init__(self, data_dir):
self.image_names = []
for root, sub_folder, file_list in os.walk(data_dir):
self.image_names += [os.path.join(root, file_path) for file_path in file_list]
random.shuffle(self.image_names)
self.labels = [int(file_name[len(data_dir):].split(os.sep)[0]) for file_name in self.image_names] @property
def size(self):
return len(self.labels) def input_pipeline(self, batch_size, num_epochs=None):
images_tensor = tf.convert_to_tensor(self.image_names, dtype=tf.string)
labels_tensor = tf.convert_to_tensor(self.labels, dtype=tf.int64)
input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor], num_epochs=num_epochs) labels = input_queue[1]
images_content = tf.read_file(input_queue[0])
images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32)
new_size = tf.constant([64, 64], dtype=tf.int32)
images = tf.image.resize_images(images, new_size)
image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size, capacity=50000,
min_after_dequeue=10000)
return image_batch, label_batch def build_graph(top_k):
# with tf.device('/cpu:0'):
images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1], name='input_image')
labels = tf.placeholder(dtype=tf.int64, shape=[None], name='label_batch') conv_1 = slim.conv2d(images, 64, [3, 3], 1, padding='SAME', scope='conv1')
max_pool_1 = slim.max_pool2d(conv_1, [2, 2], [2, 2], padding='SAME')
conv_2 = slim.conv2d(max_pool_1, 128, [3, 3], padding='SAME', scope='conv2')
max_pool_2 = slim.max_pool2d(conv_2, [2, 2], [2, 2], padding='SAME')
conv_3 = slim.conv2d(max_pool_2, 256, [3, 3], padding='SAME', scope='conv3')
max_pool_3 = slim.max_pool2d(conv_3, [2, 2], [2, 2], padding='SAME') flatten = slim.flatten(max_pool_3)
fc1 = slim.fully_connected(flatten, 1024, activation_fn=tf.nn.tanh, scope='fc1')
logits = slim.fully_connected(fc1, char_size, activation_fn=None, scope='output_logit') loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32)) global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False)
rate = tf.train.exponential_decay(2e-4, global_step, decay_steps=2000, decay_rate=0.97, staircase=True)
train_op = tf.train.AdamOptimizer(learning_rate=rate).minimize(loss, global_step = global_step)
probabilities = tf.nn.softmax(logits)
pred = tf.identity(probabilities, name = 'prediction') return {'images': images,
'labels': labels,
'global_step': global_step,
'train_op': train_op,
'loss': loss,
'accuracy': accuracy} def train():
train_feeder = DataIterator(data_dir=train_data_dir)
test_feeder = DataIterator(data_dir=test_data_dir)
with tf.Session() as sess:
train_images, train_labels = train_feeder.input_pipeline(batch_size)
test_images, test_labels = test_feeder.input_pipeline(batch_size)
graph = build_graph(top_k=1)
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
saver = tf.train.Saver() print (':::Training Start:::')
try:
while not coord.should_stop():
start_time = time.time()
train_images_batch, train_labels_batch = sess.run([train_images, train_labels])
feed_dict = {graph['images']: train_images_batch,
graph['labels']: train_labels_batch}
_, loss_val, step = sess.run(
[graph['train_op'], graph['loss'], graph['global_step']],
feed_dict=feed_dict)
end_time = time.time()
if step % 10 == 1:
print ("the step {0} takes {1} loss {2}".format(step, end_time - start_time, loss_val))
if step > 200000:
break
if step % 50 == 1:
test_images_batch, test_labels_batch = sess.run([test_images, test_labels])
feed_dict = {graph['images']: test_images_batch,
graph['labels']: test_labels_batch}
accuracy_test = sess.run(
graph['accuracy'],
feed_dict=feed_dict)
print ('===============Eval a batch=======================')
print ('the step {0} test accuracy: {1}'.format(step, accuracy_test))
print ('===============Eval a batch=======================')
if step % 200 == 1:
print ('Save the ckpt of {0}'.format(step))
saver.save(sess, os.path.join(checkpoint_dir, 'my-model'),
global_step=graph['global_step'])
except tf.errors.OutOfRangeError:
print ('==================Train Finished================')
saver.save(sess, os.path.join(checkpoint_dir, 'my-model'), global_step=graph['global_step'])
finally:
coord.request_stop()
coord.join(threads) def new_inference(predict_dir):
saver = tf.train.import_meta_graph( checkpoint_dir + "my-model-164152.meta", clear_devices=True)
image_list = []
new_file_list = []
for root, _, file_list in os.walk(predict_dir):
new_file_list += [file for file in file_list if ".nfs" not in file]
new_file_list.sort(key= lambda x:int(x[:-4]))
for file in new_file_list:
# print (new_file_list)
image = os.path.join(root, file)
temp_image = Image.open(image).convert('L')
temp_image = temp_image.resize((64, 64), Image.ANTIALIAS)
temp_image = np.asarray(temp_image) / 255.0
image_list.append(temp_image)
image_list = np.asarray(image_list)
temp_image = image_list.reshape([len(new_file_list), 64, 64, 1])
with tf.Session() as sess:
saver.restore(sess, checkpoint_dir + "my-model-164152") #读入模型参数
graph = tf.get_default_graph()
op = graph.get_tensor_by_name("prediction:0")
input_tensor = graph.get_tensor_by_name('input_image:0')
probs = sess.run(op,feed_dict = {input_tensor:temp_image})
result = []
for word in probs:
result.append(np.argsort(-word)[:3])
return result def main(): if mode == "train":
train()
if mode == 'inference':
word_dict = pickle.load(open("/aiml/code/word_dict", "rb"))
image_path = '/aiml/data/'
index = new_inference(image_path)
file = open("/aiml/result/result.txt", "w")
# print ("预测文字为: ")
pred_list = []
for i in index:
# print ("最大几率三个:")
# print (word_dict[str(i[0])],word_dict[str(i[1])],word_dict[str(i[2])])
pred_list.append(word_dict[str(i[0])])
file.write(word_dict[str(i[0])]) if __name__ == "__main__":
# tf.app.run()
main()

  

cnn汉字识别 tensorflow demo的更多相关文章

  1. 深度学习之卷积神经网络CNN及tensorflow代码实例

    深度学习之卷积神经网络CNN及tensorflow代码实例 什么是卷积? 卷积的定义 从数学上讲,卷积就是一种运算,是我们学习高等数学之后,新接触的一种运算,因为涉及到积分.级数,所以看起来觉得很复杂 ...

  2. 深度学习之卷积神经网络CNN及tensorflow代码实现示例

    深度学习之卷积神经网络CNN及tensorflow代码实现示例 2017年05月01日 13:28:21 cxmscb 阅读数 151413更多 分类专栏: 机器学习 深度学习 机器学习   版权声明 ...

  3. android应用市场、社区客户端、漫画App、TensorFlow Demo、歌词显示、动画效果等源码

    Android精选源码 MVP架构Android应用市场项目 android刻度盘控件源码 Android实现一个社区客户端 android商品详情页上拉查看详情 基于RxJava+Retrofit2 ...

  4. TensorFlow —— Demo

    import tensorflow as tf g = tf.Graph() # 创建一个Graph对象 在模型中有两个"全局"风格的Variable对象:global_step ...

  5. TensorFlow 在android上的Demo(1)

    转载时请注明出处: 修雨轩陈 系统环境说明: ------------------------------------ 操作系统 : ubunt 14.03 _ x86_64 操作系统 内存: 8GB ...

  6. tensorflow学习之(十)使用卷积神经网络(CNN)分类手写数字0-9

    #卷积神经网络cnn import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #数据包,如 ...

  7. TensorFlow Lite demo——就是为嵌入式设备而存在的,底层调用NDK神经网络API,注意其使用的tf model需要转换下,同时提供java和C++ API,无法使用tflite的见后

    Introduction to TensorFlow Lite TensorFlow Lite is TensorFlow’s lightweight solution for mobile and ...

  8. Caffe、TensorFlow、MXnet三个开源库对比

    库名称 开发语言 支持接口 安装难度(ubuntu) 文档风格 示例 支持模型 上手难易 Caffe c++/cuda c++/python/matlab *** * *** CNN ** MXNet ...

  9. YOLO2:实时目标检测视频教程,视频演示, Android Demo ,开源教学项目,论文。

    实时目标检测和分类 GIF 图: 视频截图: 论文: https://arxiv.org/pdf/1506.02640.pdf https://arxiv.org/pdf/1612.08242.pdf ...

随机推荐

  1. 几种支持动作模型格式的比较(MD2,MD5,sea3d) 【转】

    最近使用了几种不同的模型格式做人物动作的表现,记录一下优缺点   1) MD2 数据内容: 记录了所有动作顶点数据 数据格式: 二进制 动作文件: 动作文件合并在一个模型文件 文件大小: 动作多时很大 ...

  2. python 工具 图片批量合并

    注:代码两处设置 region = (4,3,x-3,y-5) 目的是crop剪去图片的白边,这个可以视情况改变 图片需要命名为   x_1.png   .....这样的格式 #encoding=ut ...

  3. chardet的使用

    http://blog.csdn.net/jy692405180/article/details/52496599

  4. Oracle-31-对视图DML操作

    一.对视图进行DML操作 1.创建一个视图v_person create or replace noforceview v_person as select *from person where id ...

  5. binary-tree-level-order-traversal I、II——输出二叉树的数字序列

    I Given a binary tree, return the level order traversal of its nodes' values. (ie, from left to righ ...

  6. Ajax 跨域难题 - 原生 JS 和 jQuery 的实现对比

    讲解顺序: AJAX 的概念及由来 JS 和 jQuery 中的 ajax 浏览器机制 AJAX 跨域 AJAX 的概念 在讲解 AJAX 的概念之前,我先提一个问题. 这是一个典型的 B/S 模式. ...

  7. C语言函数的递归和调用

    函数记住两点: (1)每个函数运行完才会返回调用它的函数:每个函数运行完才会返回调用它的函数,因此,你可以先看看这个函数不自我调用的条件,也就是fun()中if条件不成立的时候,对吧,不成立的时候就是 ...

  8. 电源滤波电容在PCB中正确的布线方法!

    电源滤波电容在PCB中正确的布线方法! 错误的电源滤波电容布线方法. 1.很多人朋友在设计的时候喜欢加宽这个电源的走,这个是一个很好的方法,但是他们如果一不小心就会忽略电容的布线. 下面的电容布线看起 ...

  9. Ubuntu 16.04下配置Golang开发环境

    安装之前先要明白两个变量,后面介绍安装时,会用这两个变量 GOROOT   , 这是go的工作目录,比如 /home/[替换为你的用户名]/go/work GOPATH    , 这是go的安装目录, ...

  10. svn 命令个

    svn 命令行下常用的几个命令 标签: svnpathdelete工作urlfile 2011-11-28 08:16 128627人阅读 评论(1) 收藏 举报  分类: 版本控制(8)  版权声明 ...