论文PDF地址:https://arxiv.org/pdf/2110.07602.pdf

转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

P-Tuning v2

摘录自第三部分

桔色块指代可训练的prompt embedding;蓝色块是由固定(冻结)的预训练语言模型 存储或计算的embedding。

Deep Prompt Tuning

continuous prompts(连续提示) 仅仅能够插入到input embedding序列层。如此,有两个问题:首先由于序列长度的约束限制,可调参数的数量有限。其次,输入的embedding对模型预测有间接的影响。

为了解决这些问题,P-Tuning v2使用deep prompt tuning的方案。正如上图的b部分,prompt作为prefix token插入到不同的层中。一方面,p-tuning v2有更多可调的特定任务参数(从 0.01% 到 0.1%~3%),扩大了任务的容量也提高了参数效率;另一方面,添加到更深层的prompt对模型的预测会有更多直接的影响。

转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

理解

在P-tuning v2的方案中,从图直观来看,有两个关键的点:

  1. prompts会加在序列的前端,而不仅仅是插入到input embedding
  2. 每一层都会插入prompts

v2版本主要基于p-tuning和prefix-tuning技术。prompt 向量是在模型的 embedding 层与其他输入 token 的 embedding 相拼接的,且通过在预训练模型的每一层引入可训练的 prompt 向量来提高模型对特定任务的适应性。

p-tuning主要是利用一个prompt encoder,将prompt先encoder再与input embedding进行拼接。

prefix-tuning是在Transformer的Encoder和Decoder的网络中都加了一些特定的前缀。

而基于这两种技术的v2版本,则是将两者结合。在embedding与transformer模块都做了prompt向量的插入。

ChatGLM中,首先要对prompt做encode,作为前缀prefix拼接插入到input embedding与transformer模型中。


# 转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/ class PrefixEncoder(torch.nn.Module):
"""
The torch.nn model to encode the prefix
Input shape: (batch-size, prefix-length)
Output shape: (batch-size, prefix-length, 2*layers*hidden)
""" def __init__(self, config):
super().__init__()
self.prefix_projection = config.prefix_projection
if self.prefix_projection:
# Use a two-layer MLP to encode the prefix
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
self.trans = torch.nn.Sequential(
torch.nn.Linear(config.hidden_size, config.hidden_size),
torch.nn.Tanh(),
torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
)
else:
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2) def forward(self, prefix: torch.Tensor):
if self.prefix_projection:
prefix_tokens = self.embedding(prefix)
past_key_values = self.trans(prefix_tokens)
else:
past_key_values = self.embedding(prefix)
return past_key_values

在ChatGLMModel中调用并插入到每一个transformer模型层中。

class ChatGLMModel(ChatGLMPreTrainedModel):
'''
省略其它....
'''
def __init__(self, config: ChatGLMConfig, empty_init=True):
if self.pre_seq_len is not None:
for param in self.parameters():
param.requires_grad = False
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
# encode prompt
self.prefix_encoder = PrefixEncoder(config)
self.dropout = torch.nn.Dropout(0.1) # 调用prompt
def get_prompt(self, batch_size, device, dtype=torch.half):
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
# 调用prompt并返回
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
past_key_values = past_key_values.view(
batch_size,
self.pre_seq_len,
self.num_layers * 2,
self.num_attention_heads,
self.hidden_size // self.num_attention_heads
)
# seq_len, b, nh, hidden_size
past_key_values = self.dropout(past_key_values)
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
# past_key_values = [(v[0], v[1]) for v in past_key_values]
return past_key_values # 返回transformer模型
def get_layer(layer_id):
return GLMBlock(
self.hidden_size,
self.num_attention_heads,
self.layernorm_epsilon,
layer_id,
inner_hidden_size=self.inner_hidden_size,
hidden_size_per_attention_head=self.hidden_size_per_attention_head,
layernorm=LayerNorm,
use_bias=True,
params_dtype=self.params_dtype,
position_encoding_2d=self.position_encoding_2d,
empty_init=empty_init
) def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
# 其它代码
if past_key_values is None:
if self.pre_seq_len is not None:
# 调用prompt
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device,
dtype=inputs_embeds.dtype)
else:
past_key_values = tuple([None] * len(self.layers)) if attention_mask is None:
attention_mask = self.get_masks(
input_ids,
device=input_ids.device
)
# 其它代码
for i, layer in enumerate(self.layers): if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# 准备参数传递到layer
layer_past = past_key_values[i]
# 每个layer 是一个GLMBlock即transformer模型层
if self.gradient_checkpointing and self.training:
# 将prompt传递到每个层中
layer_ret = torch.utils.checkpoint.checkpoint(
layer,
hidden_states,
position_ids,
attention_mask,
torch.tensor(i),
layer_past,
use_cache,
output_attentions
)
else:
layer_ret = layer(
hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
layer_id=torch.tensor(i),
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions
)
# 其它代码

参考

大模型微调之P-tuning方法解析

通俗解读大模型微调(Fine Tuning)

转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

聊聊ChatGLM中P-tuning v2的应用的更多相关文章

  1. 简单聊聊java中的final关键字

    简单聊聊java中的final关键字 日常代码中,final关键字也算常用的.其主要应用在三个方面: 1)修饰类(暂时见过,但是还没用过); 2)修饰方法(见过,没写过); 3)修饰数据. 那么,我们 ...

  2. Ubuntu_ROS中应用kinect v2笔记

    Ubuntu_ROS中应用kinect v2笔记 个人觉得最重要的资料如下: 1. Microsoft Kinect v2 Driver Released http://www.ros.org/new ...

  3. 聊聊iOS中网络编程长连接的那些事

    1.长连接在iOS开发中的应用 常见的短连接应用场景:一般的App的网络请求都是基于Http1.0进行的,使用的是NSURLConnection.NSURLSession或者是AFNetworking ...

  4. ACM中使用 JAVA v2. 1

    ACM中使用JAVA v2.1 严明超 (Blog:mingchaoyan.blogbus.com Email:mingchaoyan@gmail.com) 0.前 言 文前声明:本文只谈java用于 ...

  5. 【小家Spring】聊聊Spring中的数据绑定 --- BeanWrapper以及内省Introspector和PropertyDescriptor

    #### 每篇一句 > 千古以来要饭的没有要早饭的,知道为什么吗? #### 相关阅读 [[小家Spring]聊聊Spring中的数据转换:Converter.ConversionService ...

  6. 【小家Spring】聊聊Spring中的数据绑定 --- DataBinder本尊(源码分析)

    每篇一句 唯有热爱和坚持,才能让你在程序人生中屹立不倒,切忌跟风什么语言或就学什么去~ 相关阅读 [小家Spring]聊聊Spring中的数据绑定 --- 属性访问器PropertyAccessor和 ...

  7. 聊聊 Vue 中 axios 的封装

    聊聊 Vue 中 axios 的封装 axios 是 Vue 官方推荐的一个 HTTP 库,用 axios 官方简介来介绍它,就是: Axios 是一个基于 promise 的 HTTP 库,可以用在 ...

  8. 聊聊 Vue 中 provide/inject 的应用

    众所周知,在组件式开发中,最大的痛点就在于组件之间的通信.在 Vue 中,Vue 提供了各种各样的组件通信方式,从基础的 props/$emit 到用于兄弟组件通信的 EventBus,再到用于全局数 ...

  9. 聊聊GIS中的坐标系|再版

    本文约6500字,建议阅读时间15分钟. 作者:博客园/B站/知乎/csdn/小专栏 @秋意正寒 版权:转载请告知,并在转载文上附上转载声明与原文链接(https://www.cnblogs.com/ ...

  10. 简单聊聊CSS中的3D技术之“立方体”

    简单聊聊CSS中的3D技术之“立方体” 大家好,我是今天的男一号,我叫小博主. 今天来聊一下我在前端“逆战班”学习中遇到的颇为有趣的3D知识.前端学习3周,见识稀疏,在下面的分享中如有不对的地方请大家 ...

随机推荐

  1. Python基础——二分法、面向过程编程思想、有名函数、lambda、max、_min的应用、sorted排序、map的应用、filter的应用、reduce的应用

    文章目录 内容回顾 二分法 伪代码模板 面向过程编程思想 函数式 def用于定义有名函数 lambda用于定义匿名函数 调用匿名函数 匿名函数作用 匿名函数的示范 max的应用 min的应用 sort ...

  2. Arduino 麦克风声音传感器指南

    麦克风声音传感器 麦克风声音传感器,顾名思义,检测声音.它可以测量声音的响度. 这些传感器的种类繁多.  在下图中,您可以看到 Arduino 最常用的. 最左边是KY-038,右边是LM393麦克风 ...

  3. 【动画进阶】神奇的背景,生化危机4日食 Loading 动画还原

    最近,在 Steam 玩一款老游戏(生化危机 4 重置版),其中,每当游戏转场的过程中,都有这么一个有趣的 Loading 动画: 整个效果有点类似于日食效果,中间一圈黑色,向外散发着太阳般的光芒. ...

  4. 文心一言 VS 讯飞星火 VS chatgpt (110)-- 算法导论10.2 1题

    一.用go语言,单链表上的动态集合操作 INSERT 能否在 O(1)时间内实现?DELETE 操作呢? 文心一言: 在单链表中,INSERT 和 DELETE 操作的时间复杂度通常是 O(n),其中 ...

  5. slice简介

    简介 Go语言中的切片(slice)是一种灵活的数据结构,它构建在数组之上并提供了方便的方式来操作数组的一部分.切片的底层实现涉及到数组和一些元数据.以下是Golang切片的底层实现的详细介绍: 底层 ...

  6. 新手面对安卓6.0以上的版本时出现一个关于文件权限检测的问题,报错为:“无法解析符号 'checkSelfPermission'”,解决办法

    [[注意]:这只是笔者在遇到这个问题时的解决方法,如果对您毫无帮助,请自寻他法!!!] 面对新手:在简单做一个音乐播放程序时,如果面对安卓6.0以上的版本,就会出现一个关于文件权限检测的问题,报错为: ...

  7. QT(5)-QHeaderView

    @ 目录 1 说明 2 函数 2.1 级联调整大小 2.2 默认对齐方式 2.3 count() 2.4 表头默认单元格大小 2.5 hiddenSectionCount() 2.6 分区显示和隐藏 ...

  8. html笔记重点

    第五周-周二 一.视频和音频 <video src="路径" controls="controls"></video> 1.加contr ...

  9. Java开发中的工作流程和步骤

    前言 随着环境的变迁,大家总会更换工作,有裁员的,有跳槽的,除了进进出出的老人,还有源源不断入坑的新人. 很多人入职之后还不知道怎么快速适应工作,对我而言,除去寥寥可数的同事感情,对我而言,更换工作更 ...

  10. URL, URI 和 URN 之间的区别

    英文原文:What's the difference between a URI and a URL?  URI 标识一个事物 , URL 定位一个事物:然而,位置同样可以标识一个事物,所以,每个 U ...