数据集下载

这一部分比较简单,就不过多赘述了,把代码粘贴到自己的项目文件里,运行一下就可以下载了。

from torchvision import datasets, transforms

# 定义数据转换,将数据转换为张量并进行标准化
transform = transforms.Compose([
transforms.ToTensor(), # 转换为张量
transforms.Normalize((0.5,), (0.5,)) # 标准化
]) # 下载和加载训练集
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform) # 下载和加载测试集
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

该代码运行效果如下图:

下载好的数据集可以将其中的图片保存,这里给出两个代码,分别采用matplotlib库和opencv库进行可视化和保存

# matplotlib
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os # 创建保存图片的文件夹
os.makedirs('mnist_images', exist_ok=True) # 定义数据转换(转换为Tensor)
transform = transforms.Compose([
transforms.ToTensor()
]) # 下载 MNIST 数据集
dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True) # 获取前100张图片
for i in range(100):
image, _ = dataset[i]
image = image.squeeze() # 去掉单通道维度 plt.imshow(image, cmap='gray')
plt.axis('off') # 不显示坐标轴
plt.savefig(f'mnist_images/image_{i+1}.png', bbox_inches='tight', pad_inches=0) print("前 100 张图片已保存为 PNG 文件")
# opencv
import cv2
import numpy as np
from torchvision import datasets, transforms
import os # 创建保存图片的文件夹
os.makedirs('mnist_images', exist_ok=True) # 定义数据转换(转换为Tensor)
transform = transforms.Compose([
transforms.ToTensor()
]) # 下载 MNIST 数据集
dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True) # 获取前100张图片
for i in range(100):
image, _ = dataset[i]
image = image.squeeze().numpy() # 去掉单通道维度,并转换为 numpy 数组 # OpenCV 需要图像的范围在 0 到 255 之间
image = (image * 255).astype(np.uint8) # 保存图像
cv2.imwrite(f'mnist_images/image_{i+1}.png', image) # 可选:显示图像
cv2.imshow('image_1', image)
cv2.waitKey(0)
cv2.destroyAllWindows() # 如果你启用了显示图像的功能,记得在最后调用以下代码:
cv2.destroyAllWindows()

网络训练

该代码运行效果如下图

import torch

'''=============== 数据集部分 ==============='''
# 定义数据转换
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
]) # 打开已经下载的训练集和测试集
from torchvision.datasets import MNIST
train_dataset = MNIST(root='./data', train=True, download=False, transform=transform)
test_dataset = MNIST(root='./data', train=False, download=False, transform=transform) # 创建数据加载器
batch_size = 256
from torch.utils.data import random_split
from torch.utils.data import DataLoader # 将数据集分割为6000和剩余的数据
train_size = 6000
train_subset, _ = random_split(train_dataset, [train_size, len(train_dataset) - train_size]) train_loader = DataLoader(dataset=train_subset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) '''=============== 网络定义 ==============='''
# 初始化网络
from net import CNN
net = CNN() # 初始化优化器、学习率调整器、评价函数
import torch.nn as nn
from torch import optim
learning_rate = 0.001 # 0.05 ~ 1e-6
weight_decay = 1e-4 # 1e-2 ~ 1e-8
momentum = 0.8 # 0.3~0.9
optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=momentum)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)
criterion = nn.CrossEntropyLoss() # GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net.to(device=device) '''=============== 模型信息管理 ==============='''
model_path = None if model_path is not None:
net.load_state_dict(torch.load(model_path, map_location=device)) '''=============== 网络训练 ==============='''
epochs = 50 def train(net, device, optimizer, scheduler, criterion):
net.train() for epoch in range(epochs):
epoch_loss = 0 # 集损失置0 for images, labels in train_loader:
''' ========== 数据获取和转移 ========== '''
images = images.to(device=device, dtype=torch.float32)
labels = labels.to(device=device, dtype=torch.long) ''' ========== 数据操作 ========== '''
outputs = net(images)
# net.forward()
loss = criterion(outputs, labels)
epoch_loss += loss.detach().item() ''' ========== 反向传播 ========== '''
optimizer.zero_grad()
loss.requires_grad_(True)
loss.backward() # 梯度裁剪
for param in net.parameters():
if param.grad is not None and param.grad.nelement() > 0:
nn.utils.clip_grad_value_([param], clip_value=0.1) optimizer.step() epoch_loss /= len(train_loader) # 输出每个 epoch 的平均损失
print(f'Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss}') train(net, device, optimizer, scheduler, criterion) '''=============== 网络保存 ==============='''
from datetime import datetime # 获取当前时间
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
model_path = f'./output/final_model_{current_time}.pth' # 保存模型
torch.save(net.state_dict(), model_path)

卷积神经网络CNN实战:MINST手写数字识别——数据集下载与网络训练的更多相关文章

  1. 卷积神经网络应用于tensorflow手写数字识别(第三版)

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...

  2. keras和tensorflow搭建DNN、CNN、RNN手写数字识别

    MNIST手写数字集 MNIST是一个由美国由美国邮政系统开发的手写数字识别数据集.手写内容是0~9,一共有60000个图片样本,我们可以到MNIST官网免费下载,总共4个.gz后缀的压缩文件,该文件 ...

  3. 实现手写数字识别(数据集50000张图片)比较3种算法神经网络、灰度平均值、SVM各自的准确率—Jason niu

    对手写数据集50000张图片实现阿拉伯数字0~9识别,并且对结果进行分析准确率, 手写数字数据集下载:http://yann.lecun.com/exdb/mnist/ 首先,利用图片本身的属性,图片 ...

  4. MINST手写数字识别(三)—— 使用antirectifier替换ReLU激活函数

    这是一个来自官网的示例:https://github.com/keras-team/keras/blob/master/examples/antirectifier.py 与之前的MINST手写数字识 ...

  5. TensorFlow 卷积神经网络手写数字识别数据集介绍

    欢迎大家关注我们的网站和系列教程:http://www.tensorflownews.com/,学习更多的机器学习.深度学习的知识! 手写数字识别 接下来将会以 MNIST 数据集为例,使用卷积层和池 ...

  6. [Python]基于CNN的MNIST手写数字识别

    目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...

  7. 第三节,CNN案例-mnist手写数字识别

    卷积:神经网络不再是对每个像素做处理,而是对一小块区域的处理,这种做法加强了图像信息的连续性,使得神经网络看到的是一个图像,而非一个点,同时也加深了神经网络对图像的理解,卷积神经网络有一个批量过滤器, ...

  8. MINST手写数字识别(一)—— 全连接网络

    这是一个简单快速入门教程——用Keras搭建神经网络实现手写数字识别,它大部分基于Keras的源代码示例 minst_mlp.py. 1.安装依赖库 首先,你需要安装最近版本的Python,再加上一些 ...

  9. MINST手写数字识别(二)—— 卷积神经网络(CNN)

    今天我们的主角是keras,其简洁性和易用性简直出乎David 9我的预期.大家都知道keras是在TensorFlow上又包装了一层,向简洁易用的深度学习又迈出了坚实的一步. 所以,今天就来带大家写 ...

  10. NN:利用深度学习之神经网络实现手写数字识别(数据集50000张图片)—Jason niu

    import mnist_loader import network training_data, validation_data, test_data = mnist_loader.load_dat ...

随机推荐

  1. 推荐一款基于业务行为驱动开发(BDD)测试框架:Cucumber!

    大家好,我是狂师. 今天给大家介绍一款行为驱动开发测试框架:Cucumber. 1.介绍 Cucumber是一个行为驱动开发(BDD)工具,它结合了文本描述和自动化测试脚本.它使用一种名为Gherki ...

  2. 记录一次EF实体跟踪错误

    记录一次EF实体跟踪错误 前言 在我写文章编辑接口的,出现了一个实体跟踪的错误,详情如下 System.InvalidOperationException: The instance of entit ...

  3. UEFI与inf文件

    UEFI与inf文件 背景 学习高通UEFI中的LCD显示框架,看到有些博客对inf文件进行了介绍,因此整理了这方面的一些入门知识. 参考: https://blog.csdn.net/yunfeng ...

  4. JVM(Java虚拟机)整理(二):排错调优

    前言 上一篇内容:JVM(Java虚拟机)整理(一) Java 内存模型(JMM)详解 声明:本章节转载自 Info 上 深入理解Java内存模型.PDF文档下载 深入理解Java内存模型[程晓明] ...

  5. P9120 题解

    暴力容斥复活之路! \(k=1\) 这个你肯定会. \(k=2\) 大的放上去,小的放下来.简单贪心. \(k=3\) 考虑二分答案. 然后考虑判断是否合法. 令当前答案为 \(val\). 首先钦定 ...

  6. 痞子衡嵌入式:浅聊恩智浦i.MXRT官方SDK里关于串行Flash相关的驱动与例程资源(上篇)

    大家好,我是痞子衡,是正经搞技术的痞子.今天痞子衡给大家介绍的是恩智浦i.MXRT官方SDK里关于串行Flash相关的驱动与例程资源. 经常有同事以及 i.MXRT 客户咨询痞子衡,咱们恩智浦官方 S ...

  7. TypeScript 学习笔记 — 类型补充void,any, tuple ,enum,nerver, Symbol , BigInt ,unknown(三)

    目录 空值void 及(与Null 和 Undefined的区别) 任意值Any 元组类型 枚举类型 常量枚举 never 类型 1. 函数无法到达终点 2.通常校验逻辑的完整性,可以利用 never ...

  8. [oeasy]python0071_字符串类型_str_string_下标运算符_中括号

    回忆上次内容 上次 分辨了 静态类型 语言 动态类型 语言   python 属于 对类型要求 没有那么严格的 动态类型 语言   对 初学者很友好 不过很多时候 也容易 弄不清变量类型   直接 修 ...

  9. CF1363A 题解

    洛谷链接&CF 链接 题目简述 共有 \(T\) 组数据. 对于每组数据,给定 \(n,x\) 和 \(n\) 个数,问是否可以从 \(n\) 个数中选 \(x\) 个使其和为奇数,可以输出 ...

  10. WebAPI规范设计——违RESTful

    本文首先简单介绍了几种API设计风格(RPC.REST.GraphQL),然后根据实现项目经验提出WebAPI规范设计思路,一些地方明显违反了RESTful风格,供大家参考! 一.几种设计风格介绍 1 ...