深度学习--实战 LeNet5

数据集

数据集选用CIFAR-10的数据集,Cifar-10 是由 Hinton 的学生 Alex Krizhevsky、Ilya Sutskever 收集的一个用于普适物体识别的计算机视觉数据集,它包含 60000 张 32 X 32 的 RGB 彩色图片,总共 10 个分类。其中,包括 50000 张用于训练集,10000 张用于测试集。

模型实现

模型需要继承nn.module

import torch
from torch import nn class Lenet5(nn.Module):
"""
for cifar10 dataset.
"""
def __init__(self):
super(Lenet5,self).__init__() self.conv_unit = nn.Sequential(
#input:[b,3,32,32] ===> output:[b,6,x,x]
#Conv2d(Input_channel:输入的通道数,kernel_channels:卷积核的数量,输出的通道数,kernel_size:卷积核的大小,stride:步长,padding:边缘补足)
nn.Conv2d(3,6,kernel_size=5,stride=1,padding=0), #池化
nn.MaxPool2d(kernel_size=2,stride=2,padding=0), #卷积层
nn.Conv2d(6,16,kernel_size=5,stride=1,padding=0), #池化
nn.AvgPool2d(kernel_size=2,stride=2,padding=0) #output:[b,16,5,5]
) #flatten #Linear层
self.fc_unit=nn.Sequential(
nn.Linear(16*5*5,120),
nn.ReLU(),
nn.Linear(120,84),
nn.ReLU(),
nn.Linear(84,10)
) #测试卷积输出到全连接层的输入
#tmp = torch.rand(2,3,32,32)
#out = self.conv_unit(tmp)
#print("conv_out:",out.shape) #Loss评价 Cross Entropy Loss 分类 在其中包含一个softmax()操作
#self.criteon = nn.MSELoss() 回归
#self.criteon = nn.CrossEntropyLoss() def forward(self,x):
""" :param x:[b,3,32,32]
:return:
"""
batchsz = x.size(0)
#[b,3,32,32]=>[b,16,5,5]
x = self.conv_unit(x)
#[b,16,5,5]=>[b,16*5*5]
x = x.view(batchsz,16*5*5)
#[b,16*5*5]=>[b,10]
logits = self.fc_unit(x) return logits # [b,10]
# pred = F.softmax(logits,dim=1) 这步在CEL中包含了,所以不需要再写一次
#loss = self.criteon(logits,y) def main():
net = Lenet5()
tmp = torch.rand(2,3,32,32)
out = net(tmp)
print("lenet_out:",out.shape) if __name__ == '__main__':
main()

训练与测试

import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from lenet5 import Lenet5
import torch.nn.functional as F
from torch import nn,optim def main(): batch_size = 32
epochs = 1000
learn_rate = 1e-3 #导入图片,一次只导入一张
cifer_train = datasets.CIFAR10('cifar',train=True,transform=transforms.Compose([
transforms.Resize((32,32)),
transforms.ToTensor()
]),download=True) #加载图
cifer_train = DataLoader(cifer_train,batch_size=batch_size,shuffle=True) #导入图片,一次只导入一张
cifer_test = datasets.CIFAR10('cifar',train=False,transform=transforms.Compose([
transforms.Resize((32,32)),
transforms.ToTensor()
]),download=True) #加载图
cifer_test = DataLoader(cifer_test,batch_size=batch_size,shuffle=True) #iter迭代器,__next__()方法可以获得数据
x, label = iter(cifer_train).__next__()
print("x:",x.shape,"label:",label.shape)
#x: torch.Size([32, 3, 32, 32]) label: torch.Size([32]) device = torch.device('cuda')
model = Lenet5().to(device)
print(model)
criteon = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(),lr=learn_rate) for epoch in range(epochs):
model.train()
for batchidx,(x,label) in enumerate(cifer_train):
x,label = x.to(device),label.to(device) logits = model(x)
#logits:[b,10] loss = criteon(logits,label) #backprop
optimizer.zero_grad() #梯度清零
loss.backward()
optimizer.step() #梯度更新
#
print(epoch,loss.item()) model.eval()
with torch.no_grad():
#test
total_correct = 0
total_num = 0
for x,label in cifer_test:
x,label = x.to(device),label.to(device)
#[b,10]
logits = model(x)
#[b]
pred =logits.argmax(dim=1) #[b] vs [b] => scalar tensor
total_correct += torch.eq(pred,label).float().sum().item()
total_num += x.size(0) acc = total_correct/total_num
print("epoch:",epoch,"acc:",acc) if __name__ == '__main__':
main()

深度学习--实战 LeNet5的更多相关文章

  1. 深度学习实战篇-基于RNN的中文分词探索

    深度学习实战篇-基于RNN的中文分词探索 近年来,深度学习在人工智能的多个领域取得了显著成绩.微软使用的152层深度神经网络在ImageNet的比赛上斩获多项第一,同时在图像识别中超过了人类的识别水平 ...

  2. 学习Keras:《Keras快速上手基于Python的深度学习实战》PDF代码+mobi

    有一定Python和TensorFlow基础的人看应该很容易,各领域的应用,但比较广泛,不深刻,讲硬件的部分可以作为入门人的参考. <Keras快速上手基于Python的深度学习实战>系统 ...

  3. 对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码

    PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉.自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的Allen N ...

  4. 『深度应用』NLP机器翻译深度学习实战课程·零(基础概念)

    0.前言 深度学习用的有一年多了,最近开始NLP自然处理方面的研发.刚好趁着这个机会写一系列NLP机器翻译深度学习实战课程. 本系列课程将从原理讲解与数据处理深入到如何动手实践与应用部署,将包括以下内 ...

  5. 『深度应用』NLP机器翻译深度学习实战课程·壹(RNN base)

    深度学习用的有一年多了,最近开始NLP自然处理方面的研发.刚好趁着这个机会写一系列NLP机器翻译深度学习实战课程. 本系列课程将从原理讲解与数据处理深入到如何动手实践与应用部署,将包括以下内容:(更新 ...

  6. TensorFlow 2.0 深度学习实战 —— 浅谈卷积神经网络 CNN

    前言 上一章为大家介绍过深度学习的基础和多层感知机 MLP 的应用,本章开始将深入讲解卷积神经网络的实用场景.卷积神经网络 CNN(Convolutional Neural Networks,Conv ...

  7. 【神经网络与深度学习】深度学习实战——caffe windows 下训练自己的网络模型

    1.相关准备 1.1 手写数字数据集 这篇博客上有.jpg格式的图片下载,附带标签信息,有需要的自行下载,博客附带百度云盘下载地址(手写数字.jpg 格式):http://blog.csdn.net/ ...

  8. Tensorflow 2.0 深度学习实战 —— 详细介绍损失函数、优化器、激活函数、多层感知机的实现原理

    前言 AI 人工智能包含了机器学习与深度学习,在前几篇文章曾经介绍过机器学习的基础知识,包括了监督学习和无监督学习,有兴趣的朋友可以阅读< Python 机器学习实战 >.而深度学习开始只 ...

  9. 一箭N雕:多任务深度学习实战

    1.多任务学习导引 多任务学习是机器学习中的一个分支,按1997年综述论文Multi-task Learning一文的定义:Multitask Learning (MTL) is an inducti ...

  10. TensorFlow深度学习实战---图像识别与卷积神经网络

    全连接层网络结构:神经网络每两层之间的所有结点都是有边相连的. 卷积神经网络:1.输入层 2.卷积层:将神经网络中的每一个小块进行更加深入地分析从而得到抽象程度更高的特征. 3 池化层:可以认为将一张 ...

随机推荐

  1. DBeaver通过phoenix连接云主机的hbase

    准备 1.云主机上已经安装好jdk.hadoop.hbase.zookeeper.phoenix,并且在主机上测试连接成功.可参考 https://blog.csdn.net/shangxindeku ...

  2. Excel 去除合并并保留原值的办法

    部分Excel中,对行进行了合并.这个方便展示,但是筛选后数据展示会出现问题,需要去除合并,并在每行中保留原来的值. 1.先选择整行,并"取消单元格合并" 操作后出现大量的空值行. ...

  3. 将 Sql Server 表信息 C# 对象化 小工具_ 张光荣 的 正能量

    注: a.此程序所得到的结果是根据本人个人习惯生成,所以,肯定不完全适合所有人使用,重在想法...然后个人根据个人需求作出更进...b.程序中可能会出现数据库连接的错误提示[原因概是在sql 连接过程 ...

  4. Spyglass CDC工具使用(三)

    最近一直在搞CDC (clock domain crossing) 方面的事情,现在就CDC的一些知识点进行总结. 做CDC检查使用的是Spyglass工具.以下内容转载自:Spyglass之CDC检 ...

  5. webpack 3/4踩坑,我太难了,从安装、卸载、到使用,各相应的版本号,sass-loader报错-版本的原因,webpack -v 不识别,没卸载干净

     -先说卸载: wabpack@4对应的每个插件的版本号都在最后 1 全局安装的话,npm uninstall webpack -g 有时候并不能卸载干净, 2 webpack -v 可判断是否安装成 ...

  6. 修改密码 MVC

    控制器site public function actionPassword(){ $model = new PasswordForm(); /*判断请求属性 if ($request->isA ...

  7. .net Core使用Knife4jUI更换Swagger皮肤

    Knife4j的前身是swagger-bootstrap-ui,前身swagger-bootstrap-ui是一个纯swagger-ui的ui皮肤项目 官网实战指南:https://doc.xiaom ...

  8. 集群与iptables

    Iptables 五链四表执行关系如图所示,容器环境最常用的就是filter和nat表 加上各种自定义的链插入到各个环节,拦截流量做各种控制 filter表:匹配数据包以进行过滤 nat表:修改数据包 ...

  9. Redis Stream Commands 命令学习-1 XADD XRANGE XREVRANGE

    概况 A Redis stream is a data structure that acts like an append-only log. You can use streams to reco ...

  10. redis.clients.jedis.exceptions.JedisConnectionException: Failed connecting to "xxxxx"

    Java 连接 Redis所遇问题 1. 检查Linux是否关闭防火墙,或对外开放redis默认端口6379 关闭防火墙. systemctl stop firewalld 对外开放端口.firewa ...