使用pytorch构建神经网络的流程以及一些问题
使用PyTorch构建神经网络十分的简单,下面是我总结的PyTorch构建神经网络的一般过程以及我在学习当中遇到的一些问题,期望对你有所帮助。
PyTorch构建神经网络的一般过程
下面的程序是PyTorch官网60分钟教程上面构建神经网络的例子,版本0.4.1:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# 第一步:准备数据
# Compose是将两个转换的过程组合起来,ToTensor将numpy等数据类型转换为Tensor,将值变为0到1之间
# Normalize用公式(input-mean)/std 将值进行变换。这里mean=0.5,std=0.5,是将[0,1]区间转换为[-1,1]区间
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# trainloader 是一个将数据集和采样策略结合起来的,并提供在数据集上面迭代的方法
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 第二步:构建神经网络框架,继承nn.Module类
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
# 第三步:进行训练
# 定义损失策略和优化方法
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 训练神经网络
for epoch in range(4):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
# 训练过程1:前向过程,计算输入到输出的结果
outputs = net(inputs)
# 训练过程2:由结果和label计算损失
loss = criterion(outputs, labels)
# 训练过程3:在图的层次上面计算所有变量的梯度
# 每次计算梯度的时候,其实是有一个动态的图在里面的,求导数就是对图中的参数w进行求导的过程
# 每个参数计算的梯度值保存在w.grad.data上面,在参数更新时使用
loss.backward()
# 训练过程4:进行参数的更新
# optimizer不计算梯度,它利用已经计算好的梯度值对参数进行更新
optimizer.step()
running_loss += loss.item() # item 返回的是一个数字
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' %
(epoch+1, i+1, running_loss/2000))
running_loss = 0.0
print('Finished Training')
# 第四步:在测试集上面进行测试
total = 0
correct = 0
with torch.no_grad():
for data in testloader:
images, label = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == label).sum().item()
print("Accuracy of networkd on the 10000 test images: %d %%" % (100*correct/total))
这个例子说明了构建神经网络的四个步骤:1:准备数据集 。2:构建神经网络框架,实现神经网络的类。 3:在训练集上进行训练。 4:在测试集上面进行测试。
而在第三步的训练阶段,也可以分为四个步骤:1:前向过程,计算输入到输出的结果。2:由结果和labels计算损失。3:后向过程,由损失计算各个变量的梯度。4:优化器根据梯度进行参数的更新。
训练过程中第loss和optim是怎么联系在一起的
loss是训练阶段的第三步,计算参数的梯度。optim是训练阶段的第四步,对参数进行更新。在optimizer初始化的时候,optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9),获取了参数的指针,可以对参数进行修改。当loss计算好参数的梯度以后,把值放在参数w.grad.data上面,然后optimizer直接利用这个值对参数进行更新。
以SGD为例,它进行step的时候的基本操作是这样的: p.data.add_(-group['lr'], d_p),其中 d_p = p.grad.data
为什么要进行梯度清零
在backward每次计算梯度的时候,会将新的梯度值加到原来旧的梯度值上面,这叫做梯度累加。下面的程序可以说明什么是梯度累加:
import torch
x = torch.rand(2, requires_grad=True)
y = x.mean() # y = (x_1 + x_2) / 2 所以求梯度后应是0.5
y.backward()
print(x.grad.data) # 输出结果:tensor([0.5000, 0.5000])
y.backward()
print(x.grad.data) # 输出结果:tensor([1., 1.]) 说明进行了梯度累积
求解梯度过程和参数更新过程是分开的,这对于那些需要多次求导累计梯度,然后一次更新的神经网络可能是有帮助的,比如RNN,对于DNN和CNN不需要进行梯度累加,所以需要进行梯度清零。
如何使用GPU进行训练
旧版本:
use_cuda = True if torch.cuda.is_available() else False # 是否使用cuda
if use_cuda:
model = model.cuda() # 将模型的参数放入GPU
if use_cuda:
inputs, labels = inputs.cuda(), labels.cuda() # 将数据放入到GPU
0.4版本以后推荐新方法 to(device),
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device) #将模型的参数放入GPU中
inputs, labels = inputs.to(device), labels.to(device) # 将数据放入到GPU中
参考:
Pytorch内部中optim和loss是如何交互的? - 罗若天的回答 - 知乎
pytorch学习笔记(二):gradient
使用pytorch构建神经网络的流程以及一些问题的更多相关文章
- 使用PyTorch构建神经网络以及反向传播计算
使用PyTorch构建神经网络以及反向传播计算 前一段时间南京出现了疫情,大概原因是因为境外飞机清洁处理不恰当,导致清理人员感染.话说国外一天不消停,国内就得一直严防死守.沈阳出现了一例感染人员,我在 ...
- 使用PyTorch构建神经网络模型进行手写识别
使用PyTorch构建神经网络模型进行手写识别 PyTorch是一种基于Torch库的开源机器学习库,应用于计算机视觉和自然语言处理等应用,本章内容将从安装以及通过Torch构建基础的神经网络,计算梯 ...
- TFLearn构建神经网络
TFLearn构建神经网络 Building the network TFLearn lets you build the network by defining the layers. Input ...
- 在IDEA中构建Tomcat项目流程
在IDEA中构建Web项目流程 打开你的IDEA,跟着我走! 第一步:新建项目 第二步:找到Artifacts 点击绿色的+号,如图所示,点一下 这一步很关键,目的是设置输出格式为war包,如果你的项 ...
- pytorch构建自己的数据集
现在需要在json文件里面读取图片的URL和label,这里面可能会出现某些URL地址无效的情况. python读取json文件 此处只需要将json文件里面的内容读取出来就可以了 with open ...
- tensorflow之神经网络实现流程总结
tensorflow之神经网络实现流程总结 1.数据预处理preprocess 2.前向传播的神经网络搭建(包括activation_function和层数) 3.指数下降的learning_rate ...
- Tensorflow BatchNormalization详解:2_使用tf.layers高级函数来构建神经网络
Batch Normalization: 使用tf.layers高级函数来构建神经网络 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 吴恩达deeplearningai课程 课程笔 ...
- GeneXus DevOps 自动化构建和部署流程
以下视频详细介绍了GeneXus DevOps自动化构建和部署流程,包括通过MS Bulid来管理自动化流程,自动化的架构,以及在GeneXus Server上使用Jenkins做为自动化引擎. 视频 ...
- 使用 Visual Studio 2015 + Python3.6 + tensorflow 构建神经网络时报错:'utf-8' codec can't decode byte 0xcc in position 78: invalid continuation byte
使用 Visual Studio 2015 + Python3.6 + tensorflow 构建神经网络时报错:'utf-8' codec can't decode byte 0xcc in pos ...
随机推荐
- BFS简单题套路_Codevs 1215 迷宫
BFS 简单题套路 1. 遇到迷宫之类的简单题,有什么行走方向的,先写下面的 声明 ; struct Status { int r, c; Status(, ) : r(r), c(c) {} // ...
- Jquery自定义滚动条插件
下载地址:http://files.cnblogs.com/files/LoveOrHate/jquery.nicescroll.min.js <script src="jquery. ...
- [转载]Markdown——入门指南
http://www.jianshu.com/p/1e402922ee32/ 转载请注明原作者,如果你觉得这篇文章对你有帮助或启发,也可以来请我喝咖啡. 导语: Markdown 是一种轻量级的「标记 ...
- ODPS_ele—UDF Python API
自定义函数(UDF) UDF全称User Defined Function,即用户自定义函数.ODPS提供了很多内建函数来满足用户的计算需求,同时用户还可以通过创建自定义函数来满足不同的计算需求.UD ...
- Linux - ssh 连接问题
SSH 连接方式 ssh -p 22 user@192.168.1.209 # 从linux ssh登录另一台linux ssh -p 22 root@192.168.1.209 CMD # 利用ss ...
- checkbox判断不为空
checkbox不为空 html页面: @foreach($seals as $v) <input type="checkbox" name="seal_id[]& ...
- C# 解决VS2008在win7找不到输入序列号的地方
1.VS2008在Windows7 打开维护界面看不到可以输序列号的地方. 因为微软把他隐藏了. 2.我们可以借用工具把他显示出来 下载地址:http://www.zlsoft.com/techbbs ...
- 扩展欧几里得(E - The Balance POJ - 2142 )
题目链接:https://cn.vjudge.net/contest/276376#problem/E 题目大意:给你n,m,k,n,m代表当前由于无限个质量为n,m的砝码.然后当前有一个秤,你可以通 ...
- decimal模块
简介 decimal意思为十进制,这个模块提供了十进制浮点运算支持. 常用方法 1.可以传递给Decimal整型或者字符串参数,但不能是浮点数据,因为浮点数据本身就不准确. 2.要从浮点数据转换为De ...
- 在内部局域网内搭建HTTPs
在内部局域网内搭建HTTPs 配置环境 Windows版本:Windows Server 2008 R2 Standard Service Pack 1 系统类型: 64 位操作系统 内存 ...