pytorch-mnist神经网络训练
在net.py里面构造网络,网络的结构为输入为28*28,第一层隐藏层的输出为300, 第二层输出的输出为100, 最后一层的输出层为10,
net.py
import torch
from torch import nn class Batch_Net(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super(Batch_Net, self).__init__()
self.layer_1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1), nn.ReLU(True))
self.layer_2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2), nn.ReLU(True))
self.output = nn.Sequential(nn.Linear(n_hidden_2, out_dim)) def forward(self, x):
x = self.layer_1(x)
x = self.layer_2(x)
x = self.output(x)
return x
main.py 进行网络的训练
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms import net batch_size = 128 # 每一个batch_size的大小
learning_rate = 1e-2 # 学习率的大小
num_epoches = 20 # 迭代的epoch值
# 表示data将数据变成0, 1之间,0.5, 0.5表示减去均值处以标准差
data_tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) # 表示均值和标准差
# 获得训练集的数据
train_dataset = datasets.MNIST(root='./data', train=True, transform=data_tf, download=True)
# 获得测试集的数据
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_tf, download=True)
# 获得训练集的可迭代队列
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 获得测试集的可迭代队列
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 构造模型的网络
model = net.Batch_Net(28*28, 300, 100, 10)
if torch.cuda.is_available(): # 如果有cuda就将模型放在GPU上
model.cuda() criterion = nn.CrossEntropyLoss() # 构造交叉损失函数
optimizer = optim.SGD(model.parameters(), lr=learning_rate) # 构造模型的优化器 for epoch in range(num_epoches): # 迭代的epoch
train_loss = 0 # 训练的损失值
test_loss = 0 # 测试的损失值
eval_acc = 0 # 测试集的准确率
for data in train_loader: # 获得一个batch的样本
img, label = data # 获得图片和标签
img = img.view(img.size(0), -1) # 将图片进行img的转换
if torch.cuda.is_available(): # 如果存在torch
img = Variable(img).cuda() # 将图片放在torch上
label = Variable(label).cuda() # 将标签放在torch上
else:
img = Variable(img) # 构造img的变量
label = Variable(label)
optimizer.zero_grad() # 消除optimizer的梯度
out = model.forward(img) # 进行前向传播
loss = criterion(out, label) # 计算损失值
loss.backward() # 进行损失值的后向传播
optimizer.step() # 进行优化器的优化
train_loss += loss.data #
for data in test_loader:
img, label = data
img = img.view(img.size(0), -1)
if torch.cuda.is_available():
img = Variable(img, volatile=True).cuda()
label = Variable(label, volatile=True).cuda()
else:
img = Variable(img, volatile=True)
label = Variable(label, volatile=True)
out = model.forward(img)
loss = criterion(out, label)
test_loss += loss.data
top_p, top_class = out.topk(1, dim=1) # 获得输出的每一个样本的最大损失
equals = top_class == label.view(*top_class.shape) # 判断两组样本的标签是否相等
accuracy = torch.mean(equals.type(torch.FloatTensor)) # 计算准确率
eval_acc += accuracy
print('train_loss{:.6f}, test_loss{:.6f}, Acc:{:.6f}'.format(train_loss / len(train_loader), test_loss / len(test_loader), eval_acc / len(test_loader)))
pytorch-mnist神经网络训练的更多相关文章
- tensorflow中使用mnist数据集训练全连接神经网络-学习笔记
tensorflow中使用mnist数据集训练全连接神经网络 ——学习曹健老师“人工智能实践:tensorflow笔记”的学习笔记, 感谢曹老师 前期准备:mnist数据集下载,并存入data目录: ...
- Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)
Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...
- PyTorch Tutorials 4 训练一个分类器
%matplotlib inline 训练一个分类器 上一讲中已经看到如何去定义一个神经网络,计算损失值和更新网络的权重. 你现在可能在想下一步. 关于数据? 一般情况下处理图像.文本.音频和视频数据 ...
- Pytorch多GPU训练
Pytorch多GPU训练 临近放假, 服务器上的GPU好多空闲, 博主顺便研究了一下如何用多卡同时训练 原理 多卡训练的基本过程 首先把模型加载到一个主设备 把模型只读复制到多个设备 把大的batc ...
- 使用pytorch构建神经网络的流程以及一些问题
使用PyTorch构建神经网络十分的简单,下面是我总结的PyTorch构建神经网络的一般过程以及我在学习当中遇到的一些问题,期望对你有所帮助. PyTorch构建神经网络的一般过程 下面的程序是PyT ...
- Caffe系列4——基于Caffe的MNIST数据集训练与测试(手把手教你使用Lenet识别手写字体)
基于Caffe的MNIST数据集训练与测试 原创:转载请注明https://www.cnblogs.com/xiaoboge/p/10688926.html 摘要 在前面的博文中,我详细介绍了Caf ...
- 使用PyTorch构建神经网络以及反向传播计算
使用PyTorch构建神经网络以及反向传播计算 前一段时间南京出现了疫情,大概原因是因为境外飞机清洁处理不恰当,导致清理人员感染.话说国外一天不消停,国内就得一直严防死守.沈阳出现了一例感染人员,我在 ...
- 基于 PyTorch 和神经网络给 GirlFriend 制作漫画风头像
摘要:本文中我们介绍的 AnimeGAN 就是 GitHub 上一款爆火的二次元漫画风格迁移工具,可以实现快速的动画风格迁移. 本文分享自华为云社区<AnimeGANv2 照片动漫化:如何基于 ...
- 神经网络训练中的Tricks之高效BP(反向传播算法)
神经网络训练中的Tricks之高效BP(反向传播算法) 神经网络训练中的Tricks之高效BP(反向传播算法) zouxy09@qq.com http://blog.csdn.net/zouxy09 ...
- 从零到一:caffe-windows(CPU)配置与利用mnist数据集训练第一个caffemodel
一.前言 本文会详细地阐述caffe-windows的配置教程.由于博主自己也只是个在校学生,目前也写不了太深入的东西,所以准备从最基础的开始一步步来.个人的计划是分成配置和运行官方教程,利用自己的数 ...
随机推荐
- 1 java 笔记
第一java的版本: J2ME主要用于移动设备和信息家电 J2SE整个Java技术的核心 J2EE java技术应用最广泛的部分,主要应用与企业的开发 第二:基于java语言的开源框架 struts ...
- 4.AOP原理模拟
AOP Aspect-Oriented-Programming 面向切面编程 a)是对面向对象的思维方式的有力补充 好处:可以动态的添加和删除在切面上的逻辑而不影响原来的执行代码 a)Fil ...
- MySQL的sql_mode参数之NO_AUTO_VALUE_ON_ZERO对主键ID为0的记录影响
最近遇到一个不合理使用数据库进行项目开发最终导致项目进度受阻的一个问题,某天几位开发人员找到我并告知数据库中某张表数据无法写入,又告知某行记录被删除了,因为被删除的记录对开发框架影响很大,他们已尝试重 ...
- Delphi 类的方法
- 说说lock到底锁谁(I)?
写在前面 最近一个月一直在弄文件传输组件,其中用到多线程的技术,但有的地方确实需要只能有一个线程来操作,如何才能保证只有一个线程呢?首先想到的就是锁的概念,最近在我们项目组中听的最多的也是锁谁,如何锁 ...
- P1231 教辅的组成 拆点限流
如果只有两个物品的话 是一个裸的二分图匹配问题 现在变成了三个物品之间的匹配 则只要在中间加一层节点表示书 再把这层的每个点拆成两个点中间连一条边限制流量 使其只能用一次 #include<io ...
- python豆知识: for和while的else语句。
for语句,当可迭代对象耗尽后执行else语句. while循环,当条件为False后执行else. a = 1 while a != 10: a += 1 else: print(a)
- P1582 倒水 题解
来水一发水题.. 题目链接. 正解开始: 首先,我们根据题意,可以得知这是一个有关二进制的题目: 具体什么关系,怎么做,我们来具体分析: 对于每个n,我们尝试将其二进制分解,也就是100101之类的形 ...
- jQuery基础 (一)—样式篇
jQuery的优势 jQuery有很多特性和工具方法
- Python 读取txt文件,排序并写回文件
# 'C:\Users\SAM\Desktop\数据竞赛\个人征信_1108\个人征信\train\bank_detail_train.txt'# 反斜杠的写法会报编码错误f=open('C:/Use ...