一、nn.Embedding.weight初始化分布

nn.Embedding.weight随机初始化方式是标准正态分布  ,即均值$\mu=0$,方差$\sigma=1$的正态分布。

论据1——查看源代码

## class Embedding具体实现(在此只展示部分代码)
import torch
from torch.nn.parameter import Parameter from .module import Module
from .. import functional as F class Embedding(Module):
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
max_norm=None, norm_type=2, scale_grad_by_freq=False,
sparse=False, _weight=None):
if _weight is None:
self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.reset_parameters()
else:
assert list(_weight.shape) == [num_embeddings, embedding_dim], \
'Shape of weight does not match num_embeddings and embedding_dim'
self.weight = Parameter(_weight) def reset_parameters(self):
self.weight.data.normal_(0, 1)
if self.padding_idx is not None:
self.weight.data[self.padding_idx].fill_(0)

Embedding这个类有个属性weight,它是torch.nn.parameter.Parameter类型的,作用就是存储真正的word embeddings。如果不给weight赋值,Embedding类会自动给他初始化,看上述代码第6~8行,如果属性weight没有手动赋值,则会定义一个torch.nn.parameter.Parameter对象,然后对该对象进行reset_parameters(),看第21行,对self.weight先转为Tensor在对其进行normal_(0, 1)(调整为$N(0, 1)$正态分布)。所以nn.Embeddig.weight默认初始化方式就是N(0, 1)分布,即均值$\mu=0$,方差$\sigma=1$的标准正态分布。

论据2——简单验证nn.Embeddig.weight的分布

下面将做的是验证nn.Embeddig.weight某一行词向量的均值和方差,以便验证是否为标准正态分布。
注意:验证一行数字的均值为0,方差为1,显然不能说明该分布就是标准正态分布,只能是其必要条件,而不是充分条件,要想真正检测这行数字是不是正态分布,在概率论上有专门的较为复杂的方法,请查看概率论之假设检验。

import torch.nn as nn

# dim越大,均值、方差越接近0和1
dim = 800000
# 定义了一个(5, dim)的二维embdding
# 对于NLP来说,相当于是5个词,每个词的词向量维数是dim
# 每个词向量初始化为正态分布 N(0,1)(待验证)
embd = nn.Embedding(5, dim)
# type(embd.weight) is Parameter
# type(embd.weight.data) is Tensor
# embd.weight.data[0]是指(5, dim)的word embeddings中取第1个词的词向量,是dim维行向量
weight = embd.weight.data[0].numpy()
print("weight: {}".format(weight)) weight_sum = 0
for w in weight:
weight_sum += w
mean = weight_sum / dim
print("均值: {}".format(mean)) square_sum = 0
for w in weight:
square_sum += (mean - w) ** 2
print("方差: {}".format(square_sum / dim))

代码输出:

weight: [-0.65507996  0.11627434 -1.6705967  ...  0.78397447  ...  -0.13477565]
均值: 0.0006973597864689242
方差: 1.0019535550544454

可见,均值接近0,方差接近1,从这里也可以反映出nn.Embeddig.weight是标准正态分布$N(0, 1)$。

二、torch.Tensortorch.tensortorch.randn初始化分布

1、torch.rand

返回$[0,1)$上的均匀分布(uniform distribution)。

2、torch.randn

返回$N(0, 1)$,即标准正态分布(standard normal distribution)。

3、torch.Tensor

torch.Tensor是Tensor class,torch.Tensor(2, 3)是调用Tensor的构造函数,构造了$2\times3$矩阵,但是没有分配空间,未初始化。
不推荐使用torch.Tensor创建Tensor,应使用torch.tenstortorch.onestorch.zerostorch.randtorch.randn等,原因:

t = torch.Tensor(2,3)
# 容易出现下述错误,因为t中的值取决当前内存中的随机值
# 如果当前内存中随机值特别大会溢出
RuntimeError: Overflow when unpacking long


Pytorch的默认初始化分布 nn.Embedding.weight初始化分布的更多相关文章

  1. pytorch nn.Embedding

    pytorch nn.Embeddingclass torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_no ...

  2. 『PyTorch』第十三弹_torch.nn.init参数初始化

    初始化参数的方法 nn.Module模块对于参数进行了内置的较为合理的初始化方式,当我们使用nn.Parameter时,初始化就很重要,而且我们也可以指定代替内置初始化的方式对nn.Module模块进 ...

  3. pytorch中文文档-torch.nn常用函数-待添加-明天继续

    https://pytorch.org/docs/stable/nn.html 1)卷积层 class torch.nn.Conv2d(in_channels, out_channels, kerne ...

  4. torch.nn.Embedding理解

    Pytorch官网的解释是:一个保存了固定字典和大小的简单查找表.这个模块常用来保存词嵌入和用下标检索它们.模块的输入是一个下标的列表,输出是对应的词嵌入. torch.nn.Embedding(nu ...

  5. torch.nn.Embedding

    自然语言中的常用的构建词向量方法,将id化后的语料库,映射到低维稠密的向量空间中,pytorch 中的使用如下: import torch import torch.utils.data as Dat ...

  6. C++中默认构造函数中数据成员的初始化

    构造函数的任务是初始化数据成员的,在类中,如果没有显示定义任何构造函数,编译器将为我们创建一个构造函数,称为合成的默认构造函数,合成的默认构造函数使用与变量初始化相同的规则来初始化成员.即当类中的数据 ...

  7. C#语法糖之第二篇: 参数默认值和命名参数 对象初始化器与集合初始化器

    今天继续写上一篇文章C#4.0语法糖之第二篇,在开始今天的文章之前感谢各位园友的支持,通过昨天写的文章,今天有很多园友们也提出了文章中的一些不足,再次感谢这些关心我的园友,在以后些文章的过程中不断的完 ...

  8. 伯努利分布、二项分布、Beta分布、多项分布和Dirichlet分布与他们之间的关系,以及在LDA中的应用

    在看LDA的时候,遇到的数学公式分布有些多,因此在这里总结一下思路. 一.伯努利试验.伯努利过程与伯努利分布 先说一下什么是伯努利试验: 维基百科伯努利试验中: 伯努利试验(Bernoulli tri ...

  9. Java类的初始化与实例对象的初始化

    Java对象初始化详解 2013/04/10 · 开发 · 1 评论· java 分享到:43 与<YII框架>不得不说的故事—扩展篇 sass进阶篇 Spring事务管理 Android ...

随机推荐

  1. 个人站长建议直接封掉的IP地址列表

    <Valve className="org.apache.catalina.valves.RemoteAddrValve" deny="164.100.196.21 ...

  2. 洛谷P3459 [POI2007]MEG-Megalopolis [2017年6月计划 树上问题02]

    [POI2007]MEG-Megalopolis 题目描述 Byteotia has been eventually touched by globalisation, and so has Byte ...

  3. VMware ESXi 6.7服务器设置开机自动启动虚拟机

    VMware ESXi 6.7服务器设置开机自动启动虚拟机,具体操作步骤如下 1.登陆到VMware ESXi 6.7  web 界面 2.导航器-->主机-->管理  将自动启动修改为 ...

  4. 云原生交付加速!容器镜像服务企业版支持 Helm Chart

    2018 年 6 月,Helm 正式加入了 CNCF 孵化项目:2018 年 8 月,据 CNCF 的调研表明,有百分之六十八的开发者选择了 Helm 作为其应用包装方案:2019 年 6 月,阿里云 ...

  5. C++函数部分总结

    目录 为什么要使用函数 为什么要用函数重载 C++传参方式 特殊的函数--递归函数 为什么要使用函数 使用函数可以将一个比较复杂的程序系统的分为若干块简洁的模块,使程序更加清晰明了 比如,我们想要模拟 ...

  6. 【笔记】http协议笔记

    本文是本人在复习http协议时,手动整理的资料,以备后续查阅. http(hypertext transfer protocol):超文本协议.是万维网(world wide web,www,也简称为 ...

  7. .NET EasyUI datebox添加清空功能

    前言,前段时间的项目使用EasyUI框架搭建,使用了其自带的一系列组件.但对于datebox,其功能别的不多说,令人蛋疼的是它居然没有清空功能,这让在搜索区域中摆了日期条件的咋整啊,没办法,既然用了这 ...

  8. 常见的php攻击(6种攻击详解)

    1.SQL注入 SQL注入是一种恶意攻击,用户利用在表单字段输入SQL语句的方式来影响正常的SQL执行.还有一种是通过system()或exec()命令注入的,它具有相同的SQL注入机制,但只针对sh ...

  9. MVC开发模式与web经典三层框架

    MVC:Model(模型)-View(视图)-Controller(控制器) ----是一种软件架构模式,一般把软件系统拆分为这三个层次. 视图View层:前端交互界面或者后端系统界面,它从模型中获取 ...

  10. No.6 Verilog 其他论题

    (1)任务  **任务类似于一段程序,可以提供一种能力,使设计者可以从设计描述的不同位置执行共同的代码段.任务可以包含时序控制, 可以调用其它任务和函数.  任务的定义格式: task[automat ...