聊聊HuggingFace Transformer
概述
项目组件
一个完整的transformer模型主要包含三部分:Config、Tokenizer、Model。
Config
用于配置模型的名称、最终输出的样式、隐藏层宽度和深度、激活函数的类别等。
示例:
{
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"transformers_version": "4.6.0.dev0",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 30522
}
Tokenizer
将纯文本转换为编码的过程(注意:该过程并不会生成词向量)。由于模型(Model)并不能识别(或很好的识别)文本数据,因此对于输入的文本需要做一层编码。在这个过程中,首先会将输入文本分词而后添加某些特殊标记([MASK]标记、[SEP]、[CLS]标记),比如断句等,最后就是转换为数字类型的ID(也可以理解为是字典索引)。
示例:
pt_batch = tokenizer(
["We are very happy to show you the Transformers library.",
"We hope you don't hate it."],
padding=True,
truncation=True,
max_length=5,
return_tensors="pt"
)
## 其中,当使用list作为batch进行输入时,使用到的参数注解如下:
## padding:填充,是否将所有句子pad到同一个长度。
## truncation:截断,当遇到超过max_length的句子时是否直接截断到max_length。
## return_tensors:张量返回值,"pt"表示返回pytorch类型的tensor,"tf"表示返回TensorFlow类型的tensor,"np"表示Numpy数组。
Model
AI模型(指代基于各种算法模型,比如预训练模型、深度学习算法、强化学习算法等的实现)的抽象概念。
除了初始的Bert、GPT等基本模型,针对下游任务,还定义了诸如BertForQuestionAnswering等下游任务模型。
Transformer使用
pipeline的使用
transformer库中最基本的对象是pipeline()函数。它将模型与其必要的预处理和后处理步骤连接起来,使我们能够直接输入任何文本并获得可理解的答案:
from transformers import pipeline
classifier = pipeline("sentiment-analysis")
classifier("I've been waiting for a HuggingFace course my whole life.")
[{'label': 'POSITIVE', 'score': 0.9598047137260437}]
默认情况下,该pipeline函数选择一个特定的预训练模型,该模型已经过英语情感分析的微调。当创建classifier对象时,将下载并缓存模型。如果重新运行该命令,则将使用缓存的模型,并且不需要再次下载模型。
调用pipeline函数指定预训练模型,有三个主要步骤:
- 输入的文本被预处理成模型(Model)可以理解的格式的数据(就是上述中Tokenizer组件的处理过程)。
- 预处理后的数据作为输入参数传递给模型(Model)。
- 模型的预测结果(输出内容)是经过后处理的,可供理解。
目前可用的pipelines如下:
- feature-extraction(特征提取)
- fill-mask
- ner(命名实体识别)
- question-answering(自动问答)
- sentiment-analysis(情感分析)
- summarization(摘要)
- text-generation(文本生成)
- translation(翻译)
- zero-shot-classification(文本分类)
完整说明可参考:pipelines示例说明
pipeline的原理
如上所述,pipeline将三个步骤组合在一起:预处理、通过模型传递输入以及后处理:

Tokenizer的预处理
与其他神经网络一样,Transformer 模型无法直接处理原始文本,因此pipeline的第一步是将文本输入转换为模型可以理解的数字。为此,我们使用分词器,它将负责:
- 将输入的文本分词,即拆分为单词、子单词或符号(如标点符号),这些被称为tokens(标记)。
- 将每个token映射到一个整数。
- 添加可能对模型有用的额外输入(微调)。
预训练模型完成后,所有的预处理需要完全相同的方式完成,因此我们首先需要从Model Hub下载该信息。 为此,我们使用 AutoTokenizer 类及其 from_pretrained() 方法。 使用模型的checkpoint,它将自动获取与模型的标记生成器关联的数据并缓存它。
由于情感分析pipeline的checkpoint是 distilbert-base-uncased-finetuned-sst-2-english ,因此我们运行以下命令:
from transformers import AutoTokenizer
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
如此便得到tokenizer对象后,后续只需将文本参数输入即可,便完成了分词-编码-转换工作。
使用Transformers框架不需要担心使用哪个后端 ML 框架(PyTorch、TensorFlow、Flax)。Transformer 模型只接受tensors(张量)作为输入参数。
注:NumPy 数组可以是标量 (0D)、向量 (1D)、矩阵 (2D) 或具有更多维度。它实际上是一个张量。
tokenizer中的return_tensors 参数定了返回的张量类型(PyTorch、TensorFlow 或普通 NumPy)
raw_inputs = [
"I've been waiting for a HuggingFace course my whole life.",
"I hate this so much!",
]
inputs = tokenizer(raw_inputs, padding=True, truncation=True, return_tensors="pt")
print(inputs)
以下是tokenizer返回的PyTorch张量的结果:
{
'input_ids': tensor([
[101, 1045, 1005, 2310, 2042, 3403, 2005, 1037, 17662, 12172, 2607, 2026, 2878, 2166, 1012, 102],
[101, 1045, 5223, 2023, 2061, 2172, 999, 102, 0, 0, 0, 0, 0, 0, 0, 0]
]),
'attention_mask': tensor([
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]
])
}
tokenizer的返回值参数说明如下:
- 输出input_ids:经过编码后的数字(即前面所说的张量数据)。
- 输出token_type_ids:因为编码的是两个句子,这个list用于表明编码结果中哪些位置是第1个句子,哪些位置是第2个句子。具体表现为,第2个句子的位置是1,其他位置是0。
- 输出special_tokens_mask:用于表明编码结果中哪些位置是特殊符号,具体表现为,特殊符号的位置是1,其他位置是0。
- 输出attention_mask:用于表明编码结果中哪些位置是PAD。具体表现为,PAD的位置是0,其他位置是1。
- 输出length:表明编码后句子的长度。
Model层的处理
我们可以像使用tokenizer一样下载预训练模型。 Transformers 提供了一个 AutoModel 类,它也有一个 from_pretrained() 方法:
from transformers import AutoModel
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
model = AutoModel.from_pretrained(checkpoint)
## inputs的参数值是前面tokenizer的输出
outputs = model(**inputs)
与初始化tokenizer一样,将相同的checkpoint作为参数,初始化一个Model;而后将tokenizer的输出数据——张量数据作为参数输入到Model中。
模型的处理架构流程图,如下:
Transformer network模块有两层:嵌入层(Embeddings)、后续层(Layers)。嵌入层将标记化输入中的每个输入 ID 转换为表示关联标记的向量。 随后的层使用注意力机制操纵这些向量来产生句子的最终表示。
Transformer的输出,作为Hidden States,也可以理解为是Feature(特征数据)。
而这些特征数据,将作为模型的另一些部分的输入,比如Head层;最终由Head层输出模型的结果。
参考:
https://huggingface.co/learn/nlp-course/chapter2/1?fw=pt
聊聊HuggingFace Transformer的更多相关文章
- 昇思MindSpore全场景AI框架 1.6版本,更高的开发效率,更好地服务开发者
摘要:本文带大家快速浏览昇思MindSpore全场景AI框架1.6版本的关键特性. 全新的昇思MindSpore全场景AI框架1.6版本已发布,此版本中昇思MindSpore全场景AI框架易用性不断改 ...
- 利用Hugging Face中的模型进行句子相似性实践
Hugging Face是什么?它作为一个GitHub史上增长最快的AI项目,创始人将它的成功归功于弥补了科学与生产之间的鸿沟.什么意思呢?因为现在很多AI研究者写了大量的论文和开源了大量的代码, ...
- 大规模 Transformer 模型 8 比特矩阵乘简介 - 基于 Hugging Face Transformers、Accelerate 以及 bitsandbytes
引言 语言模型一直在变大.截至撰写本文时,PaLM 有 5400 亿参数,OPT.GPT-3 和 BLOOM 有大约 1760 亿参数,而且我们仍在继续朝着更大的模型发展.下图总结了最近的一些语言模型 ...
- 聊聊Unity项目管理的那些事:Git-flow和Unity
0x00 前言 目前所在的团队实行敏捷开发已经有了一段时间了.敏捷开发中重要的一个话题便是如何对项目进行恰当的版本管理.项目从最初使用svn到之后的Git One Track策略再到现在的GitFlo ...
- Mono为何能跨平台?聊聊CIL(MSIL)
前言: 其实小匹夫在U3D的开发中一直对U3D的跨平台能力很好奇.到底是什么原理使得U3D可以跨平台呢?后来发现了Mono的作用,并进一步了解到了CIL的存在.所以,作为一个对Unity3D跨平台能力 ...
- fir.im Weekly - 聊聊 Google 开发者大会
中国互联网的三大错觉:索尼倒闭,诺基亚崛起,谷歌重返中国.12月8日,2016 Google 开发者大会正式发布了Google Developers 中国网站 ,包含了Android Develope ...
- 聊聊asp.net中Web Api的使用
扯淡 随着app应用的崛起,后端服务开发的也越来越多,除了很多优秀的nodejs框架之外,微软当然也会在这个方面提供更便捷的开发方式.这是微软一贯的作风,如果从开发的便捷性来说的话微软是当之无愧的老大 ...
- 没有神话,聊聊decimal的“障眼法”
0x00 前言 在上一篇文章<妥协与取舍,解构C#中的小数运算>的留言区域有很多朋友都不约而同的说道了C#中的decimal类型.事实上之前的那篇文章的立意主要在于聊聊使用二进制的计算机是 ...
- 聊聊 C 语言中的 sizeof 运算
聊聊 sizeof 运算 在这两次的课上,同学们已经学到了数组了.下面几节课,应该就会学习到指针.这个速度的确是很快的. 对于同学们来说,暂时应该也有些概念理解起来可能会比较的吃力. 先说一个概念叫内 ...
- 聊聊 Apache 开源协议
摘要 用一句话概括 Apache License 就是,你可以用这代码,但是如果开源你必须保留我写的声明:你可以改我的代码,但是如果开源你必须写清楚你改了哪些:你可以加新的协议要求,但不能与我所 公布 ...
随机推荐
- 2021-09-09:企鹅厂活动发文化衫,文化衫有很多种,企鹅们都穿文化衫。采访中,企鹅会说还有多少企鹅跟他穿一种文化衫。有些企鹅没被采访到,将这些回答放在answers数组里,返回活动中企鹅的最少数
2021-09-09:企鹅厂活动发文化衫,文化衫有很多种,企鹅们都穿文化衫.采访中,企鹅会说还有多少企鹅跟他穿一种文化衫.有些企鹅没被采访到,将这些回答放在answers数组里,返回活动中企鹅的最少数 ...
- Django4全栈进阶之路9 STATIC静态文件路径设置、MEDIA媒体文件路径设置
在 Django 4 中,可以在 settings.py 文件中设置 STATICFILES_DIRS 来指定应用程序静态文件所在的文件夹路径,设置 STATIC_ROOT 来指定收集所有应用程序静态 ...
- C# decimal double 获取一组数字 小数点后最多有几位
有一组数字,想判断一组数字中最多的有几位小数,乘以10的指定幂,转为整数,此处教大家一个高级的写法,拒接无脑for循环 decimal: decimal[] numbers = new decimal ...
- cv学习总结(11.6-11.13)
两层全连接神经网络的内容要比想象中的多很多,代码量也很多,在cs231n只用了15分钟讲解的东西我用了一周半的时间才完全的消化理解,这周终于完成了全连接神经网络博客的书写https://www.cnb ...
- 某表格常用api
这是一个神奇的网站,可作为免费的数据存储平台,已白嫖多年 通过调用接口可以方便的实现增删改查.修改www前缀为vip,还能嫖vip服务器 我常常用来写入程序的日志记录,记录/更新一些关键key 不需要 ...
- Mybatis的ResultMap对column和property的理解
Mybatis的ResultMap对column和property的理解 首先,先看看这张图,看能不能一下看明白: select元素有很多属性(这里说用的比较多的): id:命名空间唯一标识,可以被用 ...
- 在 Istio 服务网格内连接外部 MySQL 数据库
为了方便理解,以 Istio 官方提供的 Bookinfo 应用示例为例,利用 ratings 服务外部 MySQL 数据库. Bookinfo应用的架构图如下: 其中,包含四个单独的微服务: pro ...
- FPGA加速技术:如何提高系统的性能和安全性
目录 1. 引言 2. 技术原理及概念 2.1 基本概念解释 2.2 技术原理介绍 3. 实现步骤与流程 3.1 准备工作:环境配置与依赖安装 3.2 核心模块实现 3.3 集成与测试 4. 应用示例 ...
- HLS AES加密
HLS AES加密 HLS AES加密介绍 HLS AES加密是一种用于保护HLS流内容安全的加密技术.它通过将HLS媒体文件进行分段,并使用AES加密算法对每个片段进行加密,从而防止未经授权的访问和 ...
- 保护数据隐私:深入探索Golang中的SM4加密解密算法
前言 最近做的项目对安全性要求比较高,特别强调:系统不能涉及MD5.SHA1.RSA1024.DES高风险算法. 那用什么嘞?甲方:建议用国产密码算法SM4. 擅长敏捷开发(CV大法)的我,先去Git ...