用pytorch搭建一个DNN网络,主要目的是熟悉pytorch的使用

"""
test Function
""" import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms class simpleNet(nn.Module):
''' define the 3 layers Network'''
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super(simpleNet, self).__init__()
self.layer1 = nn.Linear(in_dim, n_hidden_1)
self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
self.layer3 = nn.Linear(n_hidden_2, out_dim) def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x class Activation_Net(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super(Activation_Net, self).__init__()
self.layer1 = nn.Sequential(
nn.Linear(in_dim, n_hidden_1), nn.ReLU(True)
)
self.layer2 = nn.Sequential(
nn.Linear(n_hidden_1, n_hidden_2), nn.ReLU(True)
)
self.layer3 = nn.Sequential(
nn.Linear(n_hidden_2, out_dim)
) def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x class Batch_Net(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super(Batch_Net, self).__init__()
self.layer1 = nn.Sequential(
nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1) ,nn.ReLU(True)
)
self.layer2 = nn.Sequential(
nn.Linear(n_hidden_1,n_hidden_2), nn.BatchNorm1d(n_hidden_2), nn.ReLU(True)
)
self.layer3 = nn.Sequential(
nn.Linear(n_hidden_2, out_dim)
) def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x batch_size = 64
learning_rate = 1e-2
num_epochs = 20 data_tf = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
) train_dataset = datasets.MNIST(root='./data', train=True, transform=data_tf, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_tf)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) model = Batch_Net(28*28, 300, 100, 10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate) # Training
epoch = 0
for data in train_loader:
img, label = data
img = img.view(img.size(0), -1)
img = Variable(img)
label = Variable(label)
out = model(img)
loss = criterion(out, label)
print_loss = loss.data.item() optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch += 1
if epoch % 50 == 0:
print('epoch:{}, loss:{:.4f}'.format(epoch, loss.data.item())) # Evalue
model.eval() # turn the model to test pattern, do some as dropout, batchNormalization
eval_loss = 0
eval_acc = 0
for data in test_loader:
img, label = data
img = img.view(img.size(0), -1)
img = Variable(img) # 前向传播不需要保留缓存,释放掉内存,节约内存空间
label = Variable(label)
out = model(img)
loss = criterion(out, label) eval_loss += loss.data * label.size(0)
_, pred = torch.max(out, 1) # 返回每一行中最大值和对应的索引
s = (pred == label)
num_correct = (pred == label).sum()
eval_acc += num_correct.data.item()
print('Test Loss:{:6f}, Acc:{:.6f}'.format(eval_loss/len(test_dataset), eval_acc/len(test_dataset)))

pytorch-MNIST数据模型测试的更多相关文章

  1. Tensorflow MNIST 数据集测试代码入门

    本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50614444 测试代码已上传至GitH ...

  2. 深入MNIST code测试

    本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50624471 依照教程:深入MNIST ...

  3. highway network及mnist数据集测试

    先说结论:没经过仔细调参,打不开论文所说代码链接(fq也没打开),结果和普通卷积网络比较没有优势.反倒是BN对网络起着非常重要的作用,达到了99.17%的测试精度(训练轮数还没到过拟合). 论文为&l ...

  4. mxnet卷积神经网络训练MNIST数据集测试

    mxnet框架下超全手写字体识别—从数据预处理到网络的训练—模型及日志的保存 import numpy as np import mxnet as mx import logging logging. ...

  5. 如何使用Pytorch迅速实现Mnist数据及分类器

    一段时间没有更新博文,想着也该写两篇文章玩玩了.而从一个简单的例子作为开端是一个比较不错的选择.本文章会手把手地教读者构建一个简单的Mnist(Fashion-Mnist同理)的分类器,并且会使用相对 ...

  6. Caffe初试(二)windows下的cafee训练和测试mnist数据集

    一.mnist数据集 mnist是一个手写数字数据库,由Google实验室的Corinna Cortes和纽约大学柯朗研究院的Yann LeCun等人建立,它有60000个训练样本集和10000个测试 ...

  7. 使用xshell+xmanager+pycharm搭建pytorch远程调试开发环境

    1. 相关软件版本 xshell: xmanager: pycharm: pycharm破解服务器:https://jetlicense.nss.im/ 2. 将相应的软件安装(pojie好) a&g ...

  8. Pytorch学习之源码理解:pytorch/examples/mnists

    Pytorch学习之源码理解:pytorch/examples/mnists from __future__ import print_function import argparse import ...

  9. [源码解析] PyTorch 分布式(4)------分布式应用基础概念

    [源码解析] PyTorch 分布式(4)------分布式应用基础概念 目录 [源码解析] PyTorch 分布式(4)------分布式应用基础概念 0x00 摘要 0x01 基本概念 0x02 ...

随机推荐

  1. 【redis持久化】redis持久化理解

    1.以下内容仅为个人理解和总结,仅供参考,万万不可全盘真信,内容会进行实时改进和修正 2.redis持久化: 参考链接1.https://redis.io/topics/persistence  -- ...

  2. 理解Java注解类型

    一. 理解Java注解 注解本质是一个继承了Annotation的特殊接口,其具体实现类是Java运行时生成的动态代理类.而我们通过反射获取注解时,返回的是Java运行时生成的动态代理对象$Proxy ...

  3. 【ThinkPHP】ThinkPHP环境的安装与配置

    ThinkPHP是一个免费开源的,快速.简单的面向对象的轻量级PHP开发框架. 严格来说,ThinkPHP无需安装过程,这里所说的安装其实就是把ThinkPHP框架放入WEB运行环境(前提是你的WEB ...

  4. [转]Kqueue与epoll机制

    首先介绍阻塞与非阻塞:阻塞是个什么概念呢?比如某个时候你在等快递,但是你不知道快递什么时候过来,而且你没有别的事可以干(或者说接下来的事要等快递来了才能做):那么你可以去睡觉了,因为你知道快递把货送来 ...

  5. Spring Framework 5.x 学习专栏

    Spring Framework 5.0 入门篇 Spring构建REST Web Service 消费一个RESTful Web Service 事务管理 Spring使用JDBC访问关系数据 任务 ...

  6. Atitit s2018.5 s5 doc list on com pc.docx  v2

    Atitit s2018.5 s5  doc list on com pc.docx  Acc  112237553.docx Acc Acc  112237553.docx Acc baidu ne ...

  7. atitit 各分公司ceo cao行政经理职责.docx

    1.1. 人员招募--分公司高层人员招募(每月招募四五人吧,每周一人平均) 1 1.2. 组织架构优化 1 1.3. 制度建设  健全并完善分公司内部管理机构设置,优化分公司业务管理流程: 1 1.4 ...

  8. 每天学习一个命令:find 查找文件

    查找的动作在平时使用的频率也还是很高的,所以知道并用好 find 这个命令也很重要.find 命令顾名思义,就是搜索特定文件夹内的文件. 基本使用 最基本的使用 find [path] [expres ...

  9. mac中安装lua5.1.5

    lua有一个工具lua-releng( https://github.com/openresty/openresty-devel-utils/blob/master/lua-releng) 用来检测代 ...

  10. Xcode 常用代码段

    weak_shortcut /** <#注释#> */ @property(nonatomic,weak) <#class#> *<#name#>; copy_sh ...