# ====================LeNet-5_main.py===============
# pytorch+torchvision+visdom
 # -*- coding: utf-8 -*-
"""
Created on Sun May 26 22:53:52 2019 @author: jiangshan
"""
#A modified LeNet-5 [LeCun et al., 1998a] on the MNIST dataset.
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets.mnist import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import visdom
from collections import OrderedDict class LeNet5(nn.Module):
"""
Input - 1x32x32
C1 - 6@28x28 (5x5 kernel)
relu
S2 - 6@14x14 (2x2 kernel, stride 2) Subsampling
C3 - 16@10x10 (5x5 kernel, complicated shit)
relu
S4 - 16@5x5 (2x2 kernel, stride 2) Subsampling
C5 - 120@1x1 (5x5 kernel)
F6 - 84
relu
F7 - 10 (Output)
"""
def __init__(self):
super(LeNet5, self).__init__() self.convnet = nn.Sequential(OrderedDict([
('c1', nn.Conv2d(1, 6, kernel_size=(5, 5))),
('relu1', nn.ReLU()),
('s2', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
('c3', nn.Conv2d(6, 16, kernel_size=(5, 5))),
('relu3', nn.ReLU()),
('s4', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
('c5', nn.Conv2d(16, 120, kernel_size=(5, 5))),
('relu5', nn.ReLU())
])) self.fc = nn.Sequential(OrderedDict([
('f6', nn.Linear(120, 84)),
('relu6', nn.ReLU()),
('f7', nn.Linear(84, 10)),
('sig7', nn.LogSoftmax(dim=-1))
])) def forward(self, img):
output = self.convnet(img)
output = output.view(img.size(0), -1)
output = self.fc(output)
return output viz = visdom.Visdom()
data_train = MNIST('./data/mnist',
download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()]))
data_test = MNIST('./data/mnist',
train=False,
download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()]))
data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8) net = LeNet5()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=2e-3) cur_batch_win = None
cur_batch_win_opts = {
'title': 'Epoch Loss Trace',
'xlabel': 'Batch Number',
'ylabel': 'Loss',
'width': 1200,
'height': 600,
} def train(epoch):
global cur_batch_win
net.train()
loss_list, batch_list = [], []
for i, (images, labels) in enumerate(data_train_loader):
optimizer.zero_grad() output = net(images) loss = criterion(output, labels) loss_list.append(loss.detach().cpu().item())
batch_list.append(i+1) if i % 10 == 0:
print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.detach().cpu().item())) # Update Visualization
if viz.check_connection():
cur_batch_win = viz.line(torch.Tensor(loss_list), torch.Tensor(batch_list),
win=cur_batch_win, name='current_batch_loss',
update=(None if cur_batch_win is None else 'replace'),
opts=cur_batch_win_opts)
loss.backward()
optimizer.step() def test():
net.eval()
total_correct = 0
avg_loss = 0.0
for i, (images, labels) in enumerate(data_test_loader):
output = net(images)
avg_loss += criterion(output, labels).sum()
pred = output.detach().max(1)[1]
total_correct += pred.eq(labels.view_as(pred)).sum() avg_loss /= len(data_test)
print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test))) def train_and_test(epoch):
train(epoch)
test() def main():
for e in range(1, 16):
train_and_test(e) if __name__ == '__main__':
main()

先开启visdom 进行可视化

python -m visdom.server

运行程序

python LeNet-5_main.py

打开浏览器查看live graph

http://localhost:8097

LeNet-5 pytorch+torchvision+visdom的更多相关文章

  1. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  2. Linux服务器配置GPU版本的pytorch Torchvision TensorFlow

    最近在Linux服务器上配置项目,项目需要使用GPU版本的pytorch和TensorFlow,而且该项目内会同时使用TensorFlow的GPU和CPU. 在服务器上装环境,如果重新开始,就需要下载 ...

  3. pytorch的visdom启动不了、蓝屏

    pytorch的visdom启动不了.蓝屏 问题描述:我是在ubuntu16.04上用python3.5安装的visdom.可是启动是蓝屏:在网上找了很久的解决方案:有三篇博文: https://bl ...

  4. 云服务器搭建anaconda pytorch torchvision

    (因为在普通用户上安装有些权限问题安装出错,所以我在root用户下相对容易安装,但是anaconda官网说可以直接在普通用户下安装,不过,在root下安装,其他用户也是能用的. 访问Anaconda官 ...

  5. Pytorch Torchvision Transform

    Torchvision.Transforms Transforms包含常用图像转换操作.可以使用Compose将它们链接在一起. 此外,还有torchvision.transforms.functio ...

  6. pytorch torchvision.ImageFolder的使用

    参考:https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/ torchvision.dataset ...

  7. pytorch torchvision对图像进行变换

    class torchvision.transforms.Compose(转换) 多个将transform组合起来使用. class torchvision.transforms.CenterCrop ...

  8. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上

    总结一下相关概念: torch.Tensor - 一个近似多维数组的数据结构 autograd.Variable - 改变Tensor并且记录下来操作的历史记录.和Tensor拥有相同的API,以及b ...

  9. PyTorch常用代码段整理合集

    PyTorch常用代码段整理合集 转自:知乎 作者:张皓 众所周知,程序猿在写代码时通常会在网上搜索大量资料,其中大部分是代码段.然而,这项工作常常令人心累身疲,耗费大量时间.所以,今天小编转载了知乎 ...

随机推荐

  1. TensorFlow(四):手写数字识别

    一:数据集 采用MNIST数据集:-->官网 数据集被分成两部分:60000行的训练数据集和10000行的测试数据集. 其中每一张图片包含28*28个像素,我们把这个数组展开成一个向量,长度为2 ...

  2. LOJ2541. 「PKUWC2018」猎人杀 [概率,分治NTT]

    传送门 思路 好一个神仙题qwq 首先,发现由于一个人死之后分母会变,非常麻烦,考虑用某种方法定住分母. 我们稍微改一改游戏规则:一个人被打死时只打个标记,并不移走,也就是说可以被打多次但只算一次.容 ...

  3. bootstrap中tab切换的使用

    文档地址:https://v3.bootcss.com/javascript/#tabs 简单实例: <!DOCTYPE html> <html lang="en" ...

  4. 前端代码规范-CSS

    CSS规范 一.命名规范BEM(Block Element Modifier) 1.Block name -- 实体名称中的单词之间用连字符分隔(-) HTML <div class=" ...

  5. GFS中元数据的管理

    GFS 元数据(metadata)中包含三部分: GFS元数据的管理方式: 1.文件的命名空间和块的命名空间: 采用持久化的方式. 对于文件和块的命名空间以及从文件到块的映射:通过向操作日志登记修改而 ...

  6. 线程sleep方法的demo详解

    sleep:超时等待指定时间,时间到了之后,重新回到就绪状态,抢到CPU资源后,立马进入运行状态: package com.roocon.thread.t1; public class NewThre ...

  7. Choosing a fast unique identifier (UUID) for Lucene——有时间再看下

    Most search applications using Apache Lucene assign a unique id, or primary key, to each indexed doc ...

  8. oracle 导入导出功能

    关于expdp和impdp 使用EXPDP和IMPDP时应该注意的事项: EXP和IMP是客户端工具程序,它们既可以在客户端使用,也可以在服务端使用. - EXPDP和IMPDP是服务端的工具程序,他 ...

  9. html 获取地址栏信息

    <!DOCTYPE html> <html lang="zh-CN"> <head> <meta charset="UTF-8& ...

  10. javascript的正则表达式总结

    网上正则表达式的教程够多了,但由于javascript的历史比较悠久,也比较古老,因此有许多特性是不支持的.我们先从最简单地说起,文章所演示的正则基本都是perl方式. 元字符 ( [ { \ ^ $ ...