自定义layer

https://www.cnblogs.com/sdu20112013/p/12132786.html一文里说了怎么写自定义的模型.本篇说怎么自定义层.

分两种:

  • 不含模型参数的layer
  • 含模型参数的layer

核心都一样,自定义一个继承自nn.Module的类,在类的forward函数里实现该layer的计算,不同的是,带参数的layer需要用到nn.Parameter

不含模型参数的layer

直接继承nn.Module

import torch
from torch import nn class CenteredLayer(nn.Module):
def __init__(self, **kwargs):
super(CenteredLayer, self).__init__(**kwargs)
def forward(self, x):
return x - x.mean() layer = CenteredLayer()
layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float)) net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
y = net(torch.rand(4, 8))
y.mean().item()

含模型参数的layer

  • Parameter
  • ParameterList
  • ParameterDict

Parameter类其实是Tensor的子类,如果一个TensorParameter,那么它会自动被添加到模型的参数列表里。所以在自定义含模型参数的层时,我们应该将参数定义成Parameter,除了直接定义成Parameter类外,还可以使用ParameterListParameterDict分别定义参数的列表和字典。

ParameterList用法和list类似

class MyDense(nn.Module):
def __init__(self):
super(MyDense,self).__init__()
self.params = nn.ParameterList([nn.Parameter(torch.randn(4,4)) for i in range(4)])
self.params.append(nn.Parameter(torch.randn(4,1))) def forward(self,x):
for i in range(len(self.params)):
x = torch.mm(x,self.params[i])
return x net = MyDense()
print(net)

输出

MyDense(
(params): ParameterList(
(0): Parameter containing: [torch.FloatTensor of size 4x4]
(1): Parameter containing: [torch.FloatTensor of size 4x4]
(2): Parameter containing: [torch.FloatTensor of size 4x4]
(3): Parameter containing: [torch.FloatTensor of size 4x4]
(4): Parameter containing: [torch.FloatTensor of size 4x1]
)
)

ParameterDict用法和python dict类似.也可以用.keys(),.items()

class MyDictDense(nn.Module):
def __init__(self):
super(MyDictDense, self).__init__()
self.params = nn.ParameterDict({
'linear1': nn.Parameter(torch.randn(4, 4)),
'linear2': nn.Parameter(torch.randn(4, 1))
})
self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增 def forward(self, x, choice='linear1'):
return torch.mm(x, self.params[choice]) net = MyDictDense()
print(net) print(net.params.keys(),net.params.items()) x = torch.ones(1, 4)
net(x, 'linear1')

输出

MyDictDense(
(params): ParameterDict(
(linear1): Parameter containing: [torch.FloatTensor of size 4x4]
(linear2): Parameter containing: [torch.FloatTensor of size 4x1]
(linear3): Parameter containing: [torch.FloatTensor of size 4x2]
)
)
odict_keys(['linear1', 'linear2', 'linear3']) odict_items([('linear1', Parameter containing:
tensor([[-0.2275, -1.0434, -1.6733, -1.8101],
[ 1.7530, 0.0729, -0.2314, -1.9430],
[-0.1399, 0.7093, -0.4628, -0.2244],
[-1.6363, 1.2004, 1.4415, -0.1364]], requires_grad=True)), ('linear2', Parameter containing:
tensor([[ 0.5035],
[-0.0171],
[-0.8580],
[-1.1064]], requires_grad=True)), ('linear3', Parameter containing:
tensor([[-1.2078, 0.4364],
[-0.8203, 1.7443],
[-1.7759, 2.1744],
[-0.8799, -0.1479]], requires_grad=True))])

使用自定义的layer构造模型

layer1 = MyDense()
layer2 = MyDictDense() net = nn.Sequential(layer2,layer1)
print(net)
print(net(x))

输出

Sequential(
(0): MyDictDense(
(params): ParameterDict(
(linear1): Parameter containing: [torch.FloatTensor of size 4x4]
(linear2): Parameter containing: [torch.FloatTensor of size 4x1]
(linear3): Parameter containing: [torch.FloatTensor of size 4x2]
)
)
(1): MyDense(
(params): ParameterList(
(0): Parameter containing: [torch.FloatTensor of size 4x4]
(1): Parameter containing: [torch.FloatTensor of size 4x4]
(2): Parameter containing: [torch.FloatTensor of size 4x4]
(3): Parameter containing: [torch.FloatTensor of size 4x4]
(4): Parameter containing: [torch.FloatTensor of size 4x1]
)
)
)
tensor([[-4.7566]], grad_fn=<MmBackward>)

从头学pytorch(十一):自定义层的更多相关文章

  1. 从头学pytorch(一):数据操作

    跟着Dive-into-DL-PyTorch.pdf从头开始学pytorch,夯实基础. Tensor创建 创建未初始化的tensor import torch x = torch.empty(5,3 ...

  2. 从头学pytorch(三) 线性回归

    关于什么是线性回归,不多做介绍了.可以参考我以前的博客https://www.cnblogs.com/sdu20112013/p/10186516.html 实现线性回归 分为以下几个部分: 生成数据 ...

  3. 从头学pytorch(九):模型构造

    模型构造 nn.Module nn.Module是pytorch中提供的一个类,是所有神经网络模块的基类.我们自定义的模块要继承这个基类. import torch from torch import ...

  4. 从头学pytorch(六):权重衰减

    深度学习中常常会存在过拟合现象,比如当训练数据过少时,训练得到的模型很可能在训练集上表现非常好,但是在测试集上表现不好. 应对过拟合,可以通过数据增强,增大训练集数量.我们这里先不介绍数据增强,先从模 ...

  5. 从头学pytorch(七):dropout防止过拟合

    上一篇讲了防止过拟合的一种方式,权重衰减,也即在loss上加上一部分\(\frac{\lambda}{2n} \|\boldsymbol{w}\|^2\),从而使得w不至于过大,即不过分偏向某个特征. ...

  6. 从头学pytorch(十二):模型保存和加载

    模型读取和存储 总结下来,就是几个函数 torch.load()/torch.save() 通过python的pickle完成序列化与反序列化.完成内存<-->磁盘转换. Module.s ...

  7. 从头学pytorch(十五):AlexNet

    AlexNet AlexNet是2012年提出的一个模型,并且赢得了ImageNet图像识别挑战赛的冠军.首次证明了由计算机自动学习到的特征可以超越手工设计的特征,对计算机视觉的研究有着极其重要的意义 ...

  8. 从头学pytorch(十九):批量归一化batch normalization

    批量归一化 论文地址:https://arxiv.org/abs/1502.03167 批量归一化基本上是现在模型的标配了. 说实在的,到今天我也没搞明白batch normalize能够使得模型训练 ...

  9. 从头学pytorch(二十):残差网络resnet

    残差网络ResNet resnet是何凯明大神在2015年提出的.并且获得了当年的ImageNet比赛的冠军. 残差网络具有里程碑的意义,为以后的网络设计提出了一个新的思路. googlenet的思路 ...

随机推荐

  1. SDUT-3344_数据结构实验之二叉树五:层序遍历

    数据结构实验之二叉树五:层序遍历 Time Limit: 1000 ms Memory Limit: 65536 KiB Problem Description 已知一个按先序输入的字符序列,如abd ...

  2. oracle函数 BFILENAME(dir,file)

    [功能]函数返回一个空的BFILE位置值指示符,函数用于初始化BFILE变量或者是BFILE列. [参数]dir是一个directory类型的对象,file为一文件名. insert into lob ...

  3. 弹性FLEX布局

    页面布局一直都是web应用样式设计的重点 我们传统的布局方式都是基于盒模型的 利用display.position.float来布局有一定局限性 比如说实现自适应垂直居中 随着响应式布局的流行,CSS ...

  4. ThinkPHP3.2版本安全更新

    近日我们收到了一个关于3.2版本的漏洞提醒,官方已经第一时间进行处理和更新.由于3.2版本已经过了官方的维护和安全更新周期,而且大量的开发者也进行了二次开发,因此不再发布新版,官方仅进行安全公告和修复 ...

  5. el-tree文本内容过多显示不完全问题(解决)

    布局: <span class="custom-tree-node" slot-scope="{ node, data }"> 外层span 树节点 ...

  6. SpringSecurity认证流程详解

    SpringSecurity基本原理 在之前的文章<SpringBoot + Spring Security 基本使用及个性化登录配置>中对SpringSecurity进行了简单的使用介绍 ...

  7. 如何通过命令行 msbuild 编译项目

    本文告诉大家如何通过 msbuild 编译一个项目,通过命令行编译可以输出更多的编译信息,可以用来调试自己写的编译相关方法,可以看到是哪个文件编译失败 在开始菜单可以找到 VisualStudio 的 ...

  8. P1019 聪聪理扑克

    题目描述 聪聪的两个小伙伴灵灵和豪豪喜欢打扑克,什么斗地主.德州.牛牛,他们都玩的有模有样. 但是每次玩好扑克他们都不整理一下,所以整理扑克的任务就交到了聪聪的手上. 已知现在桌面上有 n 张扑克牌, ...

  9. CSU 2005: Nearest Maintenance Point(Dijkstra + bitset)

    Description A county consists of n cities (labeled 1, 2, …, n) connected by some bidirectional roads ...

  10. Linux 创建和销毁 urb

    struct urb 结构在驱动中必须不被静态创建, 或者在另一个结构中, 因为这可能破坏 USB 核心给 urb 使用的引用计数方法. 它必须使用对 usb_alloc_urb 函数的调用而被创 建 ...