PyTorch学习笔记之CBOW模型实践
import torch
from torch import nn, optim
from torch.autograd import Variable
import torch.nn.functional as F CONTEXT_SIZE = 2 # 2 words to the left, 2 to the right
raw_text = "We are about to study the idea of a computational process. Computational processes are abstract beings that inhabit computers. As they evolve, processes manipulate other abstract things called data. The evolution of a process is directed by a pattern of rules called a program. People create programs to direct processes. In effect, we conjure the spirits of the computer with our spells.".split(' ') vocab = set(raw_text)
word_to_idx = {word: i for i, word in enumerate(vocab)} data = []
for i in range(CONTEXT_SIZE, len(raw_text)-CONTEXT_SIZE):
context = [raw_text[i-2], raw_text[i-1], raw_text[i+1], raw_text[i+2]]
target = raw_text[i]
data.append((context, target)) class CBOW(nn.Module):
def __init__(self, n_word, n_dim, context_size):
super(CBOW, self).__init__()
self.embedding = nn.Embedding(n_word, n_dim)
self.linear1 = nn.Linear(2*context_size*n_dim, 128)
self.linear2 = nn.Linear(128, n_word) def forward(self, x):
x = self.embedding(x)
x = x.view(1, -1)
x = self.linear1(x)
x = F.relu(x, inplace=True)
x = self.linear2(x)
x = F.log_softmax(x)
return x model = CBOW(len(word_to_idx), 100, CONTEXT_SIZE)
if torch.cuda.is_available():
model = model.cuda() criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3) for epoch in range(100):
print('epoch {}'.format(epoch))
print('*'*10)
running_loss = 0
for word in data:
context, target = word
context = Variable(torch.LongTensor([word_to_idx[i] for i in context]))
target = Variable(torch.LongTensor([word_to_idx[target]]))
if torch.cuda.is_available():
context = context.cuda()
target = target.cuda()
# forward
out = model(context)
loss = criterion(out, target)
running_loss += loss.data[0]
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('loss: {:.6f}'.format(running_loss / len(data)))
PyTorch学习笔记之CBOW模型实践的更多相关文章
- [PyTorch 学习笔记] 3.1 模型创建步骤与 nn.Module
本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson3/module_containers.py 这篇文章来看下 ...
- [PyTorch 学习笔记] 7.1 模型保存与加载
本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py https://githu ...
- PyTorch学习笔记之n-gram模型实现
import torch import torch.nn as nn from torch.autograd import Variable import torch.nn.functional as ...
- 操作系统学习笔记----进程/线程模型----Coursera课程笔记
操作系统学习笔记----进程/线程模型----Coursera课程笔记 进程/线程模型 0. 概述 0.1 进程模型 多道程序设计 进程的概念.进程控制块 进程状态及转换.进程队列 进程控制----进 ...
- V-rep学习笔记:机器人模型创建3—搭建动力学模型
接着之前写的V-rep学习笔记:机器人模型创建2—添加关节继续机器人创建流程.如果已经添加好关节,那么就可以进入流程的最后一步:搭建层次结构模型和模型定义(build the model hierar ...
- V-rep学习笔记:机器人模型创建2—添加关节
下面接着之前经过简化并调整好视觉效果的模型继续工作流,为了使模型能受控制运动起来必须在合适的位置上添加相应的运动副/关节.一般情况下我们可以查阅手册或根据设计图纸获得这些关节的准确位置和姿态,知道这些 ...
- ArcGIS模型构建器案例学习笔记-字段处理模型集
ArcGIS模型构建器案例学习笔记-字段处理模型集 联系方式:谢老师,135-4855-4328,xiexiaokui@qq.com 由四个子模型组成 子模型1:判断字段是否存在 方法:python工 ...
- springmvc学习笔记--Interceptor机制和实践
前言: Spring的AOP理念, 以及j2ee中责任链(过滤器链)的设计模式, 确实深入人心, 处处可以看到它的身影. 这次借项目空闲, 来总结一下SpringMVC的Interceptor机制, ...
- java之jvm学习笔记六-十二(实践写自己的安全管理器)(jar包的代码认证和签名) (实践对jar包的代码签名) (策略文件)(策略和保护域) (访问控制器) (访问控制器的栈校验机制) (jvm基本结构)
java之jvm学习笔记六(实践写自己的安全管理器) 安全管理器SecurityManager里设计的内容实在是非常的庞大,它的核心方法就是checkPerssiom这个方法里又调用 AccessCo ...
随机推荐
- 牛客练习赛29 B
炎热的早上,gal男神们被迫再操场上列队,gal男神们本来想排列成x∗x的正方形,可是因为操场太小了(也可能是gal男神太大了),校长安排gal男神们站成多个4∗4的正方形(gal男神们可以正好分成n ...
- debian使用ibus
$ sudo apt-get install ibus ibus-pinyin 点击右上角的键盘图标,设置拼音输入法
- 光学字符识别OCR-4
经过第一部分,我们已经较好地提取了图像的文本特征,下面进行文字定位. 主要过程分两步: 1.邻近搜索,目的是圈出单行文字: 2.文本切割,目的是将单行文本切割为单字. ...
- jmeter进行dubbo接口测试
最近工作中接到一个需求,需要对一个MQ消息队列进行性能测试,测试其消费能力,开发提供了一个dubbo服务来供我调用发送消息. 这篇博客,介绍下如何利用jmeter来测试dubbo接口,并进行性能测试. ...
- 零基础学 JavaScript 全彩版 明日科技 编著
第1篇 基础知识 第1章 JavaScript简介 1.1 JavaScript简述 1.2 WebStorm的下载与安装 1.3 JavaScript在HTML中的使用 1.3.1 在页面中直接嵌入 ...
- Leetcode 447.回旋镖的数量
回旋镖的数量 给定平面上 n 对不同的点,"回旋镖" 是由点表示的元组 (i, j, k) ,其中 i 和 j 之间的距离和 i 和 k 之间的距离相等(需要考虑元组的顺序). 找 ...
- [python工具][1]sublime安装与配置
http://www.cnblogs.com/wind128/p/4409422.html 1 官网下载版本 http://www.sublimetext.com/3 选择 Windows - al ...
- POJ 2355 Railway tickets
Railway tickets Time Limit: 1000MS Memory Limit: 65536K Total Submissions: 2472 Accepted: 865 De ...
- RHEL 7.3修改网卡命名规则为ethX
RHEL 7网卡默认命名规则:以太网卡(Ethernet)为enX,无线网卡(WLAN)为wlX,修改网卡命名规则为ethX如下: 1.修改/etc/sysconfig/grub文件,添加net.if ...
- C遇到的编译错误整理
1: Permission denied collect2.exe: error: ld returned exit status c:/mingw/bin/../lib/gcc/mingw32/6. ...