使用PyTorch构建神经网络十分的简单,下面是我总结的PyTorch构建神经网络的一般过程以及我在学习当中遇到的一些问题,期望对你有所帮助。

PyTorch构建神经网络的一般过程

下面的程序是PyTorch官网60分钟教程上面构建神经网络的例子,版本0.4.1:

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim # 第一步:准备数据
# Compose是将两个转换的过程组合起来,ToTensor将numpy等数据类型转换为Tensor,将值变为0到1之间
# Normalize用公式(input-mean)/std 将值进行变换。这里mean=0.5,std=0.5,是将[0,1]区间转换为[-1,1]区间
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # trainloader 是一个将数据集和采样策略结合起来的,并提供在数据集上面迭代的方法
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=0) testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=0) classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # 第二步:构建神经网络框架,继承nn.Module类
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x net = Net() # 第三步:进行训练
# 定义损失策略和优化方法
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) # 训练神经网络
for epoch in range(4):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data optimizer.zero_grad() # 训练过程1:前向过程,计算输入到输出的结果
outputs = net(inputs) # 训练过程2:由结果和label计算损失
loss = criterion(outputs, labels) # 训练过程3:在图的层次上面计算所有变量的梯度
# 每次计算梯度的时候,其实是有一个动态的图在里面的,求导数就是对图中的参数w进行求导的过程
# 每个参数计算的梯度值保存在w.grad.data上面,在参数更新时使用
loss.backward() # 训练过程4:进行参数的更新
# optimizer不计算梯度,它利用已经计算好的梯度值对参数进行更新
optimizer.step() running_loss += loss.item() # item 返回的是一个数字
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' %
(epoch+1, i+1, running_loss/2000))
running_loss = 0.0
print('Finished Training') # 第四步:在测试集上面进行测试
total = 0
correct = 0
with torch.no_grad():
for data in testloader:
images, label = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == label).sum().item() print("Accuracy of networkd on the 10000 test images: %d %%" % (100*correct/total))

这个例子说明了构建神经网络的四个步骤:1:准备数据集 。2:构建神经网络框架,实现神经网络的类。 3:在训练集上进行训练。 4:在测试集上面进行测试。

而在第三步的训练阶段,也可以分为四个步骤:1:前向过程,计算输入到输出的结果。2:由结果和labels计算损失。3:后向过程,由损失计算各个变量的梯度。4:优化器根据梯度进行参数的更新。

训练过程中第loss和optim是怎么联系在一起的

loss是训练阶段的第三步,计算参数的梯度。optim是训练阶段的第四步,对参数进行更新。在optimizer初始化的时候,optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9),获取了参数的指针,可以对参数进行修改。当loss计算好参数的梯度以后,把值放在参数w.grad.data上面,然后optimizer直接利用这个值对参数进行更新。

以SGD为例,它进行step的时候的基本操作是这样的: p.data.add_(-group['lr'], d_p),其中 d_p = p.grad.data

为什么要进行梯度清零

在backward每次计算梯度的时候,会将新的梯度值加到原来旧的梯度值上面,这叫做梯度累加。下面的程序可以说明什么是梯度累加:

import torch
x = torch.rand(2, requires_grad=True)
y = x.mean() # y = (x_1 + x_2) / 2 所以求梯度后应是0.5 y.backward()
print(x.grad.data) # 输出结果:tensor([0.5000, 0.5000]) y.backward()
print(x.grad.data) # 输出结果:tensor([1., 1.]) 说明进行了梯度累积

求解梯度过程和参数更新过程是分开的,这对于那些需要多次求导累计梯度,然后一次更新的神经网络可能是有帮助的,比如RNN,对于DNN和CNN不需要进行梯度累加,所以需要进行梯度清零。

如何使用GPU进行训练

旧版本:

use_cuda = True if torch.cuda.is_available() else False  # 是否使用cuda
if use_cuda:
model = model.cuda() # 将模型的参数放入GPU
if use_cuda:
inputs, labels = inputs.cuda(), labels.cuda() # 将数据放入到GPU

0.4版本以后推荐新方法 to(device),

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device) #将模型的参数放入GPU中
inputs, labels = inputs.to(device), labels.to(device) # 将数据放入到GPU中

参考:

Pytorch内部中optim和loss是如何交互的? - 罗若天的回答 - 知乎

pytorch学习笔记(二):gradient

使用pytorch构建神经网络的流程以及一些问题的更多相关文章

  1. 使用PyTorch构建神经网络以及反向传播计算

    使用PyTorch构建神经网络以及反向传播计算 前一段时间南京出现了疫情,大概原因是因为境外飞机清洁处理不恰当,导致清理人员感染.话说国外一天不消停,国内就得一直严防死守.沈阳出现了一例感染人员,我在 ...

  2. 使用PyTorch构建神经网络模型进行手写识别

    使用PyTorch构建神经网络模型进行手写识别 PyTorch是一种基于Torch库的开源机器学习库,应用于计算机视觉和自然语言处理等应用,本章内容将从安装以及通过Torch构建基础的神经网络,计算梯 ...

  3. TFLearn构建神经网络

    TFLearn构建神经网络 Building the network TFLearn lets you build the network by defining the layers. Input ...

  4. 在IDEA中构建Tomcat项目流程

    在IDEA中构建Web项目流程 打开你的IDEA,跟着我走! 第一步:新建项目 第二步:找到Artifacts 点击绿色的+号,如图所示,点一下 这一步很关键,目的是设置输出格式为war包,如果你的项 ...

  5. pytorch构建自己的数据集

    现在需要在json文件里面读取图片的URL和label,这里面可能会出现某些URL地址无效的情况. python读取json文件 此处只需要将json文件里面的内容读取出来就可以了 with open ...

  6. tensorflow之神经网络实现流程总结

    tensorflow之神经网络实现流程总结 1.数据预处理preprocess 2.前向传播的神经网络搭建(包括activation_function和层数) 3.指数下降的learning_rate ...

  7. Tensorflow BatchNormalization详解:2_使用tf.layers高级函数来构建神经网络

    Batch Normalization: 使用tf.layers高级函数来构建神经网络 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 吴恩达deeplearningai课程 课程笔 ...

  8. GeneXus DevOps 自动化构建和部署流程

    以下视频详细介绍了GeneXus DevOps自动化构建和部署流程,包括通过MS Bulid来管理自动化流程,自动化的架构,以及在GeneXus Server上使用Jenkins做为自动化引擎. 视频 ...

  9. 使用 Visual Studio 2015 + Python3.6 + tensorflow 构建神经网络时报错:'utf-8' codec can't decode byte 0xcc in position 78: invalid continuation byte

    使用 Visual Studio 2015 + Python3.6 + tensorflow 构建神经网络时报错:'utf-8' codec can't decode byte 0xcc in pos ...

随机推荐

  1. C#_界面程序_数码游戏

    using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; usin ...

  2. shell 示例1 从1叠加到100

    从1叠加到100 echo $[$(echo +{..})] echo $[(+)*(/)] seq -s |bc

  3. redis初使用

    下载地址:https://redis.io/download Redis项目不正式支持Windows.但是,微软开放技术小组开发并维护了针对Win64的Windows端口 windows版下载地址:h ...

  4. 分模块开发创建service子模块——(八)

    1.右击父工程新建maven子模块

  5. dataframe 差集

    >>>data_a={'state':[1,1,2],'pop':['a','b','c']}>>>data_b={'state':[1,2,3],'pop':[' ...

  6. Go语言之Windows 10开发工具LiteIDE初步使用

    Intel Core i5-8250U,Windows 10家庭中文版,go version go1.11 windows/amd64,LiteIDE X34.1 在RUNOOB.COM的Go语言教程 ...

  7. jQuery-介绍

    一:什么是jQuery jQuery 是一个 JavaScript 库. 二:安装 http://jquery.com/download/ http://jquery.cuishifeng.cn/ j ...

  8. mybatis,genarate自动生成代码

    ---恢复内容开始--- generatorConfig.xml配置文件: <?xml version="1.0" encoding="UTF-8"?&g ...

  9. DNS的服务器和客户端的配置

    内网环境Linux发行版本均采用centos为主,centos下DNS服务端的搭建步骤如下: DNS master节点搭建步骤: 安装组件: yum install bind;      yum in ...

  10. DOM文档对象模型简介

    DOM简介     DOM是W3C(万维网联盟)的标准 "W3C文档对象模型DOM是中立于平台和语言的接口,它允许程序和脚本动态地访问和更新文档的内容.结构.样式".W3C DOM ...