下面是一个具体的Python代码示例,展示如何使用TensorFlow实现一个简单的神经网络来解决手写数字识别问题(使用MNIST数据集)。以下是一个完整的Python代码示例,展示如何使用TensorFlow构建和训练一个简单的神经网络来进行手写数字识别。

MNIST数据集的训练集有60000个样本:

Python代码

import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
import json
import os # 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data() # 预处理数据
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255 # 构建神经网络模型
def create_model():
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax')) # 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model # 训练模型并保存
def train_and_save_model():
model = create_model()
history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))
model.save('mnist_model.h5') # 保存训练历史记录
with open('training_history.json', 'w') as f:
json.dump(history.history, f) # 加载模型和历史记录
def load_model_and_history():
model = tf.keras.models.load_model('mnist_model.h5') with open('training_history.json', 'r') as f:
history = json.load(f) return model, history # 评估模型
def evaluate_model(model):
test_loss, test_acc = model.evaluate(test_images, test_labels)
print("Test accuracy: {}".format(test_acc)) # 可视化训练过程
def plot_training_history(history):
plt.plot(history['accuracy'], label='accuracy')
plt.plot(history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0, 1])
plt.legend(loc='lower right')
plt.show() # 检查是否已经存在模型和历史记录
if not os.path.exists('mnist_model.h5') or not os.path.exists('training_history.json'):
train_and_save_model() model, training_history = load_model_and_history()
evaluate_model(model)
plot_training_history(training_history)

代码解释

  1. 加载MNIST数据集

    mnist = tf.keras.datasets.mnist
    (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
  2. 预处理数据

    • 将图像数据调整为 (28, 28, 1) 的形状。
    • 将像素值标准化为 [0, 1] 之间。
    train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
    test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255
  3. 构建神经网络模型

    • 使用 Sequential 模型,按顺序添加层。
    • 添加卷积层、池化层、全连接层。
    model = models.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.Flatten())
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dense(10, activation='softmax'))
  4. 编译模型

    • 使用 adam 优化器,损失函数为 sparse_categorical_crossentropy,评估指标为 accuracy
    model.compile(optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])
  5. 训练模型

    • 训练模型5个epochs,并使用验证数据集评估模型性能。
    history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))
  6. 评估模型

    • 在测试集上评估模型性能,并打印测试准确率。
    test_loss, test_acc = model.evaluate(test_images, test_labels)
    print(f"Test accuracy: {test_acc}")
  7. 可视化训练过程

    • 绘制训练和验证准确率随epoch变化的曲线。
    plt.plot(history.history['accuracy'], label='accuracy')
    plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.ylim([0, 1])
    plt.legend(loc='lower right')
    plt.show()

通过这个修正后的示例,应该可以正常运行并训练一个简单的神经网络模型来进行手写数字识别。

手写数字识别-使用TensorFlow构建和训练一个简单的神经网络的更多相关文章

  1. Softmax用于手写数字识别(Tensorflow实现)-个人理解

    softmax函数的作用   对于分类方面,softmax函数的作用是从样本值计算得到该样本属于各个类别的概率大小.例如手写数字识别,softmax模型从给定的手写体图片像素值得出这张图片为数字0~9 ...

  2. 基于卷积神经网络的手写数字识别分类(Tensorflow)

    import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...

  3. TensorFlow 之 手写数字识别MNIST

    官方文档: MNIST For ML Beginners - https://www.tensorflow.org/get_started/mnist/beginners Deep MNIST for ...

  4. TensorFlow使用RNN实现手写数字识别

    学习,笔记,有时间会加注释以及函数之间的逻辑关系. # https://www.cnblogs.com/felixwang2/p/9190664.html # https://www.cnblogs. ...

  5. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  6. 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)

    上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...

  7. 第二节,TensorFlow 使用前馈神经网络实现手写数字识别

    一 感知器 感知器学习笔记:https://blog.csdn.net/liyuanbhu/article/details/51622695 感知器(Perceptron)是二分类的线性分类模型,其输 ...

  8. python-积卷神经网络全面理解-tensorflow实现手写数字识别

    首先,关于神经网络,其实是一个结合很多知识点的一个算法,关于cnn(积卷神经网络)大家需要了解: 下面给出我之前总结的这两个知识点(基于吴恩达的机器学习) 代价函数: 代价函数 代价函数(Cost F ...

  9. Tensorflow之MNIST手写数字识别:分类问题(1)

    一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点:   1.将离散特征的取值扩展 ...

  10. Tensorflow实战 手写数字识别(Tensorboard可视化)

    一.前言 为了更好的理解Neural Network,本文使用Tensorflow实现一个最简单的神经网络,然后使用MNIST数据集进行测试.同时使用Tensorboard对训练过程进行可视化,算是打 ...

随机推荐

  1. Git三大区域

    1.工作区 2.暂存区 3.版本库

  2. Prism IoC 依赖注入

    现有2个项目,SinglePageApp是基于Prism创建的WPF项目,框架使用的是Prism.DryIoc,SinglePageApp.Services是C#类库,包含多种服务,下面通过使用Pri ...

  3. Wpf虚拟屏幕键盘

    在Wpf使用虚拟键盘有基于osk和tabtip,后者只能在win8和win10之后电脑使用,而且两者在wpf中调用时都必须提升为管理员权限,实际应用中还是不方便. 今天介绍的方法是使用第三方库oskl ...

  4. Pyomo基础学习笔记:建模组成要素的编写方法

    1.Pyomo 简介 pyomo文档[数学建模]优化模型建模语言 Pyomo 入门教程 - 知乎 (zhihu.com) Pyomo 是基于 Python 的开源软件包,主要功能是建立数学规划模型,包 ...

  5. mysql笔记第一天: 介绍和MySQL编译安装

    一.DBA的工作内容: ![](371eaced-e10b-46d9-89e2-f63f15503bb6_files/9edcd22a-ef2d-4c3e-8474-3049255610db.jpg) ...

  6. NOIP模拟54

    我觉得,不改变也很好. 前言 这题太难了,场上竟然无人切题..(听说别的学校切题的人不少.. T1 选择 解题思路 范围比较小,并且每个边的度也比较小,因此考虑 树形DP+状压 . 大概就是对于每一个 ...

  7. itest(爱测试) 4.5.1 发布,开源BUG 跟踪管理 & 敏捷测试管理软件

    本次发布一共6个更新(其中一个4.5.0的重大BUG,不得不先发布4.5.1).4.5.0中增加ldap 登录支持时,引入一个BUG,新增的itest本地用户不能登录,除非重启.  V4.5.1详情如 ...

  8. Linux Debian安装教程

    Debian 是一个免费的开源操作系统,是最古老的 Linux 发行版之一,于 1993 年由 Ian Murdock 创建.它采用了自由软件协议,并且由志愿者社区维护和支持.Debian 的目标是创 ...

  9. 在线IP归属地查询工具

    在线IP地址归属地查询工具,通过该工具可以查询指定IP的物理地址或域名服务器的IP和物理地址,及所在国家或城市IP归属地,数据为纯真IP. 在线IP归属地查询工具

  10. java8 lambda Predicate示例

    import java.util.Arrays; import java.util.List; import java.util.function.Predicate; public class Pr ...