深度学习之 cnn 进行 CIFAR10 分类
深度学习之 cnn 进行 CIFAR10 分类
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
show = ToPILImage()
import torch as t
import torch.nn as nn
import torch.nn.functional as F
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5, 0.5, 0.5)),
])
# 下载数据
trainset = tv.datasets.CIFAR10(root=".",train=True, download=True, transform=transform)
trainloader = t.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)
testset = tv.datasets.CIFAR10('.', train=False, download=True, transform=transform)
testloader = t.utils.data.DataLoader(testset, batch_size=4,shuffle=False,num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
from torch import optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr = 0.001, momentum=0.9)
from torch.autograd import Variable
for epoch in range(2):
running_loss = 0.0
for i,data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = Variable(inputs), Variable(labels)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.data[0]
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
# 测试
correct = 0
total = 0
for data in testloader:
images, labels = data
outputs = net(Variable(images))
# print(outputs.data)
_, predicted = t.max(outputs.data, 1)
print(outputs.data,_, predicted)
total += labels.size(0)
correct += (predicted == labels).sum()
print('10000张测式中: %d %%' % (100 * correct / total) )
深度学习之 cnn 进行 CIFAR10 分类的更多相关文章
- [转] 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践
转自知乎上看到的一篇很棒的文章:用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文 ...
- 【深度学习】CNN 中 1x1 卷积核的作用
[深度学习]CNN 中 1x1 卷积核的作用 最近研究 GoogLeNet 和 VGG 神经网络结构的时候,都看见了它们在某些层有采取 1x1 作为卷积核,起初的时候,对这个做法很是迷惑,这是因为之前 ...
- 深度学习入门: CNN与LSTM(RNN)
1. 理解深度学习与CNN: 台湾李宏毅教授的入门视频<一天搞懂深度学习>:https://www.bilibili.com/video/av16543434/ 其中对CNN算法的矩阵卷积 ...
- 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践
https://zhuanlan.zhihu.com/p/25928551 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文题目便是文本分类问题,趁此机会总结下文本分类 ...
- 深度学习笔记(一):logistic分类【转】
本文转载自:https://blog.csdn.net/u014595019/article/details/52554582 这个系列主要记录我在学习各个深度学习算法时候的笔记,因为之前已经学过大概 ...
- PyTorch中使用深度学习(CNN和LSTM)的自动图像标题
介绍 深度学习现在是一个非常猖獗的领域 - 有如此多的应用程序日复一日地出现.深入了解深度学习的最佳方法是亲自动手.尽可能多地参与项目,并尝试自己完成.这将帮助您更深入地掌握主题,并帮助您成为更好的深 ...
- keras框架下的深度学习(二)二分类和多分类问题
本文第一部分是对数据处理中one-hot编码的讲解,第二部分是对二分类模型的代码讲解,其模型的建立以及训练过程与上篇文章一样:在最后我们将训练好的模型保存下来,再用自己的数据放入保存下来的模型中进行分 ...
- 自己动手实现深度学习框架-8 RNN文本分类和文本生成模型
代码仓库: https://github.com/brandonlyg/cute-dl 目标 上阶段cute-dl已经可以构建基础的RNN模型.但对文本相模型的支持不够友好, 这个阶段 ...
- Python深度学习案例1--电影评论分类(二分类问题)
我觉得把课本上的案例先自己抄一遍,然后将书看一遍.最后再写一篇博客记录自己所学过程的感悟.虽然与课本有很多相似之处.但自己写一遍感悟会更深 电影评论分类(二分类问题) 本节使用的是IMDB数据集,使用 ...
随机推荐
- 初探WebSocket
初探WebSocket node websocket socket.io 我们平常开发的大部分web页面都是主动'拉'的形式,如果需要更新页面内容,则需要"刷新"一个,但Slack ...
- c++函数常用
isalnum 判断一个字符是否是字符类的数字或字母isalpha 判断一个字符是否是字母isblank 判断一个字符是否是空白字符(空格,水平制表符,TAB)iscntrl 判断一个控制符(ASCI ...
- POI读写Excel-操作包含合并单元格操作
在上篇博客中写到关于Excel操作解析成相关的类,下面将写入一种Excel对Excel表格读取和写入. 对于Excel表格操作,最重要的是创建workBook.其操作顺序是: 1.获得WorkBook ...
- 发个2012年用java写的一个控制台小游戏
时间是把杀狗刀 突然发现了12年用java写的控制台玩的一个文字游戏,有兴趣的可以下载试试哈汪~ 里面难点当时确实遇到过,在计算倒计时的时候用了多线程,当时还写了好久才搞定.很怀念那个时间虽然不会做游 ...
- 简单docker镜像修改方式
• 创建Dockerfile,文件内容如下: FROM nps:v1.0.1 ENTRYPOINT ["/usr/bin/init.sh"] • 启动基础镜像:docker run ...
- 分享基于Qt5开发的一款故障波形模拟软件
背景介绍 这是一款采用Qt5编写的用于生成故障模拟波形的软件.生成的波形数据用于下发到终端机器生成对应的故障类型,用于培训相关设备维护人员的故障排查技能.因此,在这款软件中实现了故障方案管理.故障波形 ...
- 微信公众号的localStorage的大坑
业务流程是:工厂端分享一个邀请合作的二维码,商户这边用手机扫一扫后,关注微信公众号(已关注的老用户自动进入公众号)然后进入到公众号在面板上收到消息,合作邀请(图文字有点不对,请忽略!) 接下来,在点击 ...
- (floyd)佛洛伊德算法
Floyd–Warshall(简称Floyd算法)是一种著名的解决任意两点间的最短路径(All Paris Shortest Paths,APSP)的算法.从表面上粗看,Floyd算法是一个非常简单的 ...
- 如何在C#项目中使用NHibernate
现代化大型项目通常使用独立的数据库来存储数据,其中以采用关系型数据库居多.用于开发项目的高级语言(C#.Java等)是面向对象的,而关系型数据库是基于关系的,两者之间的沟通需要一种转换,也就是对象/关 ...
- Anagram
Anagram poj-1256 题目大意:给你n个字符串,求每一个字符串所有字符的全排列,按照顺序输出所有全排列. 注释:每一个字符长度小于13,且字符排序的顺序是:A<a<B<b ...