手写数字识别-使用TensorFlow构建和训练一个简单的神经网络
下面是一个具体的Python代码示例,展示如何使用TensorFlow实现一个简单的神经网络来解决手写数字识别问题(使用MNIST数据集)。以下是一个完整的Python代码示例,展示如何使用TensorFlow构建和训练一个简单的神经网络来进行手写数字识别。
MNIST数据集的训练集有60000个样本:
Python代码
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
import json
import os
# 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# 预处理数据
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255
# 构建神经网络模型
def create_model():
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
# 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
# 训练模型并保存
def train_and_save_model():
model = create_model()
history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))
model.save('mnist_model.h5')
# 保存训练历史记录
with open('training_history.json', 'w') as f:
json.dump(history.history, f)
# 加载模型和历史记录
def load_model_and_history():
model = tf.keras.models.load_model('mnist_model.h5')
with open('training_history.json', 'r') as f:
history = json.load(f)
return model, history
# 评估模型
def evaluate_model(model):
test_loss, test_acc = model.evaluate(test_images, test_labels)
print("Test accuracy: {}".format(test_acc))
# 可视化训练过程
def plot_training_history(history):
plt.plot(history['accuracy'], label='accuracy')
plt.plot(history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0, 1])
plt.legend(loc='lower right')
plt.show()
# 检查是否已经存在模型和历史记录
if not os.path.exists('mnist_model.h5') or not os.path.exists('training_history.json'):
train_and_save_model()
model, training_history = load_model_and_history()
evaluate_model(model)
plot_training_history(training_history)
代码解释
加载MNIST数据集:
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
预处理数据:
- 将图像数据调整为
(28, 28, 1)的形状。 - 将像素值标准化为
[0, 1]之间。
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255
- 将图像数据调整为
构建神经网络模型:
- 使用
Sequential模型,按顺序添加层。 - 添加卷积层、池化层、全连接层。
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
- 使用
编译模型:
- 使用
adam优化器,损失函数为sparse_categorical_crossentropy,评估指标为accuracy。
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
- 使用
训练模型:
- 训练模型5个epochs,并使用验证数据集评估模型性能。
history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))
评估模型:
- 在测试集上评估模型性能,并打印测试准确率。
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f"Test accuracy: {test_acc}")
可视化训练过程:
- 绘制训练和验证准确率随epoch变化的曲线。
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0, 1])
plt.legend(loc='lower right')
plt.show()
通过这个修正后的示例,应该可以正常运行并训练一个简单的神经网络模型来进行手写数字识别。
手写数字识别-使用TensorFlow构建和训练一个简单的神经网络的更多相关文章
- Softmax用于手写数字识别(Tensorflow实现)-个人理解
softmax函数的作用 对于分类方面,softmax函数的作用是从样本值计算得到该样本属于各个类别的概率大小.例如手写数字识别,softmax模型从给定的手写体图片像素值得出这张图片为数字0~9 ...
- 基于卷积神经网络的手写数字识别分类(Tensorflow)
import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...
- TensorFlow 之 手写数字识别MNIST
官方文档: MNIST For ML Beginners - https://www.tensorflow.org/get_started/mnist/beginners Deep MNIST for ...
- TensorFlow使用RNN实现手写数字识别
学习,笔记,有时间会加注释以及函数之间的逻辑关系. # https://www.cnblogs.com/felixwang2/p/9190664.html # https://www.cnblogs. ...
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...
- 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)
上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...
- 第二节,TensorFlow 使用前馈神经网络实现手写数字识别
一 感知器 感知器学习笔记:https://blog.csdn.net/liyuanbhu/article/details/51622695 感知器(Perceptron)是二分类的线性分类模型,其输 ...
- python-积卷神经网络全面理解-tensorflow实现手写数字识别
首先,关于神经网络,其实是一个结合很多知识点的一个算法,关于cnn(积卷神经网络)大家需要了解: 下面给出我之前总结的这两个知识点(基于吴恩达的机器学习) 代价函数: 代价函数 代价函数(Cost F ...
- Tensorflow之MNIST手写数字识别:分类问题(1)
一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点: 1.将离散特征的取值扩展 ...
- Tensorflow实战 手写数字识别(Tensorboard可视化)
一.前言 为了更好的理解Neural Network,本文使用Tensorflow实现一个最简单的神经网络,然后使用MNIST数据集进行测试.同时使用Tensorboard对训练过程进行可视化,算是打 ...
随机推荐
- 在唯一密钥属性“name”设置为“XXX”时,无法添加类型为“add”的重复集合项
我是在调试时,更改了项目url出现的问题,没有改端口号,只是改了"/"后面的地址 这个是我是改哈端口号就好了,改了端口号就重新建立虚拟目录了. 感觉是因为端口号没变,但项目url变 ...
- 创建Django项目和应用
1.创建Django项目 django-admin startproject 项目名称 2.创建应用app python manage.py startapp app名称
- Vue cli之组件的嵌套
前面显示Home.vue页面组件的内容时,我们是在App.vue通过import导入使用的.这个过程就是组件的嵌套使用.那么我们除了App.vue可以导入其他页面以外,也可以通过在Home.vue中导 ...
- GitHub two-factor authentication开启教程
问题描述 最近登录GitHub个人页面动不动就有一个提示框"...... two-factor authentication will be required for your accoun ...
- 通过Webpack搭建react
安装解析react的相关babel和插件 nmp i -D babel-loader @babel/core @babel/preset-react @babel/preset-env 进行loade ...
- 设计模式:命令模式(Command Pattern)及实例
好家伙, 0.什么是命令模式 在软件系统中,"行为请求者"与"行为实现者"通常呈现一种"紧耦合". 但在某些场合,比如要对行为进行&q ...
- chrome edge CORS 允许跨域
edge: edge://flags/#block-insecure-private-network-requests chrome: 在谷歌浏览器地址栏输入"chrome://flags/ ...
- vs2019安装使用Python3.9教程
现在vs2019只支持到Python3.7,如果要使用3.9,需要自己下载Python3.9的包 步骤: 一.在开始菜单中找到Microsoft Store搜索"Python3.9" ...
- Scrapy框架(八)--CrawlSpider
CrawlSpider类,Spider的一个子类 - 全站数据爬取的方式 - 基于Spider:手动请求 - 基于CrawlSpider - CrawlSpider的使用: - 创建一个工程 - cd ...
- 使用shell脚本在Linux中管理Java应用程序
目录 前言 一.目录结构 二.脚本实现 1. 脚本内容 2. 使用说明 2.1 配置脚本 2.2 脚本部署 2.3 操作你的Java应用 总结 前言 在日常开发和运维工作中,管理基于Java的应用程序 ...