pytorch入门 - 基于AlexNet神经网络实现猫狗大战
基于之前的博客 pytorch入门 - AlexNet神经网络,并借助Kaggle 的 Dogs vs Cats Redux 数据集,实现一个基于 AlexNet 的二分类模型识别猫与狗。
完整流程涵盖数据准备、归一化、模型定义、训练增强、验证并可视化结果。
一、数据集准备与预处理
import os
import shutil
def split_data(ROOT_TRAIN):
cat_dir = os.path.join(ROOT_TRAIN, "cat")
dog_dir = os.path.join(ROOT_TRAIN, "dog")
os.makedirs(cat_dir, exist_ok=True)
os.makedirs(dog_dir, exist_ok=True)
for filename in os.listdir(ROOT_TRAIN):
if filename.startswith("cat") and filename.endswith(".jpg"):
shutil.move(os.path.join(ROOT_TRAIN, filename),
os.path.join(cat_dir, filename))
elif filename.startswith("dog") and filename.endswith(".jpg"):
shutil.move(os.path.join(ROOT_TRAIN, filename),
os.path.join(dog_dir, filename))
优化原因:
分类任务需明确标签与数据的对应关系。通过创建cat/dog子目录并移动图片,可直接利用PyTorch的ImageFolder自动生成标签,避免手动标注错误。
二、数据归一化参数计算
def compute_normalization_params(dataset_path):
transform = transforms.Compose([
transforms.Resize((227, 227)),
transforms.ToTensor()
])
dataset = ImageFolder(dataset_path, transform=transform)
loader = DataLoader(dataset, batch_size=32, num_workers=4, shuffle=False)
# 计算各通道均值和标准差
mean = 0.0
std = 0.0
for data, _ in loader:
batch_samples = data.size(0)
data = data.view(batch_samples, data.size(1), -1)
mean += data.mean(2).sum(0)
std += data.std(2).sum(0)
return mean / len(dataset), std / len(dataset)
关键点:
- 输入尺寸统一:AlexNet要求固定输入尺寸
227×227,需提前调整 - 通道级归一化:对RGB三通道分别计算均值和标准差,消除光照差异影响,加速模型收敛
- 离线计算:避免在训练时实时计算,提升数据加载效率
三、AlexNet模型针对性修改
class AlexNet(nn.Module):
def __init__(self):
super().__init__()
# 修改1:输入通道调整为3 (RGB)
self.conv1 = nn.Conv2d(3, 96, kernel_size=11, stride=4)
# ... (中间层省略)
# 修改2:输出层调整为2分类
self.fc3 = nn.Linear(4096, 2)
# 修改3:降低Dropout比例
self.dropout = nn.Dropout(0.2) # 原论文为0.5
优化逻辑:
- 输入通道适配:原始AlexNet针对ImageNet的1000类设计,此处调整为猫狗二分类,需修改输出层维度为2
- 降低过拟合风险:
- 猫狗数据集(25k张)远小于ImageNet(1400万张)
- 降低Dropout比例(0.5→0.2)可保留更多特征信息,避免模型欠拟合
- 权重初始化:采用Kaiming初始化,适配ReLU激活函数特性,缓解梯度消失
四、数据增强策略
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(10),
transforms.RandomResizedCrop(227, scale=(0.8, 1.0)),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.488, 0.455, 0.417],
std=[0.226, 0.221, 0.221])
])
增强目的:
- 提升泛化能力:通过旋转、裁剪、色彩扰动模拟真实场景的多样性,防止模型记忆固定模式
- 克服数据局限:小数据集易导致过拟合,增强后等效扩大数据规模
- 对齐测试环境:测试阶段采用相同预处理,保证输入分布一致性
五、训练过程优化
# 1. 学习率调整
optimizer = optim.Adam(model.parameters(), lr=1e-4) # 原常用值0.001
# 2. 训练-验证集拆分
train_data, val_data = random_split(dataset, [0.8, 0.2])
# 3. 早停机制
if val_acc > best_acc:
best_model_wts = copy.deepcopy(model.state_dict())
关键技术点:
- 低学习率策略:
- 预训练模型特征已较完备,降低学习率(1e-4)避免破坏已有特征
- 微调阶段需精细调整参数,高学习率易导致震荡
- 验证集独立划分:
- 20%数据作为验证集,实时监控模型泛化能力
- 避免测试集参与训练,保证评估客观性
- 混合精度训练(可选):
使用torch.cuda.amp自动混合精度,提升训练速度30%+(需GPU支持)
关键优化总结
| 优化点 | 原始值 | 调整值 | 作用 |
|---|---|---|---|
| 输入通道 | 1 (灰度) | 3 (RGB) | 适配彩色图像 |
| 输出维度 | 1000 | 2 | 二分类需求 |
| Dropout率 | 0.5 | 0.2 | 防欠拟合 |
| 学习率 | 0.001 | 0.0001 | 稳定微调 |
| 数据增强 | 无 | 5种变换 | 提升泛化性 |
pytorch入门 - 基于AlexNet神经网络实现猫狗大战的更多相关文章
- Pytorch实现基于卷积神经网络的面部表情识别(详细步骤)
文章目录 一.项目背景 二.数据处理 1.标签与特征分离 2.数据可视化 3.训练集和测试集 三.模型搭建 四.模型训练 五.完整代码 一.项目背景数据集cnn_train.csv包含人类面部表情的图 ...
- 基于卷积神经网络的面部表情识别(Pytorch实现)----台大李宏毅机器学习作业3(HW3)
一.项目说明 给定数据集train.csv,要求使用卷积神经网络CNN,根据每个样本的面部图片判断出其表情.在本项目中,表情共分7类,分别为:(0)生气,(1)厌恶,(2)恐惧,(3)高兴,(4)难过 ...
- pytorch 入门指南
两类深度学习框架的优缺点 动态图(PyTorch) 计算图的进行与代码的运行时同时进行的. 静态图(Tensorflow <2.0) 自建命名体系 自建时序控制 难以介入 使用深度学习框架的优点 ...
- Pytorch入门上 —— Dataset、Tensorboard、Transforms、Dataloader
本节内容参照小土堆的pytorch入门视频教程.学习时建议多读源码,通过源码中的注释可以快速弄清楚类或函数的作用以及输入输出类型. Dataset 借用Dataset可以快速访问深度学习需要的数据,例 ...
- 第一章:PyTorch 入门
第一章:PyTorch 入门 1.1 Pytorch 简介 1.1.1 PyTorch的由来 1.1.2 Torch是什么? 1.1.3 重新介绍 PyTorch 1.1.4 对比PyTorch和Te ...
- Pytorch入门随手记
Pytorch入门随手记 什么是Pytorch? Pytorch是Torch到Python上的移植(Torch原本是用Lua语言编写的) 是一个动态的过程,数据和图是一起建立的. tensor.dot ...
- 超简单!pytorch入门教程(五):训练和测试CNN
我们按照超简单!pytorch入门教程(四):准备图片数据集准备好了图片数据以后,就来训练一下识别这10类图片的cnn神经网络吧. 按照超简单!pytorch入门教程(三):构造一个小型CNN构建好一 ...
- PyTorch基础——机器翻译的神经网络实现
一.介绍 内容 "基于神经网络的机器翻译"出现了"编码器+解码器+注意力"的构架,让机器翻译的准确度达到了一个新的高度.所以本次主题就是"基于深度神经 ...
- PyTorch ImageNet 基于预训练六大常用图片分类模型的实战
微调 Torchvision 模型 在本教程中,我们将深入探讨如何对 torchvision 模型进行微调和特征提取,所有这些模型都已经预先在1000类的Imagenet数据集上训练完成.本教程将深入 ...
- pytorch入门2.0构建回归模型初体验(数据生成)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
随机推荐
- 如何在linux中查看cpu信息、机器硬件型号
# cat /proc/cpuinfo | grep name | cut -f2 -d: | uniq -c 8 Intel(R) Xeon(R) CPU E5410 @ 2.33GHz (看到有8 ...
- 【Linux】速查手册
查看Linux系统信息 arch #显示机器的处理器架构(1) uname -m #显示机器的处理器架构(2) uname -r #显示正在使用的内核版本 dmidecode -q #显示硬件系统部件 ...
- Nginx 301永久性转移
我有个域名www.taadis.com, 想永久性转移到taadis.com. 前言 看到很多网友的做法是把taadis.com & www.taadis.com等多个域名放到一个server ...
- Java 中堆和栈的区别是什么?
Java 中堆和栈的区别 Java 中的堆(Heap)和栈(Stack)是两种不同的内存区域,它们有着不同的用途和特点.以下是它们的主要区别: 1. 存储内容 堆:用于存储对象实例以及类的实例变量.所 ...
- 一文速通Python并行计算:08 Python多进程编程-multiprocessing模块、进程的创建命名、获取进程ID、创建守护进程和进程的终止
一文速通 Python 并行计算:08 Python 多进程编程-multiprocessing 模块.进程的创建命名.获取进程 ID.创建守护进程和进程的终止 摘要: 本节介绍 Python 中 m ...
- toRefs 与 toRef 的详解
一.引言在 Vue 3 的响应式系统里,toRefs 和 toRef 是两个实用的工具函数,它们在处理响应式数据时发挥着重要作用.合理运用这两个函数,可以让我们在操作响应式对象和数组时更加灵活,避免一 ...
- 45分钟从零搭建私有MaaS平台和生产级的Qwen3模型服务
今天凌晨,阿里通义团队正式发布了 Qwen3,涵盖六款 Dense 模型(0.6B.1.7B.4B.8B.14B.32B)和两款 MoE 模型(30B-A3B 和 235B-A22B).其中的旗舰模型 ...
- 信息资源管理综合题之“公钥密码体系中同一个用户拥有的密钥特点 和 如何使用密钥加解密才能保证传输数据的机密性 和 如何身份认证 和 CA的作用”
一.公钥密码体制在认证技术中是广泛使用的.结合加密和认证技术知识回答以下问题: 1.公钥密码体系中同一个用户拥有的密钥的特点是什么? 2.假设A.B是公钥密码体系的用户,A向B发送数据,A.B之间如何 ...
- docker容器安装TensorFlow_gpu 版本遇到的坑。。。
运行并挂载docker镜像 docker run -it -v E:/workspace/docker:/dl -p 8888:8888 8d78dd1e1b64 /bin/bash 安装jupyte ...
- C#中的弱引用
弱引用保持的是一个GC"不可见"的引用,是指弱引用不会增加对象的引用计数,也不会阻止垃圾回收器对该对象进行回收.因此,弱引用的目标对象可以被垃圾回收器回收,而弱引用本身不会对垃圾回 ...