手写数字识别-使用TensorFlow构建和训练一个简单的神经网络
下面是一个具体的Python代码示例,展示如何使用TensorFlow实现一个简单的神经网络来解决手写数字识别问题(使用MNIST数据集)。以下是一个完整的Python代码示例,展示如何使用TensorFlow构建和训练一个简单的神经网络来进行手写数字识别。
MNIST
数据集的训练集有60000个样本:
Python代码
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
import json
import os
# 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# 预处理数据
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255
# 构建神经网络模型
def create_model():
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
# 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
# 训练模型并保存
def train_and_save_model():
model = create_model()
history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))
model.save('mnist_model.h5')
# 保存训练历史记录
with open('training_history.json', 'w') as f:
json.dump(history.history, f)
# 加载模型和历史记录
def load_model_and_history():
model = tf.keras.models.load_model('mnist_model.h5')
with open('training_history.json', 'r') as f:
history = json.load(f)
return model, history
# 评估模型
def evaluate_model(model):
test_loss, test_acc = model.evaluate(test_images, test_labels)
print("Test accuracy: {}".format(test_acc))
# 可视化训练过程
def plot_training_history(history):
plt.plot(history['accuracy'], label='accuracy')
plt.plot(history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0, 1])
plt.legend(loc='lower right')
plt.show()
# 检查是否已经存在模型和历史记录
if not os.path.exists('mnist_model.h5') or not os.path.exists('training_history.json'):
train_and_save_model()
model, training_history = load_model_and_history()
evaluate_model(model)
plot_training_history(training_history)
代码解释
加载MNIST数据集:
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
预处理数据:
- 将图像数据调整为
(28, 28, 1)
的形状。 - 将像素值标准化为
[0, 1]
之间。
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255
- 将图像数据调整为
构建神经网络模型:
- 使用
Sequential
模型,按顺序添加层。 - 添加卷积层、池化层、全连接层。
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
- 使用
编译模型:
- 使用
adam
优化器,损失函数为sparse_categorical_crossentropy
,评估指标为accuracy
。
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
- 使用
训练模型:
- 训练模型5个epochs,并使用验证数据集评估模型性能。
history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))
评估模型:
- 在测试集上评估模型性能,并打印测试准确率。
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f"Test accuracy: {test_acc}")
可视化训练过程:
- 绘制训练和验证准确率随epoch变化的曲线。
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0, 1])
plt.legend(loc='lower right')
plt.show()
通过这个修正后的示例,应该可以正常运行并训练一个简单的神经网络模型来进行手写数字识别。
手写数字识别-使用TensorFlow构建和训练一个简单的神经网络的更多相关文章
- Softmax用于手写数字识别(Tensorflow实现)-个人理解
softmax函数的作用 对于分类方面,softmax函数的作用是从样本值计算得到该样本属于各个类别的概率大小.例如手写数字识别,softmax模型从给定的手写体图片像素值得出这张图片为数字0~9 ...
- 基于卷积神经网络的手写数字识别分类(Tensorflow)
import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...
- TensorFlow 之 手写数字识别MNIST
官方文档: MNIST For ML Beginners - https://www.tensorflow.org/get_started/mnist/beginners Deep MNIST for ...
- TensorFlow使用RNN实现手写数字识别
学习,笔记,有时间会加注释以及函数之间的逻辑关系. # https://www.cnblogs.com/felixwang2/p/9190664.html # https://www.cnblogs. ...
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...
- 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)
上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...
- 第二节,TensorFlow 使用前馈神经网络实现手写数字识别
一 感知器 感知器学习笔记:https://blog.csdn.net/liyuanbhu/article/details/51622695 感知器(Perceptron)是二分类的线性分类模型,其输 ...
- python-积卷神经网络全面理解-tensorflow实现手写数字识别
首先,关于神经网络,其实是一个结合很多知识点的一个算法,关于cnn(积卷神经网络)大家需要了解: 下面给出我之前总结的这两个知识点(基于吴恩达的机器学习) 代价函数: 代价函数 代价函数(Cost F ...
- Tensorflow之MNIST手写数字识别:分类问题(1)
一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点: 1.将离散特征的取值扩展 ...
- Tensorflow实战 手写数字识别(Tensorboard可视化)
一.前言 为了更好的理解Neural Network,本文使用Tensorflow实现一个最简单的神经网络,然后使用MNIST数据集进行测试.同时使用Tensorboard对训练过程进行可视化,算是打 ...
随机推荐
- sqlerver 报错5120 无法为该请求检索数据 系统找不到指定路径
背景: 数据库mdf文件所在盘符F盘被删除了,也就是文件不存在了,sqlserver管理器打开就报错5120,并且正常路径的数据库也不显示出来. 要让正常的数据库显示出来,就需要删除掉已经没有的数据库 ...
- 电脑临时文件清理2个方法?%temp% cleanmgr
按住电脑快捷键win+R,打开运行框 输入代码 %temp%,点击回车enter或者点击确定,打开temp文件夹[此处存放的都是系统无用的缓存垃圾] 按快捷键Ctrl + A ,点击delete,删除 ...
- Lakehouse 还是 Warehouse?(1/2)
Onehouse 创始人/首席执行官 Vinoth Chandar 于 2022 年 3 月在奥斯汀数据委员会发表了这一重要演讲.奥斯汀数据委员会是"世界上最大的独立全栈数据会议" ...
- HTML——标签语法
<标签名 属性1="属性值1" 属性2="属性值2"-->内容部分</标签名> <标签名 属性1="属性值1" ...
- Git三大区域
1.工作区 2.暂存区 3.版本库
- 记录一次由nginx配置引发出来的一系列的缓存问题
问题描述: 在做一个企业微信的移动端项目时,每次修改代码后并且打包.部署完毕,再次打开页面总是会有上一次的缓存,一开始以为是cookie和webStorage缓存导致的,然后每次清除还是有缓存,后来把 ...
- 阿里云ECS后台CPU占用100%,top却找不到
上周公司阿里云服务器后台报警,CPU占用瞬间飙升到100%: 首先想到使用top命令查询CPU占用详情: 发现进程占用CPU都比较低,在CPU占用一栏发现只有ni的占用比较高. 先了解一下CPU相关监 ...
- work11
1,简述String类中的equals方法与Object类中的equals方法的不同点. /* Object 类 1,它是所有类的一个根类 2,其他类默认继承Object类 常用方法: 1,toStr ...
- Apollo启动配置排查,超时时间的配置
Apollo启动配置排查 1.排查下来是 本地的服务 apollo 配置fake发布到线上去了.2.或者是引用的apollo jar包中指向的apollo服务器地址是否正确. 3.超时时间的配置 ## ...
- 阿里也出手了!Spring CloudAlibaba AI问世了
写在前面 在之前的文章中我们有介绍过SpringAI这个项目.SpringAI 是Spring 官方社区项目,旨在简化 Java AI 应用程序开发, 让 Java 开发者想使用 Spring 开发普 ...