LeNet-5 pytorch+torchvision+visdom
# ====================LeNet-5_main.py===============
# pytorch+torchvision+visdom
# -*- coding: utf-8 -*-
"""
Created on Sun May 26 22:53:52 2019 @author: jiangshan
"""
#A modified LeNet-5 [LeCun et al., 1998a] on the MNIST dataset.
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets.mnist import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import visdom
from collections import OrderedDict class LeNet5(nn.Module):
"""
Input - 1x32x32
C1 - 6@28x28 (5x5 kernel)
relu
S2 - 6@14x14 (2x2 kernel, stride 2) Subsampling
C3 - 16@10x10 (5x5 kernel, complicated shit)
relu
S4 - 16@5x5 (2x2 kernel, stride 2) Subsampling
C5 - 120@1x1 (5x5 kernel)
F6 - 84
relu
F7 - 10 (Output)
"""
def __init__(self):
super(LeNet5, self).__init__() self.convnet = nn.Sequential(OrderedDict([
('c1', nn.Conv2d(1, 6, kernel_size=(5, 5))),
('relu1', nn.ReLU()),
('s2', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
('c3', nn.Conv2d(6, 16, kernel_size=(5, 5))),
('relu3', nn.ReLU()),
('s4', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
('c5', nn.Conv2d(16, 120, kernel_size=(5, 5))),
('relu5', nn.ReLU())
])) self.fc = nn.Sequential(OrderedDict([
('f6', nn.Linear(120, 84)),
('relu6', nn.ReLU()),
('f7', nn.Linear(84, 10)),
('sig7', nn.LogSoftmax(dim=-1))
])) def forward(self, img):
output = self.convnet(img)
output = output.view(img.size(0), -1)
output = self.fc(output)
return output viz = visdom.Visdom()
data_train = MNIST('./data/mnist',
download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()]))
data_test = MNIST('./data/mnist',
train=False,
download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()]))
data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8) net = LeNet5()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=2e-3) cur_batch_win = None
cur_batch_win_opts = {
'title': 'Epoch Loss Trace',
'xlabel': 'Batch Number',
'ylabel': 'Loss',
'width': 1200,
'height': 600,
} def train(epoch):
global cur_batch_win
net.train()
loss_list, batch_list = [], []
for i, (images, labels) in enumerate(data_train_loader):
optimizer.zero_grad() output = net(images) loss = criterion(output, labels) loss_list.append(loss.detach().cpu().item())
batch_list.append(i+1) if i % 10 == 0:
print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.detach().cpu().item())) # Update Visualization
if viz.check_connection():
cur_batch_win = viz.line(torch.Tensor(loss_list), torch.Tensor(batch_list),
win=cur_batch_win, name='current_batch_loss',
update=(None if cur_batch_win is None else 'replace'),
opts=cur_batch_win_opts)
loss.backward()
optimizer.step() def test():
net.eval()
total_correct = 0
avg_loss = 0.0
for i, (images, labels) in enumerate(data_test_loader):
output = net(images)
avg_loss += criterion(output, labels).sum()
pred = output.detach().max(1)[1]
total_correct += pred.eq(labels.view_as(pred)).sum() avg_loss /= len(data_test)
print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test))) def train_and_test(epoch):
train(epoch)
test() def main():
for e in range(1, 16):
train_and_test(e) if __name__ == '__main__':
main()
先开启visdom 进行可视化
python -m visdom.server
运行程序
python LeNet-5_main.py
打开浏览器查看live graph
LeNet-5 pytorch+torchvision+visdom的更多相关文章
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...
- Linux服务器配置GPU版本的pytorch Torchvision TensorFlow
最近在Linux服务器上配置项目,项目需要使用GPU版本的pytorch和TensorFlow,而且该项目内会同时使用TensorFlow的GPU和CPU. 在服务器上装环境,如果重新开始,就需要下载 ...
- pytorch的visdom启动不了、蓝屏
pytorch的visdom启动不了.蓝屏 问题描述:我是在ubuntu16.04上用python3.5安装的visdom.可是启动是蓝屏:在网上找了很久的解决方案:有三篇博文: https://bl ...
- 云服务器搭建anaconda pytorch torchvision
(因为在普通用户上安装有些权限问题安装出错,所以我在root用户下相对容易安装,但是anaconda官网说可以直接在普通用户下安装,不过,在root下安装,其他用户也是能用的. 访问Anaconda官 ...
- Pytorch Torchvision Transform
Torchvision.Transforms Transforms包含常用图像转换操作.可以使用Compose将它们链接在一起. 此外,还有torchvision.transforms.functio ...
- pytorch torchvision.ImageFolder的使用
参考:https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/ torchvision.dataset ...
- pytorch torchvision对图像进行变换
class torchvision.transforms.Compose(转换) 多个将transform组合起来使用. class torchvision.transforms.CenterCrop ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上
总结一下相关概念: torch.Tensor - 一个近似多维数组的数据结构 autograd.Variable - 改变Tensor并且记录下来操作的历史记录.和Tensor拥有相同的API,以及b ...
- PyTorch常用代码段整理合集
PyTorch常用代码段整理合集 转自:知乎 作者:张皓 众所周知,程序猿在写代码时通常会在网上搜索大量资料,其中大部分是代码段.然而,这项工作常常令人心累身疲,耗费大量时间.所以,今天小编转载了知乎 ...
随机推荐
- LOJ2541. 「PKUWC2018」猎人杀 [概率,分治NTT]
传送门 思路 好一个神仙题qwq 首先,发现由于一个人死之后分母会变,非常麻烦,考虑用某种方法定住分母. 我们稍微改一改游戏规则:一个人被打死时只打个标记,并不移走,也就是说可以被打多次但只算一次.容 ...
- 如何在微信小程序中国引入fontawesome字体图标
fontawesome官网地址:http://fontawesome.dashgame.com/ 一. 二. 下载之后的字体图标 找到 文件中的如下图.ttf文件 三. 在https://transf ...
- Windows使用Latex
目录 安装Texlive 安装TeXstudio 编写简单的文章 教程 安装Texlive 到清华大学开源软件镜像站下载Texlive2019.iso文件 下载之后,如果有光驱就装载,没有的话就解压. ...
- Matlab基础:关于图像的基本操作
-- %% 学习目标:学习关于图像的基本操作 %% 通过抖动来增强图像的的色彩对比度 clear all; close all; I = imread('cameraman.tif');%读取灰度图像 ...
- SQL-W3School-函数:SQL MAX() 函数
ylbtech-SQL-W3School-函数:SQL MAX() 函数 1.返回顶部 1. MAX() 函数 MAX 函数返回一列中的最大值.NULL 值不包括在计算中. SQL MAX() 语法 ...
- 判断Activty是否在前台运行
/** * 判断某个界面是否在前台 * * @param context * @param className 某个界面名称 * */ public static boolean isActivity ...
- angular中子组件通过@Output 触发父组件的方 法
1. 子组件引入 Output 和 EventEmitter import { Component, OnInit ,Input,Output,EventEmitter} from '@angular ...
- ASP程序中调用Now()总显示“上午”和“下午”,如何解决?
ASP程序中调用Now()总显示这样的格式:“2007-4-20 下午 06:06:38”,我要的正确格式为“2007-4-20 18:06:38”,我已经通过控制面板==>区域和语言选项==& ...
- 阶段5 3.微服务项目【学成在线】_day16 Spring Security Oauth2_12-SpringSecurityOauth2研究-JWT研究-生成私钥和公钥
3.6.3 JWT入门 Spring Security 提供对JWT的支持,本节我们使用Spring Security 提供的JwtHelper来创建JWT令牌,校验JWT令牌 等操作. 3.6.3. ...
- 阶段5 3.微服务项目【学成在线】_day16 Spring Security Oauth2_10-SpringSecurityOauth2研究-校验令牌&刷新令牌
3.5校验令牌 Spring Security Oauth2提供校验令牌的端点,如下: Get: http://localhost:40400/auth/oauth/check_token?token ...