数据集

数据集下载MNIST

首先读取数据集, 并打印相关信息

包括

  • 图像的数量, 形状
  • 像素的最大, 最小值
  • 以及看一下第一张图片
path = 'MNIST/mnist.npz'
with np.load(path, allow_pickle=True) as f:
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test'] print(f'dataset info: shape: {x_train.shape}, {y_train.shape}')
print(f'dataset info: max: {x_train.max()}')
print(f'dataset info: min: {x_train.min()}') print("A sample:")
print("y_train: ", y_train[0])
# print("x_train: \n", x_train[0])
show_pic = x_train[0].copy()
show_pic = cv2.resize(show_pic, (28 * 10, 28 * 10))
cv2.imshow("A image sample", show_pic)
key = cv2.waitKey(0)
# 按 q 退出
if key == ord('q'):
cv2.destroyAllWindows()
print("show demo over")

转换为tf 数据集的格式, 并进行归一化

# convert to tf tensor
x_train = tf.convert_to_tensor(x_train, dtype=tf.float32) // 255.
x_test = tf.convert_to_tensor(x_test, dtype=tf.float32) // 255.
dataset_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset_train = dataset_train.batch(batch_size).repeat(class_num)

定义网络

在这里定义一个简单的全连接网络

def build_simple_net():
net = Sequential([
layers.Dense(256, activation='relu'),
layers.Dense(256, activation='relu'),
layers.Dense(256, activation='relu'),
layers.Dense(class_num)
])
net.build(input_shape=(None, 28 * 28))
# net.summary()
return net

训练

使用 SGD 优化器进行训练

def train(print_info_step=250):
net = build_simple_net()
# 优化器
optimizer = optimizers.SGD(lr=0.01)
# 计算准确率
acc = metrics.Accuracy() for step, (x, y) in enumerate(dataset_train):
with tf.GradientTape() as tape:
# [b, 28, 28] => [b, 784]
x = tf.reshape(x, (-1, 28 * 28))
# [b, 784] => [b, 10]
out = net(x)
# [b] => [b, 10]
y_onehot = tf.one_hot(y, depth=class_num)
# [b, 10]
loss = tf.square(out - y_onehot)
# [b]
loss = tf.reduce_sum(loss) / batch_size # 反向传播
acc.update_state(tf.argmax(out, axis=1), y)
grads = tape.gradient(loss, net.trainable_variables)
optimizer.apply_gradients(zip(grads, net.trainable_variables)) if acc.result() >= 0.90:
net.save_weights(save_path)
print(f'final acc: {acc.result()}, total step: {step}')
break if step % print_info_step == 0:
print(f'step: {step}, loss: {loss}, acc: {acc.result().numpy()}')
acc.reset_states() if step % 500 == 0 and step != 0:
print('save model')
net.save_weights(save_path)

验证

验证在测试集的模型效果, 这里仅取出第一张进行验证

def test_dataset():
net = build_simple_net()
# 加载模型
net.load_weights(save_path)
# 拿到测试集第一张图片
pred_image = x_test[0]
pred_image = tf.reshape(pred_image, (-1, 28 * 28))
pred = net.predict(pred_image)
# print(pred)
print(f'pred: {tf.argmax(pred, axis=1).numpy()}, label: {y_test[0]}')

应用

分割手写数字, 并进行逐一识别

  • 先将图像二值化
  • 找到轮廓
  • 得到数字的坐标
  • 转为模型的需要的输入格式, 并进行识别
  • 显示
def split_number(img):
result = []
net = build_simple_net()
# 加载模型
net.load_weights(save_path) image = cv2.cvtColor(img.copy(), cv2.COLOR_RGB2GRAY)
ret, thresh = cv2.threshold(image, 127, 255, 0)
contours, hierarchy = cv2.findContours(thresh, 1, 2)
for cnt in contours[:-1]:
x, y, w, h = cv2.boundingRect(cnt) image = img[y:y+h, x:x+w]
image = cv2.resize(image, (28, 28)) pred_image = tf.convert_to_tensor(image, dtype=tf.float32) / 255.
pred_image = tf.reshape(pred_image, (-1, 28 * 28))
pred = net.predict(pred_image)
out = tf.argmax(pred, axis=1).numpy()
result = [out[0]] + result
img = cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 2) cv2.imshow("demo", img)
print(result)
k = cv2.waitKey(0)
# 按 q 退出
if k == ord('q'):
pass
cv2.destroyAllWindows()

效果

单数字



多数字



附录

所有代码, 文件 tf2_mnist.py

import os
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Sequential, optimizers, metrics # 屏蔽通知信息和警告信息
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 每批几张图片
batch_size = 2
# 类别数
class_num = 10
# 保存模型的路径
save_path = "./models/mnist.ckpt"
# 展示样例
show_demo = False
# 验证测试集
evaluate_dataset = False
# 是否训练
run_train = False
# 图片路径, 仅用于 detect_image(), 当为False时不识别
image_path = 'images/36.png' path = 'MNIST/mnist.npz'
with np.load(path, allow_pickle=True) as f:
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test'] if show_demo:
print(f'dataset info: shape: {x_train.shape}, {y_train.shape}')
print(f'dataset info: max: {x_train.max()}')
print(f'dataset info: min: {x_train.min()}') print("A sample:")
print("y_train: ", y_train[0])
# print("x_train: \n", x_train[0])
show_pic = x_train[0].copy()
show_pic = cv2.resize(show_pic, (28 * 10, 28 * 10))
cv2.imshow("A image sample", show_pic)
key = cv2.waitKey(0)
if key == ord('q'):
cv2.destroyAllWindows()
print("show demo over") # convert to tf tensor
x_train = tf.convert_to_tensor(x_train, dtype=tf.float32) // 255.
x_test = tf.convert_to_tensor(x_test, dtype=tf.float32) // 255.
dataset_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset_train = dataset_train.batch(batch_size).repeat(class_num) def build_simple_net():
net = Sequential([
layers.Dense(256, activation='relu'),
layers.Dense(256, activation='relu'),
layers.Dense(256, activation='relu'),
layers.Dense(class_num)
])
net.build(input_shape=(None, 28 * 28))
# net.summary()
return net def train(print_info_step=250):
net = build_simple_net()
# 优化器
optimizer = optimizers.SGD(lr=0.01)
# 计算准确率
acc = metrics.Accuracy() for step, (x, y) in enumerate(dataset_train):
with tf.GradientTape() as tape:
# [b, 28, 28] => [b, 784]
x = tf.reshape(x, (-1, 28 * 28))
# [b, 784] => [b, 10]
out = net(x)
# [b] => [b, 10]
y_onehot = tf.one_hot(y, depth=class_num)
# [b, 10]
loss = tf.square(out - y_onehot)
# [b]
loss = tf.reduce_sum(loss) / batch_size # 反向传播
acc.update_state(tf.argmax(out, axis=1), y)
grads = tape.gradient(loss, net.trainable_variables)
optimizer.apply_gradients(zip(grads, net.trainable_variables)) if acc.result() >= 0.90:
net.save_weights(save_path)
print(f'final acc: {acc.result()}, total step: {step}')
break if step % print_info_step == 0:
print(f'step: {step}, loss: {loss}, acc: {acc.result().numpy()}')
acc.reset_states() if step % 500 == 0 and step != 0:
print('save model')
net.save_weights(save_path) def test_dataset():
net = build_simple_net()
# 加载模型
net.load_weights(save_path)
# 拿到测试集第一张图片
pred_image = x_test[0]
pred_image = tf.reshape(pred_image, (-1, 28 * 28))
pred = net.predict(pred_image)
# print(pred)
print(f'pred: {tf.argmax(pred, axis=1).numpy()}, label: {y_test[0]}') def split_number(img):
result = []
net = build_simple_net()
# 加载模型
net.load_weights(save_path) image = cv2.cvtColor(img.copy(), cv2.COLOR_RGB2GRAY)
ret, thresh = cv2.threshold(image, 127, 255, 0)
contours, hierarchy = cv2.findContours(thresh, 1, 2)
for cnt in contours[:-1]:
x, y, w, h = cv2.boundingRect(cnt) image = img[y:y+h, x:x+w]
image = cv2.resize(image, (28, 28)) pred_image = tf.convert_to_tensor(image, dtype=tf.float32) / 255.
pred_image = tf.reshape(pred_image, (-1, 28 * 28))
pred = net.predict(pred_image)
out = tf.argmax(pred, axis=1).numpy()
result = [out[0]] + result
img = cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 2) cv2.imshow("demo", img)
print(result)
k = cv2.waitKey(0)
if k == ord('q'):
pass
cv2.destroyAllWindows() if __name__ == '__main__':
if run_train:
train()
elif evaluate_dataset:
test_dataset()
elif image_path:
image = cv2.imread(image_path)
# detect_image(image)
split_number(image)

linux-基于tensorflow2.x的手写数字识别-基于MNIST数据集的更多相关文章

  1. 基于Numpy的神经网络+手写数字识别

    基于Numpy的神经网络+手写数字识别 本文代码来自Tariq Rashid所著<Python神经网络编程> 代码分为三个部分,框架如下所示: # neural network class ...

  2. 手写数字识别——基于LeNet-5卷积网络模型

    在<手写数字识别——利用Keras高层API快速搭建并优化网络模型>一文中,我们搭建了全连接层网络,准确率达到0.98,但是这种网络的参数量达到了近24万个.本文将搭建LeNet-5网络, ...

  3. 【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)

    主要内容: 1.基于CNN的mnist手写数字识别(详细代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...

  4. 【TensorFlow-windows】(三) 多层感知器进行手写数字识别(mnist)

    主要内容: 1.基于多层感知器的mnist手写数字识别(代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...

  5. TensorFlow.NET机器学习入门【5】采用神经网络实现手写数字识别(MNIST)

    从这篇文章开始,终于要干点正儿八经的工作了,前面都是准备工作.这次我们要解决机器学习的经典问题,MNIST手写数字识别. 首先介绍一下数据集.请首先解压:TF_Net\Asset\mnist_png. ...

  6. 基于卷积神经网络的手写数字识别分类(Tensorflow)

    import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...

  7. 基于多层感知机的手写数字识别(Tensorflow实现)

    import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...

  8. 吴裕雄--天生自然python机器学习:基于支持向量机SVM的手写数字识别

    from numpy import * def img2vector(filename): returnVect = zeros((1,1024)) fr = open(filename) for i ...

  9. 【TensorFlow-windows】(一)实现Softmax Regression进行手写数字识别(mnist)

    博文主要内容有: 1.softmax regression的TensorFlow实现代码(教科书级的代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3 ...

随机推荐

  1. 测试开发【Mock平台】04实战:前后端项目初始化与登录鉴权实现

    [Mock平台]为系列测试开发教程,从0到1编码带你一步步使用Spring Boot 和 Antd React 框架完成搭建一个测试工具平台,希望作为一个实战项目能为你的测试开发学习有帮助. 一.后端 ...

  2. pod和容器(容易混淆的地方)

    在Kubenetes中,所有的容器均在 pod 中运行,一个pod可以承载一个或者多个相关的docker容器(或rkt,以及用户自定义容器),同一个Pod中的容器可以部署在同一个物理机器(可以叫宿主机 ...

  3. Figma禁封中国企业,下一个会是Postman吗?国产软件势在必行!

    ​ "新冷战"蔓延到生产力工具 著名 UI 设计软件 Figma 宣布制裁大疆! 近日,网上流传一份 Figma 发送给大疆的内部邮件.其中写道: "我们了解到,大疆在美 ...

  4. 华为组播实验,PIM-DM组播实验

    一.配置VLAN,并将端口加入VLAN LSW5: system vlan batch 10 to 100 int g 0/0/1 port link-type trunk port trunk al ...

  5. 公私钥 SSH 数字证书

    公私钥 SSH 数字证书 小菜鸟今天买了华为云一台服务器,在使用公私钥远程登录服务器的时候,忘记了相关公钥私钥的原理和一些应用了,今天复习一波做个记录. 相关概念 公钥:公钥用来给数据加密,用公钥加密 ...

  6. Python 交互式解释器的二三事

    学 Python 不知道何时起成了一种风尚.这里,我也随便聊聊跟Python 的交互式解释器的几个有意思的小问题. 如何进入 Python 交互解释器? 当你安装好 Python 后,如何进入 Pyt ...

  7. SpringAOP的源码解析

    一.SpringAOP的概念 一.AOP的基本概念 1.连接点(Joinpoint):可以被增强的方法. 2.切点(Pointcut):实际被增强的方法. 3.通知(Advice)(增强): 3.1. ...

  8. MySQL进阶之表的增删改查

    我的小站 修改表名 ALTER TABLE student RENAME TO stu; TO可以省略. ALTER TABLE 旧表名 RENAME 新表名; 此语句可以修改表的名称,其实一般我们在 ...

  9. Day 005:PAT练习--1047. 编程团体赛(20)

    编程团体赛的规则为:每个参赛队由若干队员组成:所有队员独立比赛:参赛队的成绩为所有队员的成绩和:成绩最高的队获胜.现给定所有队员的比赛成绩,请你编写程序找出冠军队. 输入格式: 输入第一行给出一个正整 ...

  10. HashMap源码理解一下?

    HashMap 是一个散列桶(本质是数组+链表),散列桶就是数据结构里面的散列表,每个数组元素是一个Node节点,该节点又链接着多个节点形成一个链表,故一个数组元素 = 一个链表,利用了数组线性查找和 ...