seq2seq聊天模型(二)——Scheduled Sampling
使用典型seq2seq模型,得到的结果欠佳,怎么解决
结果欠佳原因在这里
- 在训练阶段的decoder,是将目标样本["吃","兰州","拉面"]作为输入下一个预测分词的输入。
- 而在预测阶段的decoder,是将上一个预测结果,作为下一个预测值的输入。(注意查看预测多的箭头)
这个差异导致了问题的产生,训练和预测的情景不同。
在预测的时候,如果上一个词语预测错误,还后面全部都会跟着错误,蝴蝶效应。

解决办法-Scheduled Sampling
修改训练时decoder的模型
基础模型只会使用真实lable数据作为输入, 现在,train-decoder不再一直都是真实的lable数据作为下一个时刻的输入。
train-decoder时以一个概率P选择模型自身的输出作为下一个预测的输入,以1-p选择真实标记作为下一个预测的输入。
Secheduled sampling(计划采样),即采样率P在训练的过程中是变化的。
一开始训练不充分,先让P小一些,尽量使用真实的label作为输入,随着训练的进行,将P增大,多采用自身的输出作为下一个预测的输入。
随着训练的进行,P越来越大大,train-decoder模型最终变来和inference-decoder预测模型一样,消除了train-decoder与inference-decoder之间的差异
总之:
通过这个scheduled-samping方案,抹平了训练decoder和预测decoder之间的差异!让预测结果和训练时的结果一样。
tensorflow
tensoflow已经完成了这个模型,直接调用,设定参数可以使用
training_helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(
inputs=dec_emb_inputs,
sequence_length=self.dec_sequence_length + 2,
embedding=self.dec_Wemb,
sampling_probability=self.sampling_probability,
time_major=False,
name='training_helper')
self.sampling_probability = tf.placeholder(
tf.float32,
shape=[],
name='sampling_probability')
# 下面这个时feed_dic
# 随着epoch的增大,sampling_probability_list逐渐变为1,即全部采用自身输出作为下个输入,
sampling_probability_list = np.linspace(
start=0.0,
stop=1.0,
num=n_epoch,
dtype=np.float32)
实际结果
效果很好

seq2seq聊天模型(二)——Scheduled Sampling的更多相关文章
- seq2seq聊天模型(一)
原创文章,转载请注明出处 最近完成了sqe2seq聊天模型,磕磕碰碰的遇到不少问题,最终总算是做出来了,并符合自己的预期结果. 本文目的 利用流程图,从理论方面,回顾,总结seq2seq模型, seq ...
- seq2seq聊天模型(三)—— attention 模型
注意力seq2seq模型 大部分的seq2seq模型,对所有的输入,一视同仁,同等处理. 但实际上,输出是由输入的各个重点部分产生的. 比如: (举例使用,实际比重不是这样) 对于输出"晚上 ...
- 深度学习教程 | Seq2Seq序列模型和注意力机制
作者:韩信子@ShowMeAI 教程地址:http://www.showmeai.tech/tutorials/35 本文地址:http://www.showmeai.tech/article-det ...
- django模型二
django模型二 常用模型字段类型 IntegerField → int CharField → varchar TextField → longtext DateFiel ...
- pytorch做seq2seq注意力模型的翻译
以下是对pytorch 1.0版本 的seq2seq+注意力模型做法语--英语翻译的理解(这个代码在pytorch0.4上也可以正常跑): # -*- coding: utf-8 -*- " ...
- socket实现聊天功能(二)
socket实现聊天功能(二) WebSocket协议是建立在HTTP协议之上,因此创建websocket服务时需要调用http模块的createServer方法.将生成的server作为参数传入so ...
- [Beego模型] 二、CRUD 操作
[Beego模型] 一.ORM 使用方法 [Beego模型] 二.CRUD 操作 [Beego模型] 三.高级查询 [Beego模型] 四.使用SQL语句进行查询 [Beego模型] 五.构造查询 [ ...
- {django模型层(二)多表操作}一 创建模型 二 添加表记录 三 基于对象的跨表查询 四 基于双下划线的跨表查询 五 聚合查询、分组查询、F查询和Q查询
Django基础五之django模型层(二)多表操作 本节目录 一 创建模型 二 添加表记录 三 基于对象的跨表查询 四 基于双下划线的跨表查询 五 聚合查询.分组查询.F查询和Q查询 六 xxx 七 ...
- {03--CSS布局设置} 盒模型 二 padding bode margin 标准文档流 块级元素和行内元素 浮动 margin的用法 文本属性和字体属性 超链接导航栏 background 定位 z-index
03--CSS布局设置 本节目录 一 盒模型 二 padding(内边距) 三 boder(边框) 四 简单认识一下margin(外边距) 五 标准文档流 六 块级元素和行内元素 七 浮动 八 mar ...
随机推荐
- Comet OJ - Contest #5 简要题解
好久没更博了,还是象征性地更一次. 依然延续了简要题解的风格. 题目链接 https://cometoj.com/contest/46 题解 A. 迫真字符串 记 \(s_i\) 表示数字 \(i\) ...
- 使用vue-cli创建vue工程
在Windows环境下,打开命令行窗口,跳转至想创建工程的路径. 如:D:\MyWork\22_Github\rexel-cn\rexel-jarvis 创建vue工程,命令:vue create r ...
- aliplay获取播放时长
<div id="player-con" class="frequency-pic"></div> <link rel=" ...
- (二)如何利用C# Roslyn编译器写一个简单的代码提示/错误检查?
上一篇我们讲了如何建立一个简单的Roslyn分析项目如分析检查我们的代码. 今天我们主要介绍各个项目中具体的作用以及可视化分析工具. 还是这种截图,可以看到解决方案下一共有三个项目. Analyzer ...
- CSS一些常用样式
限制行数溢出省略号 display: -webkit-box; -webkit-box-orient: vertical; -webkit-line-clamp: ; overflow: hidden ...
- XCode下在不同位置声明变量的用法(转)
XCode下在不同位置声明变量的用法 方式一:直接在.h文件@interface中的大括号中声明. @interface Test : NSObject { NSString *str; // 私有变 ...
- JavaScript--常用对象的属性及方法(1)
1.Number对象(基本数据类型) Number对象的方法大多是一些强制转换方法,如果转换失败返回NaN,以下举例中用number来代替具体数字: *console.log在控制台输出(键盘F12可 ...
- redis数据结构分析 (redisObject、SDS)
redis是一个key-value储存系统.和Memcached类似,它支持存储的value类型相对更多,包括string(字符串).list(链表).set(集合).zset(sorted set ...
- 安装jdk配置环境变量后jps command not found
配置Java环境变量的时候一般是 vi /etc/profile 然后按两个大写的G就会跑到最后一行去,然后配置写入下文: 这个时候你jps查看Java的进程会出现: 分析原因: 一般是配置之后,没有 ...
- 十年阿里顶级架构师教你怎么使用Java来搭建微服务
微服务背后的大理念是将大型.复杂且历时长久的应用在架构上设计为内聚的服务,这些服务能够随着时间的流逝而演化.本文主要介绍了利用 Java 生态系统构建微服务的多种方法,并分析了每种方法的利弊. 快速预 ...