PyTorch学习笔记之n-gram模型实现
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim CONTEXT_SIZE = 2 # the same as window_size
EMBEDDING_DIM = 10
test_sentence = "When forty winters shall besiege thy brow,And dig deep trenches in thy beauty's field,Thy youth's proud livery so gazed on now,Will be a totter'd weed of small worth held:Then being asked, where all thy beauty lies,Where all the treasure of thy lusty days;To say, within thine own deep sunken eyes,Were an all-eating shame, and thriftless praise.How much more praise deserv'd thy beauty's use,If thou couldst answer 'This fair child of mineShall sum my count, and make my old excuse,'Proving his beauty by succession thine!This were to be new made when thou art old,And see thy blood warm when thou feel'st it cold.".split() vocb = set(test_sentence) # remove repeated words
word2id = {word: i for i, word in enumerate(vocb)}
id2word = {word2id[word]: word for word in word2id} # define model
class NgramModel(nn.Module):
def __init__(self, vocb_size, context_size, n_dim):
# super(NgramModel, self)._init_()
super().__init__()
self.n_word = vocb_size
self.embedding = nn.Embedding(self.n_word, n_dim)
self.linear1 = nn.Linear(context_size*n_dim, 128)
self.linear2 = nn.Linear(128, self.n_word) def forward(self, x):
# the first step: transmit words and achieve word embedding. eg. transmit two words, and then achieve (2, 100)
emb = self.embedding(x)
# the second step: word wmbedding unfold to (1,200)
emb = emb.view(1, -1)
# the third step: transmit to linear model, and then use relu, at last, transmit to linear model again
out = self.linear1(emb)
out = F.relu(out)
out = self.linear2(out)
# the output dim of last step is the number of words, wo can view as a classification problem
# if we want to predict the max probability of the words, finally we need use log softmax
log_prob = F.log_softmax(out)
return log_prob ngrammodel = NgramModel(len(word2id), CONTEXT_SIZE, 100)
criterion = nn.NLLLoss()
optimizer = optim.SGD(ngrammodel.parameters(), lr=1e-3) trigram = [((test_sentence[i], test_sentence[i+1]), test_sentence[i+2])
for i in range(len(test_sentence)-2)] for epoch in range(100):
print('epoch: {}'.format(epoch+1))
print('*'*10)
running_loss = 0
for data in trigram:
# we use 'word' to represent the two words forward the predict word, we use 'label' to represent the predict word
word, label = data # attention
word = Variable(torch.LongTensor([word2id[e] for e in word]))
label = Variable(torch.LongTensor([word2id[label]]))
# forward
out = ngrammodel(word)
loss = criterion(out, label)
running_loss += loss.data[0]
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('loss: {:.6f}'.format(running_loss/len(word2id))) # predict
word, label = trigram[3]
word = Variable(torch.LongTensor([word2id[i] for i in word]))
out = ngrammodel(word)
_, predict_label = torch.max(out, 1)
predict_word = id2word[predict_label.data[0][0]]
print('real word is {}, predict word is {}'.format(label, predict_word))
PyTorch学习笔记之n-gram模型实现的更多相关文章
- ArcGIS案例学习笔记-批量裁剪地理模型
ArcGIS案例学习笔记-批量裁剪地理模型 联系方式:谢老师,135-4855-4328,xiexiaokui#qq.com 功能:空间数据的批量裁剪 优点:1.批量裁剪:任意多个目标数据,去裁剪任意 ...
- Java学习笔记之---单例模型
Java学习笔记之---单例模型 单例模型分为:饿汉式,懒汉式 (一)要点 1.某个类只能有一个实例 2.必须自行创建实例 3.必须自行向整个系统提供这个实例 (二)实现 1.只提供私有的构造方法 2 ...
- WebGL three.js学习笔记 加载外部模型以及Tween.js动画
WebGL three.js学习笔记 加载外部模型以及Tween.js动画 本文的程序实现了加载外部stl格式的模型,以及学习了如何把加载的模型变为一个粒子系统,并使用Tween.js对该粒子系统进行 ...
- ARMV8 datasheet学习笔记5:异常模型
1.前言 2.异常类型描述 见 ARMV8 datasheet学习笔记4:AArch64系统级体系结构之编程模型(1)-EL/ET/ST 一文 3. 异常处理路由对比 AArch32.AArch64架 ...
- Javascript MVC 学习笔记(一) 模型和数据
写在前面 近期在看<MVC的Javascript富应用开发>一书.本来是抱着一口气读完的想法去看的.结果才看了一点就傻眼了:太多不懂的地方了. 仅仅好看一点查一点,一点一点往下看吧,进度虽 ...
- PowerDesigner 15学习笔记:十大模型及五大分类
个人认为PowerDesigner 最大的特点和优势就是1)提供了一整套的解决方案,面向了不同的人员提供不同的模型工具,比如有针对企业架构师的模型,有针对需求分析师的模型,有针对系统分析师和软件架构师 ...
- [PyTorch 学习笔记] 3.1 模型创建步骤与 nn.Module
本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson3/module_containers.py 这篇文章来看下 ...
- [PyTorch 学习笔记] 7.1 模型保存与加载
本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py https://githu ...
- PyTorch学习笔记之CBOW模型实践
import torch from torch import nn, optim from torch.autograd import Variable import torch.nn.functio ...
随机推荐
- linux学习-systemd-journald.service 简介
过去只有 rsyslogd 的年代中,由于 rsyslogd 必须要开机完成并且执行了 rsyslogd 这个 daemon 之 后,登录文件才会开始记录.所以,核心还得要自己产生一个 klogd 的 ...
- Mysql登陆、退出、更改环境编码
登录: mysql -h[数据库地址] -u[username] -p[password] -P[端口] //大写P表示端口,小写p表示密码 例如:mysql -hlocalhost -uroot ...
- hdu4864不是一般的贪心
题目表达的非常清楚,也不绕弯刚开始以为最大权匹配,仔细一想不对,这题的数据双循环建图都会爆,只能先贪心试一下,但一想贪心也要双循环啊,怎么搞? 想了好久没头绪,后来经学长提醒,可以把没用到的先记录下来 ...
- 在终端更改MAC的MySQL的root密码
- SPOJ 375 树链剖分 QTREE - Query on a tree
人生第一道树链剖分的题目,其实树链剖分并不是特别难. 思想就是把树剖成一些轻链和重链,轻链比较少可以直接修改,重链比较长,用线段树去维护. 貌似大家都是从这篇博客上学的. #include <c ...
- 装饰器与lambda
装饰器 实际上理解装饰器的作用很简单,在看core python相关章节的时候大概就是这种感觉.只是在实际应用的时候,发现自己很难靠直觉决定如何使用装饰器,特别是带参数的装饰器,于是摊开来思考了一番, ...
- 用virtualbox+模拟串口+CDT调试linux内核 TCP/IP协议栈-起步
经常有人问一台机器如何将hello经网络发送给另一台机器,我确实是不知道,只能看代码了. 说明:本人对内核的研究学习也是刚刚起步,有很多不了解的,所以文中可能会有一些"一本正经的胡扯&quo ...
- 零基础学 JavaScript 全彩版 明日科技 编著
第1篇 基础知识 第1章 JavaScript简介 1.1 JavaScript简述 1.2 WebStorm的下载与安装 1.3 JavaScript在HTML中的使用 1.3.1 在页面中直接嵌入 ...
- js常见的方法
Ajax请求 jquery ajax函数 我自己封装了一个ajax的函数,代码如下: var Ajax = function(url, type success, error) { $.ajax({ ...
- GBDT 与 XGBoost
GBDT & XGBoost ### 回归树 单棵回归树可以表示成如下的数学形式 \[ f(x) = \sum_j^Tw_j\mathbf{I}(x\in R_j) \] 其中\(T\)为叶节 ...