论文地址:http://www.iro.umontreal.ca/~vincentp/Publications/lm_jmlr.pdf

论文给出了NNLM的框架图:

      

针对论文,实现代码如下(https://github.com/graykode/nlp-tutorial):

 # -*- coding: utf-8 -*-
# @time : 2019/10/26 12:20 import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable dtype = torch.FloatTensor sentences = [ "i like dog", "i love coffee", "i hate milk"] word_list = " ".join(sentences).split()
word_list = list(set(word_list))
word_dict = {w: i for i, w in enumerate(word_list)} # {'i': 0, 'like': 1, 'love': 2, 'hate': 3, 'milk': 4, 'dog': 5, 'coffee': 6}}
number_dict = {i: w for i, w in enumerate(word_list)}
n_class = len(word_dict) # number of Vocabulary # NNLM Parameter
n_step = 2 # n-1 in paper ->3gram
n_hidden = 2 # h in paper ->number hidden unit
m = 2 # m in paper ->embedding size # make data batch (input,target)
# input: [[0,1],[0,2],[0,3]]
# target: [5,6,4]
def make_batch(sentences):
input_batch = []
target_batch = [] for sen in sentences:
word = sen.split()
input = [word_dict[n] for n in word[:-1]]
target = word_dict[word[-1]] input_batch.append(input)
target_batch.append(target) return input_batch, target_batch # Model
class NNLM(nn.Module):
def __init__(self):
super(NNLM, self).__init__()
self.C = nn.Embedding(n_class, m)
self.H = nn.Parameter(torch.randn(n_step * m, n_hidden).type(dtype))
self.W = nn.Parameter(torch.randn(n_step * m, n_class).type(dtype))
self.d = nn.Parameter(torch.randn(n_hidden).type(dtype))
self.U = nn.Parameter(torch.randn(n_hidden, n_class).type(dtype))
self.b = nn.Parameter(torch.randn(n_class).type(dtype)) def forward(self, X):
X = self.C(X)
X = X.view(-1, n_step * m) # [batch_size, n_step * m]
tanh = torch.tanh(self.d + torch.mm(X, self.H)) # [batch_size, n_hidden]
output = self.b + torch.mm(X, self.W) + torch.mm(tanh, self.U) # [batch_size, n_class]
return output model = NNLM() criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001) input_batch, target_batch = make_batch(sentences)
input_batch = Variable(torch.LongTensor(input_batch))
target_batch = Variable(torch.LongTensor(target_batch)) # Training
for epoch in range(5000): optimizer.zero_grad()
output = model(input_batch) # output : [batch_size, n_class], target_batch : [batch_size] (LongTensor, not one-hot)
loss = criterion(output, target_batch)
if (epoch + 1)%1000 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss)) loss.backward()
optimizer.step() # Predict [5,6,4] (equal with target)
predict = model(input_batch).data.max(1, keepdim=True)[1] # print to visual
print([sen.split()[:2] for sen in sentences], '->', [number_dict[n.item()] for n in predict.squeeze()])

pytorch ---神经网络语言模型 NNLM 《A Neural Probabilistic Language Model》的更多相关文章

  1. A Neural Probabilistic Language Model

    A Neural Probabilistic Language Model,这篇论文是Begio等人在2003年发表的,可以说是词表示的鼻祖.在这里给出简要的译文 A Neural Probabili ...

  2. 从代码角度理解NNLM(A Neural Probabilistic Language Model)

    其框架结构如下所示: 可分为四 个部分: 词嵌入部分 输入 隐含层 输出层 我们要明确任务是通过一个文本序列(分词后的序列)去预测下一个字出现的概率,tensorflow代码如下: 参考:https: ...

  3. A Neural Probabilistic Language Model (2003)论文要点

    论文链接:http://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf 解决n-gram语言模型(比如tri-gram以上)的组合爆炸问题,引入 ...

  4. NLP问题特征表达基础 - 语言模型(Language Model)发展演化历程讨论

    1. NLP问题简介 0x1:NLP问题都包括哪些内涵 人们对真实世界的感知被成为感知世界,而人们用语言表达出自己的感知视为文本数据.那么反过来,NLP,或者更精确地表达为文本挖掘,则是从文本数据出发 ...

  5. CSC321 神经网络语言模型 RNN-LSTM

    主要两个方面 Probabilistic modeling 概率建模,神经网络模型尝试去预测一个概率分布 Cross-entropy作为误差函数使得我们可以对于观测到的数据 给予较高的概率值 同时可以 ...

  6. 用CNTK搞深度学习 (二) 训练基于RNN的自然语言模型 ( language model )

    前一篇文章  用 CNTK 搞深度学习 (一) 入门    介绍了用CNTK构建简单前向神经网络的例子.现在假设读者已经懂得了使用CNTK的基本方法.现在我们做一个稍微复杂一点,也是自然语言挖掘中很火 ...

  7. [DeeplearningAI笔记]序列模型1.5-1.6不同类型的循环神经网络/语言模型与序列生成

    5.1循环序列模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 1.5不同类型的循环神经网络 上节中介绍的是 具有相同长度输入序列和输出序列的循环神经网络,但是对于很多应用\(T_{x}和 ...

  8. PyTorch 神经网络

    PyTorch 神经网络 神经网络 神经网络可以通过 torch.nn 包来构建. 现在对于自动梯度(autograd)有一些了解,神经网络是基于自动梯度 (autograd)来定义一些模型.一个 n ...

  9. 使用Google-Colab训练PyTorch神经网络

    Colaboratory 是免费的 Jupyter 笔记本环境,不需要进行任何设置就可以使用,并且完全在云端运行.关键是还有免费的GPU可以使用!用Colab训练PyTorch神经网络步骤如下: 1: ...

随机推荐

  1. vnpy源码阅读学习(3):学习vnpy的界面的实现

    学习vnpy的界面的实现 通过简单的学习了PyQt5的一些代码以后,我们基本上可以理解PyQt的一些用法,下面让我们来先研究下vnpy的UI部分的代码. 首先回到上一节看到的run.py(/vnpy/ ...

  2. Spark 配置参数

    SparkConfiguration 这一章节来看看 Spark的相关配置. 并非仅仅能够应用于 SparkStreaming, 而是对于 Spark的各种类型都有支持. 各个不同. 其中中文参考链接 ...

  3. ASP.Net MVC 引用动态 js 脚本

    希望可以动态生成 js  发送给客户端使用. layout页引用: <script type="text/javascript" src="@Url.Action( ...

  4. 关于neo4j初入门(5)

    neo4j和Java Neo4j提供JAVA API以编程方式执行所有数据库操作. 它支持两种类型的API: Neo4j的原生的Java API Neo4j Cypher Java API Neo4j ...

  5. vue拦截器

    1.在路由添加 meta:{ requireAuth:true } 完整 { path: '/xx', name: 'xx', component: xx, meta:{ requireAuth:tr ...

  6. Qt Installer Framework翻译(0)

    本人主攻C++和Qt. 以前一直看人家的博客,找资料学习.今天我也终于开博客啦. 最近在研究Qt install framework(IFW)应用程序安装框架. google也没发现有正儿八经的官方文 ...

  7. 异数OS TCP协议栈测试(三)--长连接篇

    异数OS TCP协议栈测试(三)--长连接篇 本文来自异数OS社区 github:   异数OS-织梦师(消息中间件)群: 476260389 异数OS TCP长连接技术简介 说起长连接,则首先要谈对 ...

  8. [bzoj1041] [洛谷P2508] [HAOI2008] 圆上的整点

    Description 求一个给定的圆(x^2+y^2=r^2),在圆周上有多少个点的坐标是整数. Input 只有一个正整数n,n<=2000 000 000 Output 整点个数 Samp ...

  9. Windows 64 位 mysql 5.7以上版本包解压中没有data目录和my-default.ini和my.ini文件以及服务无法启动的解决办法以及修改初始密码的方法

    下载解压mysql文件之后,中间出现了一些问题,终于解决,希望能帮助到需要的朋友. mysql官网下载地址:https://dev.mysql.com/downloads/mysql/点击打开链接 以 ...

  10. chrome 安装

    Centos7 yum安装chrome浏览器   跟着这个教程安装的:Centos7安装chrome浏览器 (点击) 1. 配置yum源 在目录 /etc/yum.repos.d/ 下新建文件 goo ...