pytorch-mnist神经网络训练
在net.py里面构造网络,网络的结构为输入为28*28,第一层隐藏层的输出为300, 第二层输出的输出为100, 最后一层的输出层为10,
net.py
import torch
from torch import nn class Batch_Net(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super(Batch_Net, self).__init__()
self.layer_1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1), nn.ReLU(True))
self.layer_2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2), nn.ReLU(True))
self.output = nn.Sequential(nn.Linear(n_hidden_2, out_dim)) def forward(self, x):
x = self.layer_1(x)
x = self.layer_2(x)
x = self.output(x)
return x
main.py 进行网络的训练
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms import net batch_size = 128 # 每一个batch_size的大小
learning_rate = 1e-2 # 学习率的大小
num_epoches = 20 # 迭代的epoch值
# 表示data将数据变成0, 1之间,0.5, 0.5表示减去均值处以标准差
data_tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) # 表示均值和标准差
# 获得训练集的数据
train_dataset = datasets.MNIST(root='./data', train=True, transform=data_tf, download=True)
# 获得测试集的数据
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_tf, download=True)
# 获得训练集的可迭代队列
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 获得测试集的可迭代队列
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 构造模型的网络
model = net.Batch_Net(28*28, 300, 100, 10)
if torch.cuda.is_available(): # 如果有cuda就将模型放在GPU上
model.cuda() criterion = nn.CrossEntropyLoss() # 构造交叉损失函数
optimizer = optim.SGD(model.parameters(), lr=learning_rate) # 构造模型的优化器 for epoch in range(num_epoches): # 迭代的epoch
train_loss = 0 # 训练的损失值
test_loss = 0 # 测试的损失值
eval_acc = 0 # 测试集的准确率
for data in train_loader: # 获得一个batch的样本
img, label = data # 获得图片和标签
img = img.view(img.size(0), -1) # 将图片进行img的转换
if torch.cuda.is_available(): # 如果存在torch
img = Variable(img).cuda() # 将图片放在torch上
label = Variable(label).cuda() # 将标签放在torch上
else:
img = Variable(img) # 构造img的变量
label = Variable(label)
optimizer.zero_grad() # 消除optimizer的梯度
out = model.forward(img) # 进行前向传播
loss = criterion(out, label) # 计算损失值
loss.backward() # 进行损失值的后向传播
optimizer.step() # 进行优化器的优化
train_loss += loss.data #
for data in test_loader:
img, label = data
img = img.view(img.size(0), -1)
if torch.cuda.is_available():
img = Variable(img, volatile=True).cuda()
label = Variable(label, volatile=True).cuda()
else:
img = Variable(img, volatile=True)
label = Variable(label, volatile=True)
out = model.forward(img)
loss = criterion(out, label)
test_loss += loss.data
top_p, top_class = out.topk(1, dim=1) # 获得输出的每一个样本的最大损失
equals = top_class == label.view(*top_class.shape) # 判断两组样本的标签是否相等
accuracy = torch.mean(equals.type(torch.FloatTensor)) # 计算准确率
eval_acc += accuracy
print('train_loss{:.6f}, test_loss{:.6f}, Acc:{:.6f}'.format(train_loss / len(train_loader), test_loss / len(test_loader), eval_acc / len(test_loader)))
pytorch-mnist神经网络训练的更多相关文章
- tensorflow中使用mnist数据集训练全连接神经网络-学习笔记
tensorflow中使用mnist数据集训练全连接神经网络 ——学习曹健老师“人工智能实践:tensorflow笔记”的学习笔记, 感谢曹老师 前期准备:mnist数据集下载,并存入data目录: ...
- Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)
Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...
- PyTorch Tutorials 4 训练一个分类器
%matplotlib inline 训练一个分类器 上一讲中已经看到如何去定义一个神经网络,计算损失值和更新网络的权重. 你现在可能在想下一步. 关于数据? 一般情况下处理图像.文本.音频和视频数据 ...
- Pytorch多GPU训练
Pytorch多GPU训练 临近放假, 服务器上的GPU好多空闲, 博主顺便研究了一下如何用多卡同时训练 原理 多卡训练的基本过程 首先把模型加载到一个主设备 把模型只读复制到多个设备 把大的batc ...
- 使用pytorch构建神经网络的流程以及一些问题
使用PyTorch构建神经网络十分的简单,下面是我总结的PyTorch构建神经网络的一般过程以及我在学习当中遇到的一些问题,期望对你有所帮助. PyTorch构建神经网络的一般过程 下面的程序是PyT ...
- Caffe系列4——基于Caffe的MNIST数据集训练与测试(手把手教你使用Lenet识别手写字体)
基于Caffe的MNIST数据集训练与测试 原创:转载请注明https://www.cnblogs.com/xiaoboge/p/10688926.html 摘要 在前面的博文中,我详细介绍了Caf ...
- 使用PyTorch构建神经网络以及反向传播计算
使用PyTorch构建神经网络以及反向传播计算 前一段时间南京出现了疫情,大概原因是因为境外飞机清洁处理不恰当,导致清理人员感染.话说国外一天不消停,国内就得一直严防死守.沈阳出现了一例感染人员,我在 ...
- 基于 PyTorch 和神经网络给 GirlFriend 制作漫画风头像
摘要:本文中我们介绍的 AnimeGAN 就是 GitHub 上一款爆火的二次元漫画风格迁移工具,可以实现快速的动画风格迁移. 本文分享自华为云社区<AnimeGANv2 照片动漫化:如何基于 ...
- 神经网络训练中的Tricks之高效BP(反向传播算法)
神经网络训练中的Tricks之高效BP(反向传播算法) 神经网络训练中的Tricks之高效BP(反向传播算法) zouxy09@qq.com http://blog.csdn.net/zouxy09 ...
- 从零到一:caffe-windows(CPU)配置与利用mnist数据集训练第一个caffemodel
一.前言 本文会详细地阐述caffe-windows的配置教程.由于博主自己也只是个在校学生,目前也写不了太深入的东西,所以准备从最基础的开始一步步来.个人的计划是分成配置和运行官方教程,利用自己的数 ...
随机推荐
- sql server 语句书写注意事项
1 Between在某些时候比IN 2 在必要是对全局或者局部临时表创建索引,有时能够提高速度,但不是一定会这样,因为索引也耗费大量的资源.他的创建同是实际表一样 3 尽量少用视图,它的效率低.对视 ...
- 基于SAML2.0的SAP云产品Identity Authentication过程介绍
SAP官网的架构图 https://cloudplatform.sap.com/scenarios/usecases/authentication.html 上图介绍了用户访问SAP云平台时经历的Au ...
- C语言编译过程及相关文件
1,C程序编译步骤 C代码编译成可执行程序经过4步: 1)预处理:宏定义展开.头文件展开.条件编译等,同时将代码中的注释删除,这里并不会检查语法 2)编译:检查语法,将预处理后文件编译生成汇编文件 3 ...
- 剖析ajax
学过javascript和接触过后端PHP语言必然要用到ajax,这是必学的一门学科,AJAX指的是Asynchronous JavaScript and XML,它使用XMLHttpRequest对 ...
- Codeforces 920E-Connected Components? (set,补图,连通块)
Connected Components? CodeForces - 920E You are given an undirected graph consisting of n vertices a ...
- 关于网站子目录绑定二级域名的方法(php网站手机端)
最近帮客户做zencart网站手机模板用到了二级域名,通过判断手机访问来调用二级目录程序,http://afish.cnblogs.com/ 怎么说都比 http://www.cnblogs.com/ ...
- github仓库管理项目
一,建立本地git仓库 首先,git要求使用者必须提供自己的身份标识,为此我们需要在git bash中执行以下命令: git config --global user.name 'aa.Tessst ...
- Java&Selenium&TestNG&ZTestReport 自动化测试并生成HTML自动化测试报告
一.摘要 本篇博文将介绍如何借助ZTestReport和HTML模版,生成HTML测试报告的ZTestReport 源码Clone地址为 https://github.com/zhangfei1984 ...
- 生成器(generator) 详解
1. 生成器是什么? 利用迭代器,我们可以在每次迭代获取数据(通过next()方法)时按照特定的规律进行生成.但是我们在实现一个迭代器时,关于当前迭代到的状态需要我们自己记录,进而才能根据当前状态生成 ...
- Zabbix Web 中文字体显示问题