Note | PyTorch
PyTorch随手记
Note:
官方书籍:Deep learning with PyTorch。
1. 模型操作
假设我们有一个用self.arcnn = nn.Sequential(...)定义并训练好的ARCNN模型。我们想迁移过来,冻结前几层再训练。分两步:
print(model.state_dict())查看名称,如'arcnn.12.bias', 'arcnn.12.weight'等。model.arcnn[0].weight.requires_grad = False,model.arcnn[0].bias.requires_grad = False,让第一层冻结。
2. 网络设计
卷积图示
填充(padding)
PyTorch和TensorFlow的填充规则是不同的。因此必须查阅官方文档。
如果y = F.pad(x, (1,2,3,4)),意思是:在\(x\)的最后一个维度上(一般是W),左边填一圈零,右边填两圈0(默认为0);在\(x\)的倒数第二个维度上(一般是H),上面填3圈零,下面填4圈零。
升采样
其中有一个参数align_corners。例子参见官方教程里的Example。
这里有一个图例:

全连接层
假设我们经过多层卷积,得到了\((128, 32, 4, 4)\)的通道,即batch size为128,32张特征图,通道尺寸为\(4 \times 4\)。我们希望基于此得到2分类。那么可以如下操作:
self.l1 = nn.Linear(32 * 4 * 4, 128)
self.l2 = nn.Linear(128, 32)
self.l3 = nn.Linear(32, 2)
x = x.view(-1, 32 * 4 * 4)
x = self.l1(x)
x = self.l2(x)
x = self.l3(x)
关于交叉熵和softmax,参见损失函数。
3. 损失函数
交叉熵
loss_func = F.cross_entropy
batch_pred_t = model(batch_cmp_t)
batch_pred = batch_pred_t.detach().cpu()
acc = cal_acc(batch_pred, batch_label)
def cal_acc(batch_pred, batch_label):
batch_pred = [torch.argmax(batch_pred[ite_patch]) for ite_patch in range(batch_size)]
acc = 0
for ite_patch in range(batch_size):
if pred[ite_patch] == batch_label[ite_patch]:
acc += 1
acc /= batch_size
return acc
注意:
cross_entropy函数结合了nn.LogSoftmax()和nn.NLLLoss()。第二个参数是
target。假设batch size是32,那么就是一个32维向量(张量),值为从0开始的正确标签。第一个参数是
input,可以没有被softmax归一化。假设batch size是32,一共有5个分类,那么就是一个\(32 \times 5\)的张量。
4. 系统或环境交互
模型加载
自动搜索空余显存最多的GPU,然后将模型加载到该GPU上:
os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')
memory_gpu=[int(x.split()[2]) for x in open('tmp','r').readlines()]
dev = torch.device("cuda:" + str(np.argmax(memory_gpu)))
print(dev)
model.load_state_dict(torch.load(os.path.join(dir_model, "model_" + str(index_model) + ".pt"), map_location=dev))
model.to(dev)
5. 犯过的错误
损失异常
- CNN最后一层使用了非线性激活函数ReLU,导致输出在0附近浮动。
测试显存过大
在测试程序中指定了torch.no_grad(),然而显存还是过大。后来改成with torch.no_grad():包裹测试程序,成功了。
Note | PyTorch的更多相关文章
- Note | PyTorch官方教程学习笔记
目录 1. 快速入门PYTORCH 1.1. 什么是PyTorch 1.1.1. 基础概念 1.1.2. 与NumPy之间的桥梁 1.2. Autograd: Automatic Differenti ...
- 理解PyTorch的自动微分机制
参考Getting Started with PyTorch Part 1: Understanding how Automatic Differentiation works 非常好的文章,讲解的非 ...
- 基于pytorch的CNN、LSTM神经网络模型调参小结
(Demo) 这是最近两个月来的一个小总结,实现的demo已经上传github,里面包含了CNN.LSTM.BiLSTM.GRU以及CNN与LSTM.BiLSTM的结合还有多层多通道CNN.LSTM. ...
- PyTorch官方中文文档:torch.nn
torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...
- pytorch对可变长度序列的处理
主要是用函数torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils. ...
- pytorch .detach() .detach_() 和 .data用于切断反向传播
参考:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-autograd/#detachsource 当我们再训 ...
- 一文看懂Transformer内部原理(含PyTorch实现)
Transformer注解及PyTorch实现 原文:http://nlp.seas.harvard.edu/2018/04/03/attention.html 作者:Alexander Rush 转 ...
- [转] 理解CheckPoint及其在Tensorflow & Keras & Pytorch中的使用
作者用游戏的暂停与继续聊明白了checkpoint的作用,在三种主流框架中演示实际使用场景,手动点赞. 转自:https://blog.floydhub.com/checkpointing-tutor ...
- pytorch做seq2seq注意力模型的翻译
以下是对pytorch 1.0版本 的seq2seq+注意力模型做法语--英语翻译的理解(这个代码在pytorch0.4上也可以正常跑): # -*- coding: utf-8 -*- " ...
随机推荐
- Paper | Non-local Neural Networks
目录 1. 动机 2. 相关工作 3. Non-local神经网络 3.1 Formulation 3.2 具体实现形式 3.3 Non-local块 4. 视频分类模型 4.1 2D ConvNet ...
- IT兄弟连 Java语法教程 Java语法基础 经典面试题
1.Java语言中有几种基本类型?分别是什么?请详细说明每种类型的范围以及所占的空间大小? Java语言中有8中基本类型,分别是代表整形的byte.short.int和long,代表浮点型的float ...
- 简单探讨一下.NET Core 3.0使用AspectCore的新姿势
前言 这几天在对EasyCaching做支持.net core 3.0的调整.期间遇到下面这个错误. System.NotSupportedException:"ConfigureServi ...
- c++11 C++14 C++17
Since C++11, WG21, the ISO designation for the C++ standard, try to shipped the standard every 3 ye ...
- jdk api 1.6,1.7,1.8,1.9版本(中文)
有需要的朋友,请自行到百度云下载 链接:https://pan.baidu.com/s/18WgEZ1WpBz5YexbbgikJcA 提取码:xry4
- FCC---Use CSS Animation to Change the Hover State of a Button---鼠标移过,背景色变色,用0.5s的动画制作
You can use CSS @keyframes to change the color of a button in its hover state. Here's an example of ...
- IDEA构建spring项目
这两天使用IDEA从零构建一个spring项目,之所以说从零,是因为,我这个小白呢,之前IDEA没有碰过,spring也只是知道个名字. 因为没有文档,遇到了好些坑,把这些记录一下吧. 构建的第一步, ...
- 转.HTML中img标签的src属性绝对路径问题解决办法,完全解决!
HTML中img标签的src属性绝对路径问题解决办法,完全解决 需求:有时候自己的项目img的src路径需要用到本地某文件夹下的图片,而不是直接使用项目根目录下的图片. 场景:eclipse,to ...
- 渗透测试之wep无线网络破解
WEP 无线网络破解 WEP(无线等效协议)于 1999 年诞生,并且是用于无线网络的最古老的安全标准.在 2003 年,WEP 被 WPA 以及之后被 WPA2 取代.由于可以使用更加安全的协议,W ...
- 大话IdentityServer4之使用 IdentityServer4 保护 ASP.NET Core 应用
这几天一直在研究IdentityServer4在asp.net core3.0中的应用,下面说说我的理解: 我们每一个.net core 项目大家可以理解为我新建了一个动物园或者植物园等,注册用户想要 ...