# ====================LeNet-5_main.py===============
# pytorch+torchvision+visdom
 # -*- coding: utf-8 -*-
"""
Created on Sun May 26 22:53:52 2019 @author: jiangshan
"""
#A modified LeNet-5 [LeCun et al., 1998a] on the MNIST dataset.
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets.mnist import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import visdom
from collections import OrderedDict class LeNet5(nn.Module):
"""
Input - 1x32x32
C1 - 6@28x28 (5x5 kernel)
relu
S2 - 6@14x14 (2x2 kernel, stride 2) Subsampling
C3 - 16@10x10 (5x5 kernel, complicated shit)
relu
S4 - 16@5x5 (2x2 kernel, stride 2) Subsampling
C5 - 120@1x1 (5x5 kernel)
F6 - 84
relu
F7 - 10 (Output)
"""
def __init__(self):
super(LeNet5, self).__init__() self.convnet = nn.Sequential(OrderedDict([
('c1', nn.Conv2d(1, 6, kernel_size=(5, 5))),
('relu1', nn.ReLU()),
('s2', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
('c3', nn.Conv2d(6, 16, kernel_size=(5, 5))),
('relu3', nn.ReLU()),
('s4', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
('c5', nn.Conv2d(16, 120, kernel_size=(5, 5))),
('relu5', nn.ReLU())
])) self.fc = nn.Sequential(OrderedDict([
('f6', nn.Linear(120, 84)),
('relu6', nn.ReLU()),
('f7', nn.Linear(84, 10)),
('sig7', nn.LogSoftmax(dim=-1))
])) def forward(self, img):
output = self.convnet(img)
output = output.view(img.size(0), -1)
output = self.fc(output)
return output viz = visdom.Visdom()
data_train = MNIST('./data/mnist',
download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()]))
data_test = MNIST('./data/mnist',
train=False,
download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()]))
data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8) net = LeNet5()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=2e-3) cur_batch_win = None
cur_batch_win_opts = {
'title': 'Epoch Loss Trace',
'xlabel': 'Batch Number',
'ylabel': 'Loss',
'width': 1200,
'height': 600,
} def train(epoch):
global cur_batch_win
net.train()
loss_list, batch_list = [], []
for i, (images, labels) in enumerate(data_train_loader):
optimizer.zero_grad() output = net(images) loss = criterion(output, labels) loss_list.append(loss.detach().cpu().item())
batch_list.append(i+1) if i % 10 == 0:
print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.detach().cpu().item())) # Update Visualization
if viz.check_connection():
cur_batch_win = viz.line(torch.Tensor(loss_list), torch.Tensor(batch_list),
win=cur_batch_win, name='current_batch_loss',
update=(None if cur_batch_win is None else 'replace'),
opts=cur_batch_win_opts)
loss.backward()
optimizer.step() def test():
net.eval()
total_correct = 0
avg_loss = 0.0
for i, (images, labels) in enumerate(data_test_loader):
output = net(images)
avg_loss += criterion(output, labels).sum()
pred = output.detach().max(1)[1]
total_correct += pred.eq(labels.view_as(pred)).sum() avg_loss /= len(data_test)
print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test))) def train_and_test(epoch):
train(epoch)
test() def main():
for e in range(1, 16):
train_and_test(e) if __name__ == '__main__':
main()

先开启visdom 进行可视化

python -m visdom.server

运行程序

python LeNet-5_main.py

打开浏览器查看live graph

http://localhost:8097

LeNet-5 pytorch+torchvision+visdom的更多相关文章

  1. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  2. Linux服务器配置GPU版本的pytorch Torchvision TensorFlow

    最近在Linux服务器上配置项目,项目需要使用GPU版本的pytorch和TensorFlow,而且该项目内会同时使用TensorFlow的GPU和CPU. 在服务器上装环境,如果重新开始,就需要下载 ...

  3. pytorch的visdom启动不了、蓝屏

    pytorch的visdom启动不了.蓝屏 问题描述:我是在ubuntu16.04上用python3.5安装的visdom.可是启动是蓝屏:在网上找了很久的解决方案:有三篇博文: https://bl ...

  4. 云服务器搭建anaconda pytorch torchvision

    (因为在普通用户上安装有些权限问题安装出错,所以我在root用户下相对容易安装,但是anaconda官网说可以直接在普通用户下安装,不过,在root下安装,其他用户也是能用的. 访问Anaconda官 ...

  5. Pytorch Torchvision Transform

    Torchvision.Transforms Transforms包含常用图像转换操作.可以使用Compose将它们链接在一起. 此外,还有torchvision.transforms.functio ...

  6. pytorch torchvision.ImageFolder的使用

    参考:https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/ torchvision.dataset ...

  7. pytorch torchvision对图像进行变换

    class torchvision.transforms.Compose(转换) 多个将transform组合起来使用. class torchvision.transforms.CenterCrop ...

  8. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上

    总结一下相关概念: torch.Tensor - 一个近似多维数组的数据结构 autograd.Variable - 改变Tensor并且记录下来操作的历史记录.和Tensor拥有相同的API,以及b ...

  9. PyTorch常用代码段整理合集

    PyTorch常用代码段整理合集 转自:知乎 作者:张皓 众所周知,程序猿在写代码时通常会在网上搜索大量资料,其中大部分是代码段.然而,这项工作常常令人心累身疲,耗费大量时间.所以,今天小编转载了知乎 ...

随机推荐

  1. (7)打造简单OS-加载内核

    一.简要说明 我们在第五讲[(5)打造简单OS-进入保护模式]中的mbr.S 汇编文件有段这样的代码 mov eax, 0x2 ; 起始扇区lba地址,从间隔第二个扇区开始 mov bx, 0x900 ...

  2. P2320 [HNOI2006]鬼谷子的钱袋——进制(没事就别看这个了)

    就是n可以被1到n/2的所有数表示出来: 我一开始写了个把二进制数里的1拿出来,但是WA了两个点: 分治? 好多人说数据有问题,我也不知道,也不想知道: %:include<cstdio> ...

  3. 【模板】杜教筛(Sum)

    传送门 Description 给定一个正整数\(N(N\le2^{31}-1)\) 求 \[ans1=\sum_{i=1}^n \varphi(i)\] \[ans_2=\sum_{i=1}^n \ ...

  4. Java Heap dump文件分析工具jhat简介

    jhat 是Java堆分析工具(Java heap Analyzes Tool). 在JDK6u7之后成为标配. 使用该命令需要有一定的Java开发经验,官方不对此工具提供技术支持和客户服务. 用法: ...

  5. Spring MVC 三大组件

    ㈠ HandlerMapping 处理器映射(一般通过扫描包配置) 通过处理器映射,你可以将Web 请求映射到正确的处理器 Controller 上.当接收到请求时,DispactherServlet ...

  6. MySQL 中视图和表的区别以及联系是什么?

    两者的区别: (1)视图是已经编译好的 SQL 语句,是基于 SQL 语句的结果集的可视化的表,而表不是. (2)视图没有实际的物理记录,而基本表有. (3)表是内容,视图是窗口. (4)表占用物理空 ...

  7. javascript 的垃圾回收机制讲一下

    定义:指一块被分配的内存既不能使用,又不能回收,直到浏览器进程结束. 像 C 这样的编程语言,具有低级内存管理原语,如 malloc()和 free().开发人员使用这些原语显式地对操作系统的内存进行 ...

  8. Logback 输出 JPA SQL日志 到文件

    Logback 输出 JPA SQL日志 到文件 使用Spring Boot 配置 JPA 时可以指定如下配置在控制台查看执行的SQL语句 spring.jpa.show-sql=true Sprin ...

  9. LC 991. Broken Calculator

    On a broken calculator that has a number showing on its display, we can perform two operations: Doub ...

  10. Spring Boot属性配置&自定义属性配置

    一.修改默认配置 例1.spring boot 开发web应用的时候,默认tomcat的启动端口为8080,如果需要修改默认的端口,则需要在application.properties 添加以下记录: ...