用pytorch搭建一个DNN网络,主要目的是熟悉pytorch的使用

"""
test Function
""" import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms class simpleNet(nn.Module):
''' define the 3 layers Network'''
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super(simpleNet, self).__init__()
self.layer1 = nn.Linear(in_dim, n_hidden_1)
self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
self.layer3 = nn.Linear(n_hidden_2, out_dim) def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x class Activation_Net(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super(Activation_Net, self).__init__()
self.layer1 = nn.Sequential(
nn.Linear(in_dim, n_hidden_1), nn.ReLU(True)
)
self.layer2 = nn.Sequential(
nn.Linear(n_hidden_1, n_hidden_2), nn.ReLU(True)
)
self.layer3 = nn.Sequential(
nn.Linear(n_hidden_2, out_dim)
) def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x class Batch_Net(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super(Batch_Net, self).__init__()
self.layer1 = nn.Sequential(
nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1) ,nn.ReLU(True)
)
self.layer2 = nn.Sequential(
nn.Linear(n_hidden_1,n_hidden_2), nn.BatchNorm1d(n_hidden_2), nn.ReLU(True)
)
self.layer3 = nn.Sequential(
nn.Linear(n_hidden_2, out_dim)
) def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x batch_size = 64
learning_rate = 1e-2
num_epochs = 20 data_tf = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
) train_dataset = datasets.MNIST(root='./data', train=True, transform=data_tf, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_tf)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) model = Batch_Net(28*28, 300, 100, 10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate) # Training
epoch = 0
for data in train_loader:
img, label = data
img = img.view(img.size(0), -1)
img = Variable(img)
label = Variable(label)
out = model(img)
loss = criterion(out, label)
print_loss = loss.data.item() optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch += 1
if epoch % 50 == 0:
print('epoch:{}, loss:{:.4f}'.format(epoch, loss.data.item())) # Evalue
model.eval() # turn the model to test pattern, do some as dropout, batchNormalization
eval_loss = 0
eval_acc = 0
for data in test_loader:
img, label = data
img = img.view(img.size(0), -1)
img = Variable(img) # 前向传播不需要保留缓存,释放掉内存,节约内存空间
label = Variable(label)
out = model(img)
loss = criterion(out, label) eval_loss += loss.data * label.size(0)
_, pred = torch.max(out, 1) # 返回每一行中最大值和对应的索引
s = (pred == label)
num_correct = (pred == label).sum()
eval_acc += num_correct.data.item()
print('Test Loss:{:6f}, Acc:{:.6f}'.format(eval_loss/len(test_dataset), eval_acc/len(test_dataset)))

pytorch-MNIST数据模型测试的更多相关文章

  1. Tensorflow MNIST 数据集测试代码入门

    本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50614444 测试代码已上传至GitH ...

  2. 深入MNIST code测试

    本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50624471 依照教程:深入MNIST ...

  3. highway network及mnist数据集测试

    先说结论:没经过仔细调参,打不开论文所说代码链接(fq也没打开),结果和普通卷积网络比较没有优势.反倒是BN对网络起着非常重要的作用,达到了99.17%的测试精度(训练轮数还没到过拟合). 论文为&l ...

  4. mxnet卷积神经网络训练MNIST数据集测试

    mxnet框架下超全手写字体识别—从数据预处理到网络的训练—模型及日志的保存 import numpy as np import mxnet as mx import logging logging. ...

  5. 如何使用Pytorch迅速实现Mnist数据及分类器

    一段时间没有更新博文,想着也该写两篇文章玩玩了.而从一个简单的例子作为开端是一个比较不错的选择.本文章会手把手地教读者构建一个简单的Mnist(Fashion-Mnist同理)的分类器,并且会使用相对 ...

  6. Caffe初试(二)windows下的cafee训练和测试mnist数据集

    一.mnist数据集 mnist是一个手写数字数据库,由Google实验室的Corinna Cortes和纽约大学柯朗研究院的Yann LeCun等人建立,它有60000个训练样本集和10000个测试 ...

  7. 使用xshell+xmanager+pycharm搭建pytorch远程调试开发环境

    1. 相关软件版本 xshell: xmanager: pycharm: pycharm破解服务器:https://jetlicense.nss.im/ 2. 将相应的软件安装(pojie好) a&g ...

  8. Pytorch学习之源码理解:pytorch/examples/mnists

    Pytorch学习之源码理解:pytorch/examples/mnists from __future__ import print_function import argparse import ...

  9. [源码解析] PyTorch 分布式(4)------分布式应用基础概念

    [源码解析] PyTorch 分布式(4)------分布式应用基础概念 目录 [源码解析] PyTorch 分布式(4)------分布式应用基础概念 0x00 摘要 0x01 基本概念 0x02 ...

随机推荐

  1. 【Android自己定义控件】圆圈交替,仿progress效果

    还是我们自定View的那几个步骤: 1.自己定义View的属性 2.在View的构造方法中获得我们自己定义的属性 3.重写onMesure (不是必须) 4.重写onDraw 自己定义View的属性 ...

  2. 【RS】Collaborative Memory Network for Recommendation Systems - 基于协同记忆网络的推荐系统

    [论文标题]Collaborative Memory Network for Recommendation Systems    (SIGIR'18) [论文作者]—Travis Ebesu (San ...

  3. 解决windows 下 mysql命令行导入备份文件 查询时乱码的问题

    Mysql导入乱码,一般在命令行会遇到.下面说的是命令行的情况下解决乱码问题: 方法一: 通过增加参数 –default-character-set = utf8 解决乱码问题 mysql -uroo ...

  4. 城市经纬度 json 理解SignalR Main(string[] args)之args传递的几种方式 串口编程之端口 多线程详细介绍 递归一个List<T>,可自己根据需要改造为通用型。 Sql 优化解决方案

    城市经纬度 json https://www.cnblogs.com/innershare/p/10723968.html 理解SignalR ASP .NET SignalR 是一个ASP .NET ...

  5. React Native 从入门到原理一

    React Native 从入门到原理一 React Native 是最近非常火的一个话题,介绍如何利用 React Native 进行开发的文章和书籍多如牛毛,但面向入门水平并介绍它工作原理的文章却 ...

  6. 8款基于Jquery的WEB前端动画特效

    1.超炫酷的30个jQuery按钮悬停动画 按钮插件是最常见的jQuery插件之一,因为它用途广泛,而且配置起来最为方便.今天我们要分享的是30个超炫酷的jQuery悬停按钮动画,当我们将鼠标滑过按钮 ...

  7. r里面如何实现两列数据合并为一列

    library(dplyr) unite(mtcars, "vs_am", vs, am) Merging Data Adding Columns To merge two dat ...

  8. VirtualBox通过Host-Only网络连接方式实现宿主机与虚拟机通信

    适用情况 (1)没有联网, 不插网线 (2)宿主机直接连接宽带(无路由器) 情景: 宿主机 Windows 7 虚拟机 Windows XP 虚拟机安装了SQLServer2005,宿主机想连接使用虚 ...

  9. 使用 maven 自动将源码打包并发布

    1.maven-source-plugin 访问地址 在 pom.xml 中添加 下面的 内容,可以 使用 maven 生成 jar 的同时 生成 sources 包 <plugin> & ...

  10. Linux下 编译lib3ds库

    从网上下载的一个QT程序链接需要用到lib3ds.a静态库. lib3ds is an overall software library for managing 3D-Studio Release ...