下载HuggingFace 模型

首先打开网址:https://huggingface.co/models 这个网址是huggingface/transformers支持的所有模型,目前大约一千多个。搜索gpt2(其他的模型类似,比如bert-base-uncased等),并点击进去。

进入之后,可以看到gpt2模型的说明页,点击页面中的list all files in model,可以看到模型的所有文件。

通常需要把保存的是三个文件以及一些额外的文件

  • 配置文件 -- config.json
  • 词典文件 -- vocab.json
  • 预训练模型文件

    pytorch -- pytorch_model.bin文件

    tensorflow 2 -- tf_model.h5文件

额外的文件,指的是merges.txtspecial_tokens_map.jsonadded_tokens.jsontokenizer_config.jsonsentencepiece.bpe.model等,这几类是tokenizer需要使用的文件,如果出现的话,也需要保存下来。没有的话,就不必在意。如果不确定哪些需要下,哪些不需要的话,可以把图1中类似的文件全部下载下来。

看下这几个文件都是什么:

  • config.json配置文件



    包含了模型的类型、激活函数等配置信息

  • vocab.json 词典文件

  • merges.txt

使用HuggingFace模型

将上述下载的模型存储在本地:

加载本地HuggingFace模型

  1. 导入依赖
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

导入PyTorch框架和HuggingFace Transformers库的GPT-2组件

  1. 初始化分词器
tokenizer = GPT2Tokenizer.from_pretrained("../../Models/gpt2/")
text = "Who was Jim Henson ? Jim Henson was a"
indexed_tokens = tokenizer.encode(text)
print(indexed_tokens) # [8241, 373, 5395, 367, 19069, 5633]
# 转换为torch Tensor
token_tensor = torch.tensor([indexed_tokens])
print(token_tensor) # tensor([[ 8241, 373, 5395, 367, 19069, 5633]])

tokenizer.encode(text)执行流程如下:

分词器处理:

首先将文本分词为子词(subwords),如:

"Who was Jim Henson ?" → ['Who', 'Ġwas', 'ĠJim', 'ĠHen', 'son', '?']

ID转换:

然后将每个子词转换为对应的整数ID(来自vcab.json),如:

['Who', 'Ġwas', 'ĠJim', 'ĠHen', 'son', '?'] -> [8241, 373, 5395, 367, 19069, 5633]

可以查看vcab.json文件:

返回的是 token ID 列表(整数列表),而非词向量

  1. 加载预训练模型并预测
# 加载预训练模型
model = GPT2LMHeadModel.from_pretrained("../../Models/gpt2/")
# print(model) model.eval() with torch.no_grad():
outputs = model(token_tensor)
predictions = outputs[0] # 我们需要预测下一个单词,所以是使用predictions第一个batch,最后一个词的logits去计算
# predicted_index = 582,通过计算最大得分的索引得到的
predicted_index = torch.argmax(predictions[0, -1, :]).item()
# 反向解码为我们需要的文本
predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])
# 解码后的文本:'Who was Jim Henson? Jim Henson was a man'
# 成功预测出单词 'man'
print(predicted_text)

输出结果:

使用HuggingFace 模型并预测的更多相关文章

  1. Python之逻辑回归模型来预测

    建立一个逻辑回归模型来预测一个学生是否被录取. import numpy as np import pandas as pd import matplotlib.pyplot as plt impor ...

  2. 迁移学习——使用Tensorflow和VGG16预训模型进行预测

    使用Tensorflow和VGG16预训模型进行预测 from:https://zhuanlan.zhihu.com/p/28997549   fast.ai的入门教程中使用了kaggle: dogs ...

  3. matlab(5) : 求得θ值后用模型来预测 / 计算模型的精度

    求得θ值后用模型来预测 / 计算模型的精度  ex2.m部分程序 %% ============== Part 4: Predict and Accuracies ==============% Af ...

  4. R语言利用ROCR评测模型的预测能力

    R语言利用ROCR评测模型的预测能力 说明 受试者工作特征曲线(ROC),这是一种常用的二元分类系统性能展示图形,在曲线上分别标注了不同切点的真正率与假正率.我们通常会基于ROC曲线计算处于曲线下方的 ...

  5. Keras 构建DNN 对用户名检测判断是否为非法用户名(从数据预处理到模型在线预测)

    一.  数据集的准备与预处理 1 . 收集dataset (大量用户名--包含正常用户名与非法用户名) 包含两个txt文件  legal_name.txt  ilegal_name.txt. 如下图所 ...

  6. 用交叉验证改善模型的预测表现-着重k重交叉验证

    机器学习技术在应用之前使用“训练+检验”的模式(通常被称作”交叉验证“). 预测模型为何无法保持稳定? 让我们通过以下几幅图来理解这个问题: 此处我们试图找到尺寸(size)和价格(price)的关系 ...

  7. 学习ML.NET(2): 使用模型进行预测

    训练模型 在上一篇文章中,我们已经通过LearningPipeline训练好了一个“鸢尾花瓣预测”模型, var model = pipeline.Train<IrisData, IrisPre ...

  8. Machine Learning for hackers读书笔记(五)回归模型:预测网页访问量

    线性回归函数 model<-lm(Weight~Height,data=?) coef(model):得到回归直线的截距 predict(model):预测 residuals(model):残 ...

  9. SVM模型进行分类预测时的参数调整技巧

    一:如何判断调参范围是否合理 正常来说,当我们参数在合理范围时,模型在训练集和测试集的准确率都比较高:当模型在训练集上准确率比较高,而测试集上的准确率比较低时,模型处于过拟合状态:当模型训练集和测试集 ...

  10. 深度学习利器:TensorFlow在智能终端中的应用——智能边缘计算,云端生成模型给移动端下载,然后用该模型进行预测

    前言 深度学习在图像处理.语音识别.自然语言处理领域的应用取得了巨大成功,但是它通常在功能强大的服务器端进行运算.如果智能手机通过网络远程连接服务器,也可以利用深度学习技术,但这样可能会很慢,而且只有 ...

随机推荐

  1. Java获得当前日期是星期几

    第一种方法:   /**   * 获取当前日期是星期几<br>   *   * @param date   * @return 当前日期是星期几   */   public String ...

  2. Mybatis中的 switch

    我这遇到个问题,如果 type字段为null则查询type is null,否则查对应的值 询问 AI 得知,可以用choose-when-otherwise <select> selec ...

  3. Jenkins - [02] 安装部署

    题记部分 一.Jenkins是什么   Jenkins,原名Hudson,2011年改为现在的名字,它是一个开源的实现持续集成的软件工具. 官网:https://www.jenkins.io/ 官网: ...

  4. Go红队开发—并发编程

    目录 并发编程 go协程 chan通道 无缓冲通道 有缓冲通道 创建⽆缓冲和缓冲通道 等协程 sync.WaitGroup同步 Runtime包 Gosched() Goexit() 区别 同步变量 ...

  5. Twain Capabilities属性

    Asynchronous Device Events 异步设备事件 CAP_DEVICEEVENT MSG_SET选择应用程序希望Twain源报告的事件; MSG_RESET返回Twain源的首选设置 ...

  6. 基于Microsoft.Extensions.AI核心库实现RAG应用

    大家好,我是Edison. 之前我们了解 Microsoft.Extensions.AI 和 Microsoft.Extensions.VectorData 两个重要的AI应用核心库.基于对他们的了解 ...

  7. MySQL索引最左原则:从原理到实战的深度解析

    MySQL索引最左原则:从原理到实战的深度解析 一.什么是索引最左原则? 索引最左原则是MySQL复合索引使用的核心规则,简单来说: "当使用复合索引(多列索引)时,查询条件必须从索引的最左 ...

  8. mysql [Err] 1067 - Invalid default value for

    出错原因 mysql5.7版本引起的默认值不兼容的问题,同样的问题在mysql8.0可能也会出现. 出问题的值有: NO_ZERO_IN_DATE 在严格模式下,不允许日期和月份为零. NO_ZERO ...

  9. pandas(进阶操作)-- 政治献金项目数据分析

    博客地址:https://www.cnblogs.com/zylyehuo/ 开发环境 anaconda 集成环境:集成好了数据分析和机器学习中所需要的全部环境 安装目录不可以有中文和特殊符号 jup ...

  10. 【教程】Windows10系统激活

    Windows10系统激活 一.找一个激活码 到百度搜索,筛选发表日期在最近一个月或者一周之内的 二.以管理员身份打开cmd 按Win+R键,输入cmd打开命令行窗口 按Ctrl+Shift+Esc键 ...