[NLP] The Annotated Transformer 代码修正
1. RuntimeError: "exp" not implemented for 'torch.LongTensor'
class PositionalEncoding(nn.Module)
div_term = torch.exp(torch.arange(0., d_model, 2) *
-(math.log(10000.0) / d_model))
将 “0” 改为 “0.”
否则会报错:RuntimeError: "exp" not implemented for 'torch.LongTensor'
2. RuntimeError: expected type torch.FloatTensor but got torch.LongTensor
class PositionalEncoding(nn.Module)
position = torch.arange(0., max_len).unsqueeze(1)
将 “0” 改为 “0.”
否则会报错:
pe[:, 0::2] = torch.sin(position * div_term)
RuntimeError: expected type torch.FloatTensor but got torch.LongTensor
3. UserWarning: nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.
def make_model
nn.init.xavier_uniform_(p)
将“nn.init.xavier_uniform(p)” 改为 “nn.init.xavier_uniform_(p)”
否则会提示:UserWarning: nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.
4. UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
class LabelSmoothing
self.criterion = nn.KLDivLoss(reduction='sum')
将 “self.criterion = nn.KLDivLoss(size_average=False)” 改为 “self.criterion = nn.KLDivLoss(reduction='sum')”
否则会提示:UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
5. IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number
class SimpleLossCompute
return loss.item() * norm
将 “loss.data[0]” 改为 loss.item(),
否则会报错:IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number
6. floating point exception (core dumped)
直接运行“A First Example”会报错:floating point exception (core dumped)
参考的修改方法:https://github.com/harvardnlp/annotated-transformer/issues/26,该方法中,修改 run_epoch 函数,将计数值转换为numpy。方法:.detach().numpy() 或者直接 .numpy()
但是我试了仍有问题。最后需要将gpu上的先移到cpu中,再进行numpy转换。
以下是自己调整后的代码,是可以正确运行的:
def run_epoch(data_iter, model, loss_compute, epoch = 0):
"Standard Training and Logging Function"
start = time.time()
total_tokens = 0
total_loss = 0
tokens = 0
for i, batch in enumerate(data_iter):
out = model.forward(batch.src, batch.trg, batch.src_mask, batch.trg_mask)
loss = loss_compute(out, batch.trg_y, batch.ntokens) total_loss += loss.detach().cpu().numpy()
total_tokens += batch.ntokens.cpu().numpy()
tokens += batch.ntokens.cpu().numpy()
if i % 50 == 1:
elapsed = time.time() - start
print("Epoch Step: %d Loss: %f Tokens per Sec: %f" % (i, loss.detach().cpu().numpy() / batch.ntokens.cpu().numpy(), tokens / elapsed))
start = time.time()
tokens = 0
return total_loss / total_tokens
7. loss 均为整数
class SimpleLossCompute
在运行“A First Example” 时, 结果显示的 loss 全部是整数,这就很奇怪了。测试后发现,是 class SimpleLossCompute中的返回值的问题,norm这个tensor是int型的,虽然loss.item()是浮点数,但是return loss.item() * norm的值仍是int型tensor.
修改方法:将norm转为float再进行乘法运算:
return loss.item() * norm.float()
[NLP] The Annotated Transformer 代码修正的更多相关文章
- 【转载/修改】ScrollLayout代码修正,追加模仿viewpager滚动速度
组件作用为类似ViewPager但直接插视图的横向滚动容器. 修改自:http://blog.csdn.net/yaoyeyzq/article/details/7571940 在该组件基础上修正了滚 ...
- [The Annotated Transformer] Iterators
Iterators 对torchtext的batch实现的修改算法原理 Batching matters a ton for speed. We want to have very evenly di ...
- [原创]PHP代码修正之CodeSniffer
目录 参考链接 介绍 安装 使用 命令行模式 PHPStorm 让编辑器使用PSR-2标准 集成phpcbf 参考链接 PHP开发规范之使用phpcbf脚本自动修正代码格式 在PhpStorm中使用P ...
- 对于国嵌上学期《一跃进入C大门》Mini2440的代码修正
摸索了几天,加了无数的群,病急乱投医式地问了好多个人,终于改对了代码. 下面先贴出给的范例代码 这是C语言代码,是没有错的. 那么出错的地方就在start.S部分 很明显,MPLLCON地址错误,正确 ...
- NLP之基于Transformer的句子翻译
Transformer 目录 Transformer 1.理论 1.1 Model Structure 1.2 Multi-Head Attention & Scaled Dot-Produc ...
- vscode vue 项目保存运行lint进行代码修正
{ "editor.tabSize": 2, "files.associations": { "*.vue": "vue" ...
- NLP整体流程的代码
import nltk import numpy as np import re from nltk.corpus import stopwords # 1 分词1 text = "Sent ...
- nlp英文的数据清洗代码
1.常用的清洗方式 #coding=utf-8 import jieba import unicodedata import sys,re,collections,nltk from nltk.ste ...
- Transformer各层网络结构详解!面试必备!(附代码实现)
1. 什么是Transformer <Attention Is All You Need>是一篇Google提出的将Attention思想发挥到极致的论文.这篇论文中提出一个全新的模型,叫 ...
随机推荐
- React 中的函数式思想
函数式编程简要概念 函数式编程中一个核心概念之一就是纯函数,如果一个函数满足一下几个条件,就可以认为这个函数是纯函数了: 它是一个函数(废话): 当给定相同的输入(函数的参数)的时候,总是有相同的输出 ...
- Threadlocal线程本地变量理解
转载:https://www.cnblogs.com/chengxiao/p/6152824.html 总结: 作用:ThreadLocal 线程本地变量,可用于分布式项目的日志追踪 用法:在切面中生 ...
- 1,全局变量;2,图形验证码;3,解决bug的毅力
通过这一整天的学习,主要解决了这三个: 1,全局变量 在函数外部定义: var gloabl: function test(){ global = " ": //不能写成va ...
- Http的简单介绍
之前写过一篇使用HttpListener作为简单的HTTP服务器,后面实际项目中就用到了,测试发现,在Win7下如果不是以管理员权限运行程序,使用HttpListener是会出错了. 所以就很好奇HT ...
- luogu题解 P2212 【浇地Watering the Fields】
题目链接: https://www.luogu.org/problemnew/show/P2212 思路: 一道最小生成树裸题(最近居然变得这么水了),但是因为我太蒻,搞了好久,不过借此加深了对最小生 ...
- 小程序存emoji表情 不改变数据库
1.小程序:提交前先编码 encodeURIComponent(data) 2.服务端解码(PHP) urldecode(data) 3.如果有空格字符串的,保存之前先对空格进行处理,不然空格在页面会 ...
- linux查看端口被占用情况
Linux 查看端口占用情况可以使用 lsof 和 netstat 命令. 如果linux中没有这两个命令,则yum安装一下 yum install -y lsof yum install -y ne ...
- 001-CentOS 7系统搭建Rsyslog+LogAnalyzer解决交换机日志收
日志功能对于操作系统是相当重要的,在使用中,无论是系统还是应用等等,出了任何问题,我们首先想到的便是分析日志,查找问题原因.自 CentOS 7 开始,我们的 CentOS 便开始使用 rsyslog ...
- kotlin高阶函数实战&DSL入门
传统函数演示: 这里以电视节目“非诚勿扰”为例,男人去从一大堆美女当中挑选出自己中意的对象,比如台上有24位妹子,其档案如下: 接下来第一个男嘉宾出场啦,如下: 下面用代码来实现一下,比较简单: 先定 ...
- @InitBinder的作用
由@InitBinder表示的方法,可以对WebDataBinder对象进行初始化.WebDataBinder是DataBinder的子类,用于完成由表单到JavaBean属性的绑定. @InitBi ...