# ====================LeNet-5_main.py===============
# pytorch+torchvision+visdom
 # -*- coding: utf-8 -*-
"""
Created on Sun May 26 22:53:52 2019 @author: jiangshan
"""
#A modified LeNet-5 [LeCun et al., 1998a] on the MNIST dataset.
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets.mnist import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import visdom
from collections import OrderedDict class LeNet5(nn.Module):
"""
Input - 1x32x32
C1 - 6@28x28 (5x5 kernel)
relu
S2 - 6@14x14 (2x2 kernel, stride 2) Subsampling
C3 - 16@10x10 (5x5 kernel, complicated shit)
relu
S4 - 16@5x5 (2x2 kernel, stride 2) Subsampling
C5 - 120@1x1 (5x5 kernel)
F6 - 84
relu
F7 - 10 (Output)
"""
def __init__(self):
super(LeNet5, self).__init__() self.convnet = nn.Sequential(OrderedDict([
('c1', nn.Conv2d(1, 6, kernel_size=(5, 5))),
('relu1', nn.ReLU()),
('s2', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
('c3', nn.Conv2d(6, 16, kernel_size=(5, 5))),
('relu3', nn.ReLU()),
('s4', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
('c5', nn.Conv2d(16, 120, kernel_size=(5, 5))),
('relu5', nn.ReLU())
])) self.fc = nn.Sequential(OrderedDict([
('f6', nn.Linear(120, 84)),
('relu6', nn.ReLU()),
('f7', nn.Linear(84, 10)),
('sig7', nn.LogSoftmax(dim=-1))
])) def forward(self, img):
output = self.convnet(img)
output = output.view(img.size(0), -1)
output = self.fc(output)
return output viz = visdom.Visdom()
data_train = MNIST('./data/mnist',
download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()]))
data_test = MNIST('./data/mnist',
train=False,
download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()]))
data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8) net = LeNet5()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=2e-3) cur_batch_win = None
cur_batch_win_opts = {
'title': 'Epoch Loss Trace',
'xlabel': 'Batch Number',
'ylabel': 'Loss',
'width': 1200,
'height': 600,
} def train(epoch):
global cur_batch_win
net.train()
loss_list, batch_list = [], []
for i, (images, labels) in enumerate(data_train_loader):
optimizer.zero_grad() output = net(images) loss = criterion(output, labels) loss_list.append(loss.detach().cpu().item())
batch_list.append(i+1) if i % 10 == 0:
print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.detach().cpu().item())) # Update Visualization
if viz.check_connection():
cur_batch_win = viz.line(torch.Tensor(loss_list), torch.Tensor(batch_list),
win=cur_batch_win, name='current_batch_loss',
update=(None if cur_batch_win is None else 'replace'),
opts=cur_batch_win_opts)
loss.backward()
optimizer.step() def test():
net.eval()
total_correct = 0
avg_loss = 0.0
for i, (images, labels) in enumerate(data_test_loader):
output = net(images)
avg_loss += criterion(output, labels).sum()
pred = output.detach().max(1)[1]
total_correct += pred.eq(labels.view_as(pred)).sum() avg_loss /= len(data_test)
print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test))) def train_and_test(epoch):
train(epoch)
test() def main():
for e in range(1, 16):
train_and_test(e) if __name__ == '__main__':
main()

先开启visdom 进行可视化

python -m visdom.server

运行程序

python LeNet-5_main.py

打开浏览器查看live graph

http://localhost:8097

LeNet-5 pytorch+torchvision+visdom的更多相关文章

  1. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  2. Linux服务器配置GPU版本的pytorch Torchvision TensorFlow

    最近在Linux服务器上配置项目,项目需要使用GPU版本的pytorch和TensorFlow,而且该项目内会同时使用TensorFlow的GPU和CPU. 在服务器上装环境,如果重新开始,就需要下载 ...

  3. pytorch的visdom启动不了、蓝屏

    pytorch的visdom启动不了.蓝屏 问题描述:我是在ubuntu16.04上用python3.5安装的visdom.可是启动是蓝屏:在网上找了很久的解决方案:有三篇博文: https://bl ...

  4. 云服务器搭建anaconda pytorch torchvision

    (因为在普通用户上安装有些权限问题安装出错,所以我在root用户下相对容易安装,但是anaconda官网说可以直接在普通用户下安装,不过,在root下安装,其他用户也是能用的. 访问Anaconda官 ...

  5. Pytorch Torchvision Transform

    Torchvision.Transforms Transforms包含常用图像转换操作.可以使用Compose将它们链接在一起. 此外,还有torchvision.transforms.functio ...

  6. pytorch torchvision.ImageFolder的使用

    参考:https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/ torchvision.dataset ...

  7. pytorch torchvision对图像进行变换

    class torchvision.transforms.Compose(转换) 多个将transform组合起来使用. class torchvision.transforms.CenterCrop ...

  8. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上

    总结一下相关概念: torch.Tensor - 一个近似多维数组的数据结构 autograd.Variable - 改变Tensor并且记录下来操作的历史记录.和Tensor拥有相同的API,以及b ...

  9. PyTorch常用代码段整理合集

    PyTorch常用代码段整理合集 转自:知乎 作者:张皓 众所周知,程序猿在写代码时通常会在网上搜索大量资料,其中大部分是代码段.然而,这项工作常常令人心累身疲,耗费大量时间.所以,今天小编转载了知乎 ...

随机推荐

  1. 【优化算法】变邻域搜索算法解决0-1背包问题(Knapsack Problem)代码实例 已

    01 前言 经过小编这几天冒着挂科的风险,日日修炼,终于赶在考试周中又给大家更新了一篇干货文章.关于用变邻域搜索解决0-1背包问题的代码.怎样,大家有没有很感动? 02 什么是0-1背包问题? 0-1 ...

  2. ros平台下python脚本控制机械臂运动

    在使用moveit_setup_assistant生成机械臂的配置文件后可以使用roslaunch demo.launch启动demo,在rviz中可以通过拖动机械臂进行运动学正逆解/轨迹规划等仿真运 ...

  3. maven 安装 jar 包

    注意 jar 包路径,和版本号 mvn install:install-file -Dfile=E:\ChromeBox\box\kaptcha-2.3.2\kaptcha-2.3.2.jar -Dg ...

  4. 基于Shiro的登录功能 设计思路

    认证流程 Shiro的认证流程可以看作是个“有窗户黑盒”, 整个流程都有框架控制,对外的入口只有subject.login(token);,这代表“黑盒” 流程中的每一个组件,都可以使用Spring ...

  5. mongodb的状态分析

    1.借助工具 mongostat 分析mongodb运行状况 C:\Users\Administrator>mongostat --help //查看帮助 View live MongoDB p ...

  6. 【java8新特性】日期和时间

    Java 8 (又称为 jdk 1.8) 是 Java 语言开发的一个主要版本. Oracle 公司于 2014 年 3 月 18 日发布 Java 8 ,它支持函数式编程,新的 JavaScript ...

  7. python 时间等待

    #coding=utf- import time t1=time.time() time.sleep() t2=time.time() print(t2-t1) 输出 3.00304102898

  8. 去掉BigDecimal类型变量小数点后多余的零

           业务背景:mysql中A表中的B字段的类型是decimal类型,小数位数是三位,某一条数据的值是3000000,在Java中查询出来的结果是3000000.000,这样显示在页面中不太好 ...

  9. Tkinter 之ListBox列表标签

    一.参数说明 参数 作用 background (bg) 设置背景颜色 borderwidth (bd) 指定 Listbox 的边框宽度,通常是 2 像素 cursor  指定当鼠标在 Listbo ...

  10. 可视化图表库--goJS

    GoJS是Northwoods Software的产品.Northwoods Software创立于1995年,专注于交互图控件和类库.旗下四款产品: GoJS:用于在HTML上创建交互图的纯java ...