LeNet-5 pytorch+torchvision+visdom
# ====================LeNet-5_main.py===============
# pytorch+torchvision+visdom
# -*- coding: utf-8 -*-
"""
Created on Sun May 26 22:53:52 2019 @author: jiangshan
"""
#A modified LeNet-5 [LeCun et al., 1998a] on the MNIST dataset.
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets.mnist import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import visdom
from collections import OrderedDict class LeNet5(nn.Module):
"""
Input - 1x32x32
C1 - 6@28x28 (5x5 kernel)
relu
S2 - 6@14x14 (2x2 kernel, stride 2) Subsampling
C3 - 16@10x10 (5x5 kernel, complicated shit)
relu
S4 - 16@5x5 (2x2 kernel, stride 2) Subsampling
C5 - 120@1x1 (5x5 kernel)
F6 - 84
relu
F7 - 10 (Output)
"""
def __init__(self):
super(LeNet5, self).__init__() self.convnet = nn.Sequential(OrderedDict([
('c1', nn.Conv2d(1, 6, kernel_size=(5, 5))),
('relu1', nn.ReLU()),
('s2', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
('c3', nn.Conv2d(6, 16, kernel_size=(5, 5))),
('relu3', nn.ReLU()),
('s4', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
('c5', nn.Conv2d(16, 120, kernel_size=(5, 5))),
('relu5', nn.ReLU())
])) self.fc = nn.Sequential(OrderedDict([
('f6', nn.Linear(120, 84)),
('relu6', nn.ReLU()),
('f7', nn.Linear(84, 10)),
('sig7', nn.LogSoftmax(dim=-1))
])) def forward(self, img):
output = self.convnet(img)
output = output.view(img.size(0), -1)
output = self.fc(output)
return output viz = visdom.Visdom()
data_train = MNIST('./data/mnist',
download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()]))
data_test = MNIST('./data/mnist',
train=False,
download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()]))
data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8) net = LeNet5()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=2e-3) cur_batch_win = None
cur_batch_win_opts = {
'title': 'Epoch Loss Trace',
'xlabel': 'Batch Number',
'ylabel': 'Loss',
'width': 1200,
'height': 600,
} def train(epoch):
global cur_batch_win
net.train()
loss_list, batch_list = [], []
for i, (images, labels) in enumerate(data_train_loader):
optimizer.zero_grad() output = net(images) loss = criterion(output, labels) loss_list.append(loss.detach().cpu().item())
batch_list.append(i+1) if i % 10 == 0:
print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.detach().cpu().item())) # Update Visualization
if viz.check_connection():
cur_batch_win = viz.line(torch.Tensor(loss_list), torch.Tensor(batch_list),
win=cur_batch_win, name='current_batch_loss',
update=(None if cur_batch_win is None else 'replace'),
opts=cur_batch_win_opts)
loss.backward()
optimizer.step() def test():
net.eval()
total_correct = 0
avg_loss = 0.0
for i, (images, labels) in enumerate(data_test_loader):
output = net(images)
avg_loss += criterion(output, labels).sum()
pred = output.detach().max(1)[1]
total_correct += pred.eq(labels.view_as(pred)).sum() avg_loss /= len(data_test)
print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test))) def train_and_test(epoch):
train(epoch)
test() def main():
for e in range(1, 16):
train_and_test(e) if __name__ == '__main__':
main()
先开启visdom 进行可视化
python -m visdom.server
运行程序
python LeNet-5_main.py
打开浏览器查看live graph
LeNet-5 pytorch+torchvision+visdom的更多相关文章
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...
- Linux服务器配置GPU版本的pytorch Torchvision TensorFlow
最近在Linux服务器上配置项目,项目需要使用GPU版本的pytorch和TensorFlow,而且该项目内会同时使用TensorFlow的GPU和CPU. 在服务器上装环境,如果重新开始,就需要下载 ...
- pytorch的visdom启动不了、蓝屏
pytorch的visdom启动不了.蓝屏 问题描述:我是在ubuntu16.04上用python3.5安装的visdom.可是启动是蓝屏:在网上找了很久的解决方案:有三篇博文: https://bl ...
- 云服务器搭建anaconda pytorch torchvision
(因为在普通用户上安装有些权限问题安装出错,所以我在root用户下相对容易安装,但是anaconda官网说可以直接在普通用户下安装,不过,在root下安装,其他用户也是能用的. 访问Anaconda官 ...
- Pytorch Torchvision Transform
Torchvision.Transforms Transforms包含常用图像转换操作.可以使用Compose将它们链接在一起. 此外,还有torchvision.transforms.functio ...
- pytorch torchvision.ImageFolder的使用
参考:https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/ torchvision.dataset ...
- pytorch torchvision对图像进行变换
class torchvision.transforms.Compose(转换) 多个将transform组合起来使用. class torchvision.transforms.CenterCrop ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上
总结一下相关概念: torch.Tensor - 一个近似多维数组的数据结构 autograd.Variable - 改变Tensor并且记录下来操作的历史记录.和Tensor拥有相同的API,以及b ...
- PyTorch常用代码段整理合集
PyTorch常用代码段整理合集 转自:知乎 作者:张皓 众所周知,程序猿在写代码时通常会在网上搜索大量资料,其中大部分是代码段.然而,这项工作常常令人心累身疲,耗费大量时间.所以,今天小编转载了知乎 ...
随机推荐
- [luogu] zpl的数学题1
https://www.luogu.org/problemnew/show/U16887 $f[1] + f[2] + f[3] + .... + f[n] = f[n + 2] - 1$ 矩阵快速幂 ...
- [转]Linux下的常见信号总结
转自 https://www.cnblogs.com/gaorong/p/6430905.html 在linux下有很多信号,按可靠性分为可靠信号和非可靠信号,按时间分为实时信号和非实时信号,linu ...
- [HNOI2004]L语言 字典树 记忆化搜索
[HNOI2004]L语言 字典树 记忆化搜索 给出\(n\)个字符串作为字典,询问\(m\)个字符串,求每个字符串最远能匹配(字典中的字符串)到的位置 容易想到使用字典树维护字典,然后又发现不能每步 ...
- 【原创】go语言学习(五)函数详解1
目录 1.函数介绍 2.多返回值和可变参数 3.defer语句 4.内置函数介绍 1.函数介绍 1.1定义: 有输⼊入.有输出,⽤用来执⾏行行⼀一个指定任务的代码块. func functionnam ...
- Thread 相关函数和属性
t=Thread(target=func) # 启动子线程t.start() # 阻塞子线程,待子线程结束后,再往下执行t.join() # 判断线程是否在执行状态,在执行返回True,否则返回Fal ...
- [APIO2017]商旅——分数优化+floyd+SPFA判负环+二分答案
题目链接: [APIO2017]商旅 枚举任意两个点$(s,t)$,求出在$s$买入一个物品并在$t$卖出的最大收益. 新建一条从$s$到$t$的边,边权为最大收益,长度为原图从$s$到$t$的最短路 ...
- commit 没有提交图片,但是出现了commit的修改
.gitignore里面写上 image/cache/ 就好了
- mysql的select语句
参考: https://www.cnblogs.com/xiaoshen666/p/10824117.html https://www.cnblogs.com/zouwangblog/archive/ ...
- [RK3288] 外接USB设备出现丢数
CPU:RK3288 系统:Android 5.1 主板外接 USB 接口的外设,经常会出现丢数的现象,这种问题在很多 USB 接口的外设上都遇到过,例如:USB读卡器.USB扫描枪等 有一个共同点是 ...
- IIS部署常见错误
1.404.17 2.402.2 3.401.3 4.未能加载文件或程序集“System.Data.SQLite”或它的某一个依赖项”的解决方法