一、nn.Embedding.weight初始化分布

nn.Embedding.weight随机初始化方式是标准正态分布  ,即均值$\mu=0$,方差$\sigma=1$的正态分布。

论据1——查看源代码

## class Embedding具体实现(在此只展示部分代码)
import torch
from torch.nn.parameter import Parameter from .module import Module
from .. import functional as F class Embedding(Module):
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
max_norm=None, norm_type=2, scale_grad_by_freq=False,
sparse=False, _weight=None):
if _weight is None:
self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.reset_parameters()
else:
assert list(_weight.shape) == [num_embeddings, embedding_dim], \
'Shape of weight does not match num_embeddings and embedding_dim'
self.weight = Parameter(_weight) def reset_parameters(self):
self.weight.data.normal_(0, 1)
if self.padding_idx is not None:
self.weight.data[self.padding_idx].fill_(0)

Embedding这个类有个属性weight,它是torch.nn.parameter.Parameter类型的,作用就是存储真正的word embeddings。如果不给weight赋值,Embedding类会自动给他初始化,看上述代码第6~8行,如果属性weight没有手动赋值,则会定义一个torch.nn.parameter.Parameter对象,然后对该对象进行reset_parameters(),看第21行,对self.weight先转为Tensor在对其进行normal_(0, 1)(调整为$N(0, 1)$正态分布)。所以nn.Embeddig.weight默认初始化方式就是N(0, 1)分布,即均值$\mu=0$,方差$\sigma=1$的标准正态分布。

论据2——简单验证nn.Embeddig.weight的分布

下面将做的是验证nn.Embeddig.weight某一行词向量的均值和方差,以便验证是否为标准正态分布。
注意:验证一行数字的均值为0,方差为1,显然不能说明该分布就是标准正态分布,只能是其必要条件,而不是充分条件,要想真正检测这行数字是不是正态分布,在概率论上有专门的较为复杂的方法,请查看概率论之假设检验。

import torch.nn as nn

# dim越大,均值、方差越接近0和1
dim = 800000
# 定义了一个(5, dim)的二维embdding
# 对于NLP来说,相当于是5个词,每个词的词向量维数是dim
# 每个词向量初始化为正态分布 N(0,1)(待验证)
embd = nn.Embedding(5, dim)
# type(embd.weight) is Parameter
# type(embd.weight.data) is Tensor
# embd.weight.data[0]是指(5, dim)的word embeddings中取第1个词的词向量,是dim维行向量
weight = embd.weight.data[0].numpy()
print("weight: {}".format(weight)) weight_sum = 0
for w in weight:
weight_sum += w
mean = weight_sum / dim
print("均值: {}".format(mean)) square_sum = 0
for w in weight:
square_sum += (mean - w) ** 2
print("方差: {}".format(square_sum / dim))

代码输出:

weight: [-0.65507996  0.11627434 -1.6705967  ...  0.78397447  ...  -0.13477565]
均值: 0.0006973597864689242
方差: 1.0019535550544454

可见,均值接近0,方差接近1,从这里也可以反映出nn.Embeddig.weight是标准正态分布$N(0, 1)$。

二、torch.Tensortorch.tensortorch.randn初始化分布

1、torch.rand

返回$[0,1)$上的均匀分布(uniform distribution)。

2、torch.randn

返回$N(0, 1)$,即标准正态分布(standard normal distribution)。

3、torch.Tensor

torch.Tensor是Tensor class,torch.Tensor(2, 3)是调用Tensor的构造函数,构造了$2\times3$矩阵,但是没有分配空间,未初始化。
不推荐使用torch.Tensor创建Tensor,应使用torch.tenstortorch.onestorch.zerostorch.randtorch.randn等,原因:

t = torch.Tensor(2,3)
# 容易出现下述错误,因为t中的值取决当前内存中的随机值
# 如果当前内存中随机值特别大会溢出
RuntimeError: Overflow when unpacking long


Pytorch的默认初始化分布 nn.Embedding.weight初始化分布的更多相关文章

  1. pytorch nn.Embedding

    pytorch nn.Embeddingclass torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_no ...

  2. 『PyTorch』第十三弹_torch.nn.init参数初始化

    初始化参数的方法 nn.Module模块对于参数进行了内置的较为合理的初始化方式,当我们使用nn.Parameter时,初始化就很重要,而且我们也可以指定代替内置初始化的方式对nn.Module模块进 ...

  3. pytorch中文文档-torch.nn常用函数-待添加-明天继续

    https://pytorch.org/docs/stable/nn.html 1)卷积层 class torch.nn.Conv2d(in_channels, out_channels, kerne ...

  4. torch.nn.Embedding理解

    Pytorch官网的解释是:一个保存了固定字典和大小的简单查找表.这个模块常用来保存词嵌入和用下标检索它们.模块的输入是一个下标的列表,输出是对应的词嵌入. torch.nn.Embedding(nu ...

  5. torch.nn.Embedding

    自然语言中的常用的构建词向量方法,将id化后的语料库,映射到低维稠密的向量空间中,pytorch 中的使用如下: import torch import torch.utils.data as Dat ...

  6. C++中默认构造函数中数据成员的初始化

    构造函数的任务是初始化数据成员的,在类中,如果没有显示定义任何构造函数,编译器将为我们创建一个构造函数,称为合成的默认构造函数,合成的默认构造函数使用与变量初始化相同的规则来初始化成员.即当类中的数据 ...

  7. C#语法糖之第二篇: 参数默认值和命名参数 对象初始化器与集合初始化器

    今天继续写上一篇文章C#4.0语法糖之第二篇,在开始今天的文章之前感谢各位园友的支持,通过昨天写的文章,今天有很多园友们也提出了文章中的一些不足,再次感谢这些关心我的园友,在以后些文章的过程中不断的完 ...

  8. 伯努利分布、二项分布、Beta分布、多项分布和Dirichlet分布与他们之间的关系,以及在LDA中的应用

    在看LDA的时候,遇到的数学公式分布有些多,因此在这里总结一下思路. 一.伯努利试验.伯努利过程与伯努利分布 先说一下什么是伯努利试验: 维基百科伯努利试验中: 伯努利试验(Bernoulli tri ...

  9. Java类的初始化与实例对象的初始化

    Java对象初始化详解 2013/04/10 · 开发 · 1 评论· java 分享到:43 与<YII框架>不得不说的故事—扩展篇 sass进阶篇 Spring事务管理 Android ...

随机推荐

  1. fastjson 对象和json互转

    list转json List<Openid> openids = od.getAll(session); String json = JSONObject.toJSONString(ope ...

  2. MySQL数据库起步 关于数据库的基本操作(更新中...)

    mysql的基本操作 连接指定的服务器(需要服务器开启3306端口) mysql -h ip地址 -P 端口号 -u 账号 -p 密码 删除游客模式 mysql -h ip地址 -P 端口号 -u 账 ...

  3. 2018-8-10-如何移动-nuget-缓存文件夹

    title author date CreateTime categories 如何移动 nuget 缓存文件夹 lindexi 2018-08-10 19:16:51 +0800 2018-2-13 ...

  4. web前端学习(二)html学习笔记部分(2)-- 改良的元素(input元素等等)

    1.2.5  HTML5 改良的 input 元素的种类 1.2.5.1  新增的input元素种类中的改良与增加 input 元素的种类 (1) 新增的input元素种类中的url类型.email类 ...

  5. Linux SSH远程链接 短时间内断开

    Linux SSH远程链接 短时间内断开 操作系统:RedHat 7.5 问题描述: 在进行SSH链接后,时不时的就断开了 解决方案: 修改 /etc/ssh/sshd_config 文件,找到 Cl ...

  6. [Vue CLI 3] 配置解析之 parallel

    官方文档中介绍过在 vue.config.js 文件中可以配置 parallel,作用如下: 是否为 Babel 或 TypeScript 使用 thread-loader. 该选项在系统的 CPU ...

  7. SQL优化系列(一)- 优化SQL

     优化SQL SQL开发人员从源代码中发现一条跑得很慢的SQL, 如何优化? DBA从AWR报告中发现一条跑得很慢的SQL,没有源代码或者不想修改源代码怎么办? SQL自动优化工具SQL Tuning ...

  8. vs2015卸载、vs2008安装Visual Assist x

    卸载2015 参考博文 1. 手动卸载VS2015的主要部分: win10系统: 控制面板---程序和功能--Microsoft Visual Studio 2015---更改卸载 2. 下载工具并解 ...

  9. 如何在Data Lake Analytics中使用临时表

    前言 Data Lake Analytics (后文简称DLA)是阿里云重磅推出的一款用于大数据分析的产品,可以对存储在OSS,OTS上的数据进行查询分析.相较于传统的数据分析产品,用户无需将数据重新 ...

  10. LeetCode136 Single Number, LeetCode137 Single Number II, LeetCode260 Single Number III

    136. Single Number Given an array of integers, every element appears twice except for one. Find that ...