下面是一个具体的Python代码示例,展示如何使用TensorFlow实现一个简单的神经网络来解决手写数字识别问题(使用MNIST数据集)。以下是一个完整的Python代码示例,展示如何使用TensorFlow构建和训练一个简单的神经网络来进行手写数字识别。

MNIST数据集的训练集有60000个样本:

Python代码

import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
import json
import os # 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data() # 预处理数据
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255 # 构建神经网络模型
def create_model():
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax')) # 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model # 训练模型并保存
def train_and_save_model():
model = create_model()
history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))
model.save('mnist_model.h5') # 保存训练历史记录
with open('training_history.json', 'w') as f:
json.dump(history.history, f) # 加载模型和历史记录
def load_model_and_history():
model = tf.keras.models.load_model('mnist_model.h5') with open('training_history.json', 'r') as f:
history = json.load(f) return model, history # 评估模型
def evaluate_model(model):
test_loss, test_acc = model.evaluate(test_images, test_labels)
print("Test accuracy: {}".format(test_acc)) # 可视化训练过程
def plot_training_history(history):
plt.plot(history['accuracy'], label='accuracy')
plt.plot(history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0, 1])
plt.legend(loc='lower right')
plt.show() # 检查是否已经存在模型和历史记录
if not os.path.exists('mnist_model.h5') or not os.path.exists('training_history.json'):
train_and_save_model() model, training_history = load_model_and_history()
evaluate_model(model)
plot_training_history(training_history)

代码解释

  1. 加载MNIST数据集

    mnist = tf.keras.datasets.mnist
    (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
  2. 预处理数据

    • 将图像数据调整为 (28, 28, 1) 的形状。
    • 将像素值标准化为 [0, 1] 之间。
    train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
    test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255
  3. 构建神经网络模型

    • 使用 Sequential 模型,按顺序添加层。
    • 添加卷积层、池化层、全连接层。
    model = models.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.Flatten())
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dense(10, activation='softmax'))
  4. 编译模型

    • 使用 adam 优化器,损失函数为 sparse_categorical_crossentropy,评估指标为 accuracy
    model.compile(optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])
  5. 训练模型

    • 训练模型5个epochs,并使用验证数据集评估模型性能。
    history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))
  6. 评估模型

    • 在测试集上评估模型性能,并打印测试准确率。
    test_loss, test_acc = model.evaluate(test_images, test_labels)
    print(f"Test accuracy: {test_acc}")
  7. 可视化训练过程

    • 绘制训练和验证准确率随epoch变化的曲线。
    plt.plot(history.history['accuracy'], label='accuracy')
    plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.ylim([0, 1])
    plt.legend(loc='lower right')
    plt.show()

通过这个修正后的示例,应该可以正常运行并训练一个简单的神经网络模型来进行手写数字识别。

手写数字识别-使用TensorFlow构建和训练一个简单的神经网络的更多相关文章

  1. Softmax用于手写数字识别(Tensorflow实现)-个人理解

    softmax函数的作用   对于分类方面,softmax函数的作用是从样本值计算得到该样本属于各个类别的概率大小.例如手写数字识别,softmax模型从给定的手写体图片像素值得出这张图片为数字0~9 ...

  2. 基于卷积神经网络的手写数字识别分类(Tensorflow)

    import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...

  3. TensorFlow 之 手写数字识别MNIST

    官方文档: MNIST For ML Beginners - https://www.tensorflow.org/get_started/mnist/beginners Deep MNIST for ...

  4. TensorFlow使用RNN实现手写数字识别

    学习,笔记,有时间会加注释以及函数之间的逻辑关系. # https://www.cnblogs.com/felixwang2/p/9190664.html # https://www.cnblogs. ...

  5. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  6. 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)

    上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...

  7. 第二节,TensorFlow 使用前馈神经网络实现手写数字识别

    一 感知器 感知器学习笔记:https://blog.csdn.net/liyuanbhu/article/details/51622695 感知器(Perceptron)是二分类的线性分类模型,其输 ...

  8. python-积卷神经网络全面理解-tensorflow实现手写数字识别

    首先,关于神经网络,其实是一个结合很多知识点的一个算法,关于cnn(积卷神经网络)大家需要了解: 下面给出我之前总结的这两个知识点(基于吴恩达的机器学习) 代价函数: 代价函数 代价函数(Cost F ...

  9. Tensorflow之MNIST手写数字识别:分类问题(1)

    一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点:   1.将离散特征的取值扩展 ...

  10. Tensorflow实战 手写数字识别(Tensorboard可视化)

    一.前言 为了更好的理解Neural Network,本文使用Tensorflow实现一个最简单的神经网络,然后使用MNIST数据集进行测试.同时使用Tensorboard对训练过程进行可视化,算是打 ...

随机推荐

  1. 本地项目文件上传到git

    初始化项目: git init 与服务器项目关联:git remote add origin "http://**************************/r/ruoyi.git&q ...

  2. 记一次bug排除心得

    问题背景 要做一个需求,大概是检测到某输入重启,于是写一个demo调试一下 c语言程序,交叉编译后在adb shell下运行 思路 用 am 命令直接重启 我们先手动验证一下,发现这个设备不支持am命 ...

  3. ubuntu安装之后要做的10件事

    部分内容整理自网络,如果侵权还请联系 基础配置 换源 换源 [ubuntu清华源镜像站] ctrl+click,进入镜像站链接,选择合适的版本,将镜像地址粘贴到本地文件里,对于: <24.04的 ...

  4. 使用Visual Studio分析.NET Dump

    前言 内存泄漏和高CPU使用率是在日常开发中经常遇到的问题,它们可能会导致应用程序性能下降甚至崩溃.今天我们来讲讲如何使用Visual Studio 2022分析.NET Dump,快速找到程序内存泄 ...

  5. xpath提取不到值(iframe嵌套)的问题

    爬取http://xgj.xiangyang.gov.cn/zwgk/gkml/?itemid=2471的时候遇到frame嵌套,内部的a标签获取不到. 网上也有人遇到了同样的问题.https://b ...

  6. 如何基于R包做GO分析?实现秒出图

    GO分析 基因本体论(Gene Ontology, GO)是一个用于描述基因和基因产品属性的标准术语体系.它提供了一个有组织的方式来表示基因在生物体内的各种角色.基因本体论通常从三个层面对基因进行描述 ...

  7. RSS 解析:全球内容分发的利器及使用技巧

    使用 RSS 可以将最新的网络内容从一个网站分发到全球数千个其他网站. RSS 允许快速浏览新闻和更新. RSS 文档示例 <?xml version="1.0" encod ...

  8. word文档生成视频,自动配音、背景音乐、自动字幕,另类创作工具

    简介 不同于别的视频创作工具,这个工具创作视频只需要在word文档中打字,插入图片即可.完事后就能获得一个带有配音.字幕.背景音乐.视频特效滤镜的优美作品. 这种不要门槛,没有技术难度的视频创作工具, ...

  9. jwt 加密和解密demo

    jwt 加密和解密demo JSON Web Token(JWT)是一个非常轻巧的规范.这个规范允许我们使用 JWT 在用户和服务器之间传递安全可靠的信息.导入jar <dependency&g ...

  10. java8 lambda Predicate示例

    import java.util.Arrays; import java.util.List; import java.util.function.Predicate; public class Pr ...