深度学习之 cnn 进行 CIFAR10 分类

import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
show = ToPILImage()
import torch as t
import torch.nn as nn
import torch.nn.functional as F transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5, 0.5, 0.5)),
]) # 下载数据
trainset = tv.datasets.CIFAR10(root=".",train=True, download=True, transform=transform)
trainloader = t.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)
testset = tv.datasets.CIFAR10('.', train=False, download=True, transform=transform) testloader = t.utils.data.DataLoader(testset, batch_size=4,shuffle=False,num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # 网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x net = Net() from torch import optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr = 0.001, momentum=0.9)
from torch.autograd import Variable for epoch in range(2):
running_loss = 0.0
for i,data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = Variable(inputs), Variable(labels) optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward() optimizer.step() running_loss += loss.data[0]
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training') # 测试
correct = 0
total = 0
for data in testloader:
images, labels = data
outputs = net(Variable(images))
# print(outputs.data)
_, predicted = t.max(outputs.data, 1)
print(outputs.data,_, predicted)
total += labels.size(0)
correct += (predicted == labels).sum() print('10000张测式中: %d %%' % (100 * correct / total) )

深度学习之 cnn 进行 CIFAR10 分类的更多相关文章

  1. [转] 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践

    转自知乎上看到的一篇很棒的文章:用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文 ...

  2. 【深度学习】CNN 中 1x1 卷积核的作用

    [深度学习]CNN 中 1x1 卷积核的作用 最近研究 GoogLeNet 和 VGG 神经网络结构的时候,都看见了它们在某些层有采取 1x1 作为卷积核,起初的时候,对这个做法很是迷惑,这是因为之前 ...

  3. 深度学习入门: CNN与LSTM(RNN)

    1. 理解深度学习与CNN: 台湾李宏毅教授的入门视频<一天搞懂深度学习>:https://www.bilibili.com/video/av16543434/ 其中对CNN算法的矩阵卷积 ...

  4. 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践

    https://zhuanlan.zhihu.com/p/25928551 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文题目便是文本分类问题,趁此机会总结下文本分类 ...

  5. 深度学习笔记(一):logistic分类【转】

    本文转载自:https://blog.csdn.net/u014595019/article/details/52554582 这个系列主要记录我在学习各个深度学习算法时候的笔记,因为之前已经学过大概 ...

  6. PyTorch中使用深度学习(CNN和LSTM)的自动图像标题

    介绍 深度学习现在是一个非常猖獗的领域 - 有如此多的应用程序日复一日地出现.深入了解深度学习的最佳方法是亲自动手.尽可能多地参与项目,并尝试自己完成.这将帮助您更深入地掌握主题,并帮助您成为更好的深 ...

  7. keras框架下的深度学习(二)二分类和多分类问题

    本文第一部分是对数据处理中one-hot编码的讲解,第二部分是对二分类模型的代码讲解,其模型的建立以及训练过程与上篇文章一样:在最后我们将训练好的模型保存下来,再用自己的数据放入保存下来的模型中进行分 ...

  8. 自己动手实现深度学习框架-8 RNN文本分类和文本生成模型

    代码仓库: https://github.com/brandonlyg/cute-dl 目标         上阶段cute-dl已经可以构建基础的RNN模型.但对文本相模型的支持不够友好, 这个阶段 ...

  9. Python深度学习案例1--电影评论分类(二分类问题)

    我觉得把课本上的案例先自己抄一遍,然后将书看一遍.最后再写一篇博客记录自己所学过程的感悟.虽然与课本有很多相似之处.但自己写一遍感悟会更深 电影评论分类(二分类问题) 本节使用的是IMDB数据集,使用 ...

随机推荐

  1. 浅析git

    git是什么 简单来说,Git,它是一个快速的 分布式版本控制系统 (Distributed Version Control System,简称 DVCS) . 同传统的 集中式版本控制系统 (Cen ...

  2. 移动端tab滑动和上下拉刷新加载

    移动端tab滑动和上下拉刷新加载 查看demo(请在移动端模式下查看) 查看代码 开发该插件的初衷是,在做一个项目时发现现在实现移动端tab滑动的插件大多基于swiper,swiper的功能太强大而我 ...

  3. 实用的Docker入门

    1 Docker概述 Docker和虚拟机一样,都拥有环境隔离的能力,但它比虚拟机更加轻量级,可以使资源更大化地得到应用.首先来看Docker的架构图: 理解其中几个概念: Client(Docker ...

  4. 关于eclipse安装Genymotion插件的方法

    其实Genymotion的安装方法也有两种:在线安装和离线安装,不过推荐使用在线安装,这个更快.这里我只说在线安装的方法. 打开eclipse,点击help-install new software ...

  5. Intellij +Maven 报错: Dmaven.multiModuleProjectDirectory system property is not set. Check $M2_HOME environment variable and mvn script match.

    在intellij使用 Maven Project 测试时,运行test时看到log里的报错信息: -Dmaven.multiModuleProjectDirectory system propert ...

  6. 深度剖析PHP序列化和反序列化

    序列化 序列化格式 在PHP中,序列化用于存储或传递 PHP 的值的过程中,同时不丢失其类型和结构. 序列化函数原型如下: string serialize ( mixed $value ) 先看下面 ...

  7. oracle 10g数据库下的 XDB组件的重新安装

    emmmm,这是一个不做死就不会的过程!!! 今天在导出数据库时,遇到了报错信息,其实开发说这个报错没关系了,但作死如楼主,一定要把这个错给解决了,然后就有了下面的作死过程. 错误关键字是:packa ...

  8. 基于netcore实现mongodb和ElasticSearch之间的数据实时同步的工具(Mongo2Es)

    基于netcore实现mongodb和ElasticSearch之间的数据实时同步的工具 支持一对一,一对多,多对一和多对多的数据传输方式. 一对一 - 一个mongodb的collection对应一 ...

  9. canvas动画气球

    canvas小球的动画我用canvas画布实现的小球动画效果,可以参考下 我用canvas画布实现的小球动画效果,可以参考下 我用canvas画布实现的小球动画效果,可以参考下 我用canvas画布实 ...

  10. conda下载速度慢——添加源

    清华提供的anaconda镜像,使用以后真的很快!尤其在学校龟速的网络环境里提速非常明显. https://mirrors.tuna.tsinghua.edu.cn/help/anaconda/ TU ...