2018年google推出了bert模型,这个模型的性能要远超于以前所使用的模型,总的来说就是很牛。但是训练bert模型是异常昂贵的,对于一般人来说并不需要自己单独训练bert,只需要加载预训练模型,就可以完成相应的任务。下面我将以情感分类为例,介绍使用bert的方法。这里与我们之前调用API写代码有所区别,已经有大神将bert封装成.py文件,我们只需要简单修改一下,就可以直接调用这些.py文件了。

官方文档

  1. tensorflow版:点击传送门
  2. pytorch版(注意这是一个第三方团队实现的):点击传送门
  3. 论文:点击传送门

    一切以官方论文为准,如果有什么疑问,请仔细阅读官方文档

具体实现

我这里使用的是pytorch版本。

前置需要

  1. 安装pytorch和tensorflow。
  2. 安装PyTorch pretrained bert。(pip install pytorch-pretrained-bert)
  3. 将pytorch-pretrained-BERT提供的文件,整个下载。
  4. 选择并且下载预训练模型。地址:请点击

    注意这里的model是tensorflow版本的,需要进行相应的转换才能在pytorch中使用

无论是tf版还是pytorch版本,预训练模型都需要三个文件(或者功能类似的)

  1. 预训练模型文件,里面保存的是模型参数。
  2. config文件,用来加载预训练模型。
  3. vocabulary文件,用于后续分词。

模型转换

文档里提供了convert_tf_checkpoint_to_pytorch.py 这个脚本来进行模型转换。使用方法如下:

export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12

pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch \
$BERT_BASE_DIR/bert_model.ckpt \
$BERT_BASE_DIR/bert_config.json \
$BERT_BASE_DIR/pytorch_model.bin

修改源码

这里是需要实现情感分类。只需要用到run_classifier_dataset_utils.py和run_classifier.py这两个文件。run_classifier_dataset_utils.py是用来处理文本的输入,我们只需要添加一个类用来处理输入即可。

class MyProcessor(DataProcessor):
'''Processor for the sentiment classification data set''' def get_train_examples(self, data_dir):
"""See base class."""
logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") def get_labels(self):
"""See base class."""
return ["-1", "1"] def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = line[0]
label = line[1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples

train.tsv和dev.tsv分别表示训练集和测试集。记得要在下面的代码加上之前定义的类。

def compute_metrics(task_name, preds, labels):
assert len(preds) == len(labels)
if task_name == "cola":
return {"mcc": matthews_corrcoef(labels, preds)}
elif task_name == "sst-2":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "mrpc":
return acc_and_f1(preds, labels)
elif task_name == "sts-b":
return pearson_and_spearman(preds, labels)
elif task_name == "qqp":
return acc_and_f1(preds, labels)
elif task_name == "mnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "mnli-mm":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "qnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "rte":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "wnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "my":
return acc_and_f1(preds, labels)
else:
raise KeyError(task_name) processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mnli-mm": MnliMismatchedProcessor,
"mrpc": MrpcProcessor,
"sst-2": Sst2Processor,
"sts-b": StsbProcessor,
"qqp": QqpProcessor,
"qnli": QnliProcessor,
"rte": RteProcessor,
"wnli": WnliProcessor,
"my": MyProcessor
} output_modes = {
"cola": "classification",
"mnli": "classification",
"mrpc": "classification",
"sst-2": "classification",
"sts-b": "regression",
"qqp": "classification",
"qnli": "classification",
"rte": "classification",
"wnli": "classification",
"my": "classification"
}

运行bert

编辑shell脚本:

#!/bin/bash
export TASK_NAME=my python run_classifier.py \
--task_name $TASK_NAME \
--do_train \
--do_eval \
--do_lower_case \
--data_dir /home/garvey/Yuqinfenxi/ \
--bert_model /home/garvey/uncased_L-12_H-768_A-12 \
--max_seq_length 410 \
--train_batch_size 8 \
--learning_rate 2e-5 \
--num_train_epochs 3.0 \
--output_dir /home/garvey/bertmodel

运行即可。这里要注意max_seq_length和train_batch_size这两个参数,设置过大是很容易爆掉显存的,一般来说运行bert需要11G左右的显存。

备注

max_seq_length是指词的数量而不是指字符的数量。参考代码中的注释:

The maximum total input sequence length after WordPiece tokenization. Sequences longer than this will be truncated, and sequences shorter than this will be padded.

对于sequence的理解,网上很多博客都把这个翻译为句子,我个人认为是不准确的,序列是可以包含多个句子的,而不只是单独一个句子。

注意

Bert开源的代码中,只提供了train和dev数据,也就是训练集和验证集。对于评测论文标准数据集的时候,只需要把训练集和测试集送进去就可以得到结果,这一过程是没有调参的(没有验证集),都是使用默认参数。但是如果用Bert来打比赛,注意这个时候的测试集是没有标签的,这就需要在源码中加上一个处理test数据集的部分,并且通过验证集来选择参数。

补充

在大的预训练模型例如像bert-large在对小的训练集进行精细调整的时候,往往会导致性能退化:模型要么运行良好,要么根本不起作用,在我们用bert-large对一些小数据集进行微调,直接使用默认参数的话二分类的准确率只有0.5,也就是一点作用也没有,这个时候需要对学习率和迭代次数进行一个调整才会有一个正常的结果,这个问题暂时还没有得到解决。

使用bert进行情感分类的更多相关文章

  1. 使用BERT进行情感分类预测及代码实例

    文章目录 0. BERT介绍 1. BERT配置 1.1. clone BERT 代码 1.2. 数据处理 1.2.1预训练模型 1.2.2数据集 训练集 测试集 开发集 2. 修改代码 2.1 加入 ...

  2. Bert实战---情感分类

    1.情感分析语料预处理 使用酒店评论语料,正面评论和负面评论各5000条,用BERT参数这么大的模型, 训练会产生严重过拟合,,泛化能力差的情况, 这也是我们下面需要解决的问题; 2.sigmoid二 ...

  3. 基于Bert的文本情感分类

    详细代码已上传到github: click me Abstract:    Sentiment classification is the process of analyzing and reaso ...

  4. 关于情感分类(Sentiment Classification)的文献整理

    最近对NLP中情感分类子方向的研究有些兴趣,在此整理下个人阅读的笔记(持续更新中): 1. Thumbs up? Sentiment classification using machine lear ...

  5. kaggle——Bag of Words Meets Bags of Popcorn(IMDB电影评论情感分类实践)

    kaggle链接:https://www.kaggle.com/c/word2vec-nlp-tutorial/overview 简介:给出 50,000 IMDB movie reviews,进行0 ...

  6. NLP文本情感分类传统模型+深度学习(demo)

    文本情感分类: 文本情感分类(一):传统模型 摘自:http://spaces.ac.cn/index.php/archives/3360/ 测试句子:工信处女干事每月经过下属科室都要亲口交代24口交 ...

  7. kaggle之电影评论文本情感分类

    电影文本情感分类 Github地址 Kaggle地址 这个任务主要是对电影评论文本进行情感分类,主要分为正面评论和负面评论,所以是一个二分类问题,二分类模型我们可以选取一些常见的模型比如贝叶斯.逻辑回 ...

  8. PaddlePaddle︱开发文档中学习情感分类(CNN、LSTM、双向LSTM)、语义角色标注

    PaddlePaddle出教程啦,教程一部分写的很详细,值得学习. 一期涉及新手入门.识别数字.图像分类.词向量.情感分析.语义角色标注.机器翻译.个性化推荐. 二期会有更多的图像内容. 随便,帮国产 ...

  9. [DeeplearningAI笔记]序列模型2.9情感分类

    5.2自然语言处理 觉得有用的话,欢迎一起讨论相互学习~Follow Me 2.9 Sentiment classification 情感分类 情感分类任务简单来说是看一段文本,然后分辨这个人是否喜欢 ...

随机推荐

  1. 原生ajax解析&封装原生ajax函数

    前沿:对于此篇随笔,完是简要写了几个重要的地方,具体实现细节完在提供的源码做了笔记 <一>ajax基本要点介绍--更好的介绍ajax 1. ajax对象中new XMLHttpReques ...

  2. java 获取最近7天 最近今天的日期

    private static Date getDateAdd(int days){ SimpleDateFormat sf = new SimpleDateFormat("yyyy-MM-d ...

  3. Rendering in UE4

    Intro Thinking performance. Identify the target framerate, aim your approach on hitting that target ...

  4. CentOS7.5下SVN服务器备份与恢复

    可以先查看 svnadmin 命令的使用说明 svnadmin --help 1.完全备份和增量备份 查看 svnadmin dump 命令的使用说明 svnadmin dump --help svn ...

  5. Django REST framework —— 权限组件源码分析

    在上一篇文章中我们已经分析了认证组件源码,我们再来看看权限组件的源码,权限组件相对容易,因为只需要返回True 和False即可 代码 class ShoppingCarView(ViewSetMix ...

  6. 浏览器报400-Bad Request异常

    今天在使用ie浏览器在测试程序的时候,报这个错误,后台日志打印出来显示的是:连接一个远程主机失败 解决Invalid character found in the request target. Th ...

  7. 关于原生js的节点兼容性

    关于节点的兼容性: 1:获取元素的子节点 a: childNodes:获取元素的子节点,空文本,非空文本,注释,获取的比较全面, 如果只是想获取元素的子节点,请用(children) b:     c ...

  8. 学习Spring-Data-Jpa(十六)---@Version与@Lock

    1.问题场景 以用户账户为例,如果允许同时对某个用户的账户进行修改的话,会导致某些修改被覆盖,使最后的结果不正确. 如:1.1.张三的账户中有100元. 1.2.张三的账户消费了50元. 1.3.张三 ...

  9. Mscordacwks.dll/SOS.dll 调试归档

    找到个好东西 为什么要归档 此存档提供帮助,并可能提供对以下问题的答案 是否可以使WinDBG在符号存储中找到mscordacwks.dll?, Windbg需要不同版本的mscordacwks.dl ...

  10. windows系统的快速失败机制---fastfail

    windows系统的快速失败机制---fastfail,是一种用于“快速失败”请求的机制 — 一种潜在破坏进程请求立即终止进程的方法. 无法使用常规异常处理设施处理可能已破坏程序状态和堆栈至无法恢复的 ...