# 参考 https://blog.csdn.net/luoyexuge/article/details/84939755 小做改动

需要:

  github上下载bert的代码:https://github.com/google-research/bert

  下载google训练好的中文语料模型:https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip

使用:

  使用bert,其实是使用几个checkpoint(ckpt)文件。上面下载的zip是google训练好的bert,我们可以在那个zip内的ckpt文件基础上继续训练,获得更贴近具体任务的ckpt文件。

如果是直接使用训练好的ckpt文件(就是bert模型),只需如下代码,定义model,获得model的值

from bert import modeling    
# 使用数据加载BertModel,获取对应的字embedding
model = modeling.BertModel(
config=bert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=use_one_hot_embeddings
)
# 获取对应的embedding 输入数据[batch_size, seq_length, embedding_size]
embedding = model.get_sequence_output()

这里的bert_config 是之前定义的bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file);输入是input_ids, input_mask, segment_ids三个向量;还有两个设置is_training(False), use_one_hot_embedding(False),这样的设置还有很多,这里只列举这两个。。

关于FLAGS,需要提到TensorFlow的flags,相当于配置运行变量,设置如下:

import tensorflow as tf

flags = tf.flags
FLAGS = flags.FLAGS # 预训练的中文model路径和项目路径
bert_path = '/home/xiangbo_wang/xiangbo/NER/chinese_L-12_H-768_A-12/'
root_path = '/home/xiangbo_wang/xiangbo/NER/BERT-BiLSTM-CRF-NER' # 设置bert_config_file
flags.DEFINE_string(
"bert_config_file", os.path.join(bert_path, 'bert_config.json'),
"The config json file corresponding to the pre-trained BERT model."
)

关于输入的三个向量,具体内容可以参照之前的博客https://www.cnblogs.com/rucwxb/p/10277217.html

input_ids, segment_ids 分别是 token embedding, segment embedding

position embedding会自动生成

input_mask 是input中需要mask的位置,本来是随机取一部分,这里的做法是把全部输入位置都mask住。

获得输入的这三个向量的方式如下:

# 获得三个向量的函数
def inputs(vectors,maxlen=10):
length=len(vectors)
if length>=maxlen:
return vectors[0:maxlen],[1]*maxlen,[0]*maxlen
else:
input=vectors+[0]*(maxlen-length)
mask=[1]*length+[0]*(maxlen-length)
segment=[0]*maxlen
return input,mask,segment # 测试的句子
text = request.args.get('text')
vectors = [di.get("[CLS]")] + [di.get(i) if i in di else di.get("[UNK]") for i in list(text)] + [di.get("[SEP]")] # 转成1*maxlen的向量
input, mask, segment = inputs(vectors)
input_ids = np.reshape(np.array(input), [1, -1])
input_mask = np.reshape(np.array(mask), [1, -1])
segment_ids = np.reshape(np.array(segment), [1, -1])

最后是将变量输入模型获得最终的bert向量:

# 定义输入向量形状
input_ids_p=tf.placeholder(shape=[None,None],dtype=tf.int32,name="input_ids_p")
input_mask_p=tf.placeholder(shape=[None,None],dtype=tf.int32,name="input_mask_p")
segment_ids_p=tf.placeholder(shape=[None,None],dtype=tf.int32,name="segment_ids_p") model = modeling.BertModel(
config=bert_config,
is_training=is_training,
input_ids=input_ids_p,
input_mask=input_mask_p,
token_type_ids=segment_ids_p,
use_one_hot_embeddings=use_one_hot_embeddings
) # 载入预训练模型
restore_saver = tf.train.Saver()
restore_saver.restore(sess, init_checkpoint) # 一个[batch_size, seq_length, embedding_size]大小的向量
embedding = tf.squeeze(model.get_sequence_output())
# 运行结果
ret=sess.run(embedding,feed_dict={"input_ids_p:0":input_ids,"input_mask_p:0":input_mask,"segment_ids_p:0":segment_ids})

完整可运行代码如下:

import tensorflow as tf
from bert import modeling
import collections
import os
import numpy as np
import json flags = tf.flags
FLAGS = flags.FLAGS
bert_path = '/home/xiangbo_wang/xiangbo/NER/chinese_L-12_H-768_A-12/' flags.DEFINE_string(
'bert_config_file', os.path.join(bert_path, 'bert_config.json'),
'config json file corresponding to the pre-trained BERT model.'
)
flags.DEFINE_string(
'bert_vocab_file', os.path.join(bert_path,'vocab.txt'),
'the config vocab file',
)
flags.DEFINE_string(
'init_checkpoint', os.path.join(bert_path,'bert_model.ckpt'),
'from a pre-trained BERT get an initial checkpoint',
)
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") def convert2Uni(text):
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode('utf-8','ignore')
else:
print(type(text))
print('####################wrong################') def load_vocab(vocab_file):
vocab = collections.OrderedDict()
vocab.setdefault('blank', 2)
index = 0
with open(vocab_file) as reader:
# with tf.gfile.GFile(vocab_file, 'r') as reader:
while True:
tmp = reader.readline()
if not tmp:
break
token = convert2Uni(tmp)
token = token.strip()
vocab[token] = index
index+=1
return vocab def inputs(vectors, maxlen = 50):
length = len(vectors)
if length > maxlen:
return vectors[0:maxlen], [1]*maxlen, [0]*maxlen
else:
input = vectors+[0]*(maxlen-length)
mask = [1]*length + [0]*(maxlen-length)
segment = [0]*maxlen
return input, mask, segment def response_request(text):
vectors = [dictionary.get('[CLS]')] + [dictionary.get(i) if i in dictionary else dictionary.get('[UNK]') for i in list(text)] + [dictionary.get('[SEP]')]
input, mask, segment = inputs(vectors) input_ids = np.reshape(np.array(input), [1, -1])
input_mask = np.reshape(np.array(mask), [1, -1])
segment_ids = np.reshape(np.array(segment), [1, -1]) embedding = tf.squeeze(model.get_sequence_output())
rst = sess.run(embedding, feed_dict={'input_ids_p:0':input_ids, 'input_mask_p:0':input_mask, 'segment_ids_p:0':segment_ids}) return json.dumps(rst.tolist(), ensure_ascii=False) dictionary = load_vocab(FLAGS.bert_vocab_file)
init_checkpoint = FLAGS.init_checkpoint sess = tf.Session()
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) input_ids_p = tf.placeholder(shape=[None, None], dtype = tf.int32, name='input_ids_p')
input_mask_p = tf.placeholder(shape=[None, None], dtype = tf.int32, name='input_mask_p')
segment_ids_p = tf.placeholder(shape=[None, None], dtype = tf.int32, name='segment_ids_p') model = modeling.BertModel(
config = bert_config,
is_training = FLAGS.use_tpu,
input_ids = input_ids_p,
input_mask = input_mask_p,
token_type_ids = segment_ids_p,
use_one_hot_embeddings = FLAGS.use_tpu,
)
print('####################################')
restore_saver = tf.train.Saver()
restore_saver.restore(sess, init_checkpoint) print(response_request('我叫水奈樾。'))

【NLP】使用bert的更多相关文章

  1. NLP新秀 - Bert

    目录 什么是Bert Bert能干什么? Bert和TensorFlow的关系 BERT的原理 Bert相关工具和服务 Bert的局限性和对应的解决方案 沉舟侧畔千帆过, 病树前头万木春. 今天介绍的 ...

  2. 最强NLP模型-BERT

    简介: BERT,全称Bidirectional Encoder Representations from Transformers,是一个预训练的语言模型,可以通过它得到文本表示,然后用于下游任务, ...

  3. NLP采用Bert进行简单文本情感分类

    参照当Bert遇上Kerashttps://spaces.ac.cn/archives/6736此示例准确率达到95.5%+ https://github.com/CyberZHG/keras-ber ...

  4. 语言模型预训练方法(ELMo、GPT和BERT)——自然语言处理(NLP)

    1. 引言 在介绍论文之前,我将先简单介绍一些相关背景知识.首先是语言模型(Language Model),语言模型简单来说就是一串词序列的概率分布.具体来说,语言模型的作用是为一个长度为m的文本确定 ...

  5. 自然语言处理中的语言模型预训练方法(ELMo、GPT和BERT)

    自然语言处理中的语言模型预训练方法(ELMo.GPT和BERT) 最近,在自然语言处理(NLP)领域中,使用语言模型预训练方法在多项NLP任务上都获得了不错的提升,广泛受到了各界的关注.就此,我将最近 ...

  6. Paper: 《Bert》

    Bert: Bidirectional Encoder Representations from Transformers. 主要创新点:Masked LM 和 Next sentence predi ...

  7. BERT的几个可能的应用

      BERT是谷歌公司于2018年11月发布的一款新模型,它一种预训练语言表示的方法,在大量文本语料(维基百科)上训练了一个通用的"语言理解"模型,然后用这个模型去执行想做的NLP ...

  8. 基于Bert的文本情感分类

    详细代码已上传到github: click me Abstract:    Sentiment classification is the process of analyzing and reaso ...

  9. 学习AI之NLP后对预训练语言模型——心得体会总结

    一.学习NLP背景介绍:      从2019年4月份开始跟着华为云ModelArts实战营同学们一起进行了6期关于图像深度学习的学习,初步了解了关于图像标注.图像分类.物体检测,图像都目标物体检测等 ...

  10. 知识图谱辅助金融领域NLP任务

    从人工智能学科诞生之初起,自然语言处理(NLP)就是人工智能核心的研究问题之一.NLP的重要性是毋庸置疑的,它能够实现以自然语言交流为特征的高级人机交互,使机器能“阅读”所有以文字形式记录的人类知识, ...

随机推荐

  1. 【转】使用URL SCHEME启动天猫客户端并跳转到某个商品页面的方法

    在项目中遇到了这样一个需求:让用户在手机应用中,点击一个天猫的商品链接(知道商品在PC浏览器里的地址),直接启动天猫的客户端并显示这个商品.以前曾经实现过类似的功能,不过那次是淘宝的商品,天猫和淘宝的 ...

  2. 最长公共子序列&最长公共子串

    首先区别最长公共子串和最长公共子序列  LCS(计算机科学算法:最长公共子序列)_百度百科 最长公共子串,这个子串要求在原字符串中是连续的.而最长公共子序列则并不要求连续. 最长公共子序列: http ...

  3. java 装饰者模式

    一.概念 我们在使用以前既定的类或者使用别人使用的类的时候,如果该类的方法,不满足你的需求的时候,需要你进行额外附加功能的时候,往往我们想到的方法是继承实现, 但是继承会导致类的越来越庞大,有什么好的 ...

  4. VMware虚拟机更换根用户( su: Authentication failure问题)

    su命令不能切换root,提示su: Authentication failure,只要你sudo passwd root过一次之后,下次再su的时候只要输入密码就可以成功登录了.

  5. Xcode下载模拟器太慢?

    在Xcode里下载模拟器,速度实在是太慢了.点击下载,卡住十几分钟才开始下载,并且龟速进行. 解决方案:获取模拟器下载地址,自己选择下载器进行下载. 找到下载链接 打开 Console.app(苹果电 ...

  6. 使用Android studio搭建Android环境

    最近安装Android studio遇到了很多问题,现在总结一下安装过程 因为我的电脑是AMD的cpu,好像不能使用虚拟机(具体原因不知道),所以我使用 软件+手机  去开发APP 先说一下使用And ...

  7. JavaWeb基础—VerifyCode源码

    package com.jiangbei.verifycodeutils; import java.awt.BasicStroke; import java.awt.Color; import jav ...

  8. JavaEE笔记(一)

    Hibernate Hibernate是一个开放源代码的对象关系映射框架,它对JDBC进行了非常轻量级的对象封装,它将POJO与数据库表建立映射关系,是一个全自动的orm框架,hibernate可以自 ...

  9. Noip前的大抱佛脚----动态规划

    目录 动态规划 序列DP 背包问题 状态压缩以及拆分数 期望概率DP 马尔可夫过程 一类生成树计数问题 平方计数 动态规划 序列DP 有些问题: 求长度为\(l\)的上升子序列个数 形如一个值域的前缀 ...

  10. mfc 引用

    一.引用的概念 引用(reference)是另一标识符的别名,可以说是C++的一种新的变量类型,是对C的重要扩充.当建立引用时,程序用另一个变量或对象(目标)的名字初始化它(即它代表了标识符的左值), ...