使用HuggingFace 模型并预测
下载HuggingFace 模型
首先打开网址:https://huggingface.co/models 这个网址是huggingface/transformers支持的所有模型,目前大约一千多个。搜索gpt2(其他的模型类似,比如bert-base-uncased等),并点击进去。
进入之后,可以看到gpt2模型的说明页,点击页面中的list all files in model,可以看到模型的所有文件。

通常需要把保存的是三个文件以及一些额外的文件
- 配置文件 -- config.json
- 词典文件 -- vocab.json
- 预训练模型文件
pytorch -- pytorch_model.bin文件
tensorflow 2 -- tf_model.h5文件
额外的文件,指的是merges.txt、special_tokens_map.json、added_tokens.json、tokenizer_config.json、sentencepiece.bpe.model等,这几类是tokenizer需要使用的文件,如果出现的话,也需要保存下来。没有的话,就不必在意。如果不确定哪些需要下,哪些不需要的话,可以把图1中类似的文件全部下载下来。

看下这几个文件都是什么:
config.json配置文件

包含了模型的类型、激活函数等配置信息vocab.json 词典文件

merges.txt

使用HuggingFace模型
将上述下载的模型存储在本地:

加载本地HuggingFace模型
- 导入依赖
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
导入PyTorch框架和HuggingFace Transformers库的GPT-2组件
- 初始化分词器
tokenizer = GPT2Tokenizer.from_pretrained("../../Models/gpt2/")
text = "Who was Jim Henson ? Jim Henson was a"
indexed_tokens = tokenizer.encode(text)
print(indexed_tokens) # [8241, 373, 5395, 367, 19069, 5633]
# 转换为torch Tensor
token_tensor = torch.tensor([indexed_tokens])
print(token_tensor) # tensor([[ 8241, 373, 5395, 367, 19069, 5633]])
tokenizer.encode(text)执行流程如下:
分词器处理:
首先将文本分词为子词(subwords),如:
"Who was Jim Henson ?" → ['Who', 'Ġwas', 'ĠJim', 'ĠHen', 'son', '?']
ID转换:
然后将每个子词转换为对应的整数ID(来自vcab.json),如:
['Who', 'Ġwas', 'ĠJim', 'ĠHen', 'son', '?'] -> [8241, 373, 5395, 367, 19069, 5633]
可以查看vcab.json文件:

返回的是 token ID 列表(整数列表),而非词向量
- 加载预训练模型并预测
# 加载预训练模型
model = GPT2LMHeadModel.from_pretrained("../../Models/gpt2/")
# print(model)
model.eval()
with torch.no_grad():
outputs = model(token_tensor)
predictions = outputs[0]
# 我们需要预测下一个单词,所以是使用predictions第一个batch,最后一个词的logits去计算
# predicted_index = 582,通过计算最大得分的索引得到的
predicted_index = torch.argmax(predictions[0, -1, :]).item()
# 反向解码为我们需要的文本
predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])
# 解码后的文本:'Who was Jim Henson? Jim Henson was a man'
# 成功预测出单词 'man'
print(predicted_text)
输出结果:

使用HuggingFace 模型并预测的更多相关文章
- Python之逻辑回归模型来预测
建立一个逻辑回归模型来预测一个学生是否被录取. import numpy as np import pandas as pd import matplotlib.pyplot as plt impor ...
- 迁移学习——使用Tensorflow和VGG16预训模型进行预测
使用Tensorflow和VGG16预训模型进行预测 from:https://zhuanlan.zhihu.com/p/28997549 fast.ai的入门教程中使用了kaggle: dogs ...
- matlab(5) : 求得θ值后用模型来预测 / 计算模型的精度
求得θ值后用模型来预测 / 计算模型的精度 ex2.m部分程序 %% ============== Part 4: Predict and Accuracies ==============% Af ...
- R语言利用ROCR评测模型的预测能力
R语言利用ROCR评测模型的预测能力 说明 受试者工作特征曲线(ROC),这是一种常用的二元分类系统性能展示图形,在曲线上分别标注了不同切点的真正率与假正率.我们通常会基于ROC曲线计算处于曲线下方的 ...
- Keras 构建DNN 对用户名检测判断是否为非法用户名(从数据预处理到模型在线预测)
一. 数据集的准备与预处理 1 . 收集dataset (大量用户名--包含正常用户名与非法用户名) 包含两个txt文件 legal_name.txt ilegal_name.txt. 如下图所 ...
- 用交叉验证改善模型的预测表现-着重k重交叉验证
机器学习技术在应用之前使用“训练+检验”的模式(通常被称作”交叉验证“). 预测模型为何无法保持稳定? 让我们通过以下几幅图来理解这个问题: 此处我们试图找到尺寸(size)和价格(price)的关系 ...
- 学习ML.NET(2): 使用模型进行预测
训练模型 在上一篇文章中,我们已经通过LearningPipeline训练好了一个“鸢尾花瓣预测”模型, var model = pipeline.Train<IrisData, IrisPre ...
- Machine Learning for hackers读书笔记(五)回归模型:预测网页访问量
线性回归函数 model<-lm(Weight~Height,data=?) coef(model):得到回归直线的截距 predict(model):预测 residuals(model):残 ...
- SVM模型进行分类预测时的参数调整技巧
一:如何判断调参范围是否合理 正常来说,当我们参数在合理范围时,模型在训练集和测试集的准确率都比较高:当模型在训练集上准确率比较高,而测试集上的准确率比较低时,模型处于过拟合状态:当模型训练集和测试集 ...
- 深度学习利器:TensorFlow在智能终端中的应用——智能边缘计算,云端生成模型给移动端下载,然后用该模型进行预测
前言 深度学习在图像处理.语音识别.自然语言处理领域的应用取得了巨大成功,但是它通常在功能强大的服务器端进行运算.如果智能手机通过网络远程连接服务器,也可以利用深度学习技术,但这样可能会很慢,而且只有 ...
随机推荐
- Mybatis之Select Count(*)的获取 返回int 的值
本文将介绍,SSM中mybatis 框架如何获取Select Count(*)返回int 的值.1. Service 代码: public boolean queryByunitclass(Strin ...
- VSCode ESLint规则警告屏蔽方法
举例:要屏蔽"Missing trailing comma"或"comma-dangle"警告,你可以使用ESLint的配置选项来设置规则.下面是一些方法,你可 ...
- 【Azure Storage Account】利用App Service作为反向代理, 并使用.NET Storage Account SDK实现上传/下载操作
问题描述 在使用Azure上的存储服务 Storage Account 的时候,有时需要代替 它原本提供的域名进行访问,比如默认的域名为:mystorageaccount.blob.core.chin ...
- 【ABAQUS脚本】后处理快速出图
效果图: # -*- coding: utf-8 -*- # Do not delete the following import lines from abaqus import * from ab ...
- go 定时任务库 cron
简介 在Linux中,Cron是计划任务管理系统,通过crontab命令使任务在约定的时间执行已经计划好的工作,例如定时备份系统数据.周期性清理缓存.定时重启服务等. 本文介绍的cron库是一个用于管 ...
- IvorySQL 增量备份与合并增量备份功能解析
1. 概述 IvorySQL v4 引入了块级增量备份和增量备份合并功能,旨在优化数据库备份与恢复流程.通过 pg_basebackup 工具支持增量备份,显著降低了存储需求和备份时间.同时,pg_c ...
- SQL INSERT批量插入方式
1.常规INSERT写法 INSERT INTO ... VALUES (...); INSERT INTO 表名( `字段1`, `字段2`) VALUES ('字段1的值', '字段2的值') ...
- 大模型提示词(Prompt)模板推荐
只有提示词写得好,与大模型的互动才能更高效.提示词不仅仅是与AI对话的起点,更是驱动模型产生高质量输出的关键因素.本文将介绍大模型提示词的概念.意义,并分享一些实用的提示词模板,帮助AI玩家更好地利用 ...
- Centos系统云主机中nvme盘不可用解决方法
本文分享自天翼云开发者社区<Centos系统云主机中nvme盘不可用解决方法>,作者:P****n 问题描述 Linux系统的云主机使用NVMe盘后,出现非预期的慢IO读写,导致系统或者应 ...
- X86-64位简易系统开发 - 从BIOS阶段开始
最近回顾之前写的代码的时候, 发现了以前本科时还开发过一个64位的操作系统, 不过最终也只是开发到进程切换部分 这是一个涉及到汇编和C语言的一个偏底层偏硬核的项目, 而且为了能够学到更多东西, 使用的 ...