Code Monkey home page Code Monkey logo

maml-nlp's Introduction

中文小样本事件抽取

这份文档主要分两个部分,第一部分是代码说明,主要是对代码实现的相关细节进行了具体介绍。第二部分是代码运行指南,主要是介绍如何运行模型。

一、代码说明

1. 文件结构

fewevent  
    |--models/    
    |    |--proto.py # ProtoNN模型实现  
    |--util/
    |    |--dataloder.py # 读取数据  
    |    |--framework.py # 训练测试  
    |    |--sentence_encoder.py #对数据用bert进行embedding
    |--pretrain/ #存放预训练模型的文件夹
    |--checkpoint/ #保存模型的文件,训练的时候会将模型保存到这个文件中,测试时调用
    |--train_demo.py # 代码入口  
    |--readme.md   
    |--requirements.txt  

2. 命令行参数详解

  • train / val / test:设定训练集,验证集以及测试集,默认值是上面提到的数据集。
  • N: 每个任务的样本类别。
  • K: 每个类别的支持集样本数量。
  • Q: 每个类别的query集的样本数量。
  • Q_test: 测试或者验证的时候每个类别的query集的样本数量。
  • batchsize:训练时,读取任务的数量,默认值是4
  • num_class:集成测试中,我们关注每个类的准确性,所以在计算单个类的准确性时,希望找到是这个类准确性最高的类别组合,num_class指的就是针对单个类要对比的类别组合数量
  • train_iter:总的训练迭代次数,默认值是10000,可根据样本大小自行修改
  • val_iter:每次验证的迭代次数,默认值是1000,可根据样本大小自行修改
  • val_step:每val_step个迭代次数验证一次,默认值是1000,也就是训练迭代1000次在验证集上进行一次验证
  • bert_type: 使用的bert类别,默认值是 fin,即FinBERT。ch代表谷歌的中文bert,en代表谷歌的英文bert。
  • dis:原型网络计算类中心点和测试样本所用的距离度量,ou代表欧氏距离,cos代表预先距离,dot代表点乘,默认是欧式距离,目前表现最好的也是欧式距离
  • max_length / lr / dropout/grad_iter/hidden_size:模型训练的一些超参可以自行调整
  • use_sgd_for_bert:模型使用的优化器,默认是sgd,设为false话会使用admw
  • only_test:仅在调用保存模型测试的时候使用
  • load_ckpt:调用模型的名字,--only_test --load_ckpt {CHECKPOINT_PATH}用于模型测试,训练时不必加这两个参数
  • loss:设定损失函数,margin_loss/cross表示margin loss或者交叉熵损失,margin loss用来测试未知类别,默认为cross,margin loss目前结果不是很理想还在优化中
  • M:使用margin_loss用来表示设定的margin大小,默认值为-10
  • lamda:使用margin_loss的缩放系数,主要用于测试的时候对margin放宽限制,默认值为0.1
  • unuse_label:不使用标签信息,如果在命令行加这个就表示不适用prompt

3. 模型架构

模型架构 ProtoNN的模型如上图所示,将支持集和询问集分别送到BERT中,得到相应的表示。然后针对每一类的支持集表示做一个平均(当然也可以采用其他的策略)作为类的原型,最后计算询问集和每一类原型之间的距离,选取距离最小的作为目标类别。

对比学习

为了提高鲁棒性,我们在学习过程中加入了对比学习,用来优化Bert表示。对比学习最重要的就是构造正负样本,这里借鉴了SimCSE。SimCSE 引入dropout给输入加噪声,假设加噪后的输入仍与原始输入在语义空间距离相近。其正负例的构造方式如下:

  • 正例:给定输入,用预训练语言模型编码两次得到的两个向量对作为正例对。
  • 负例:随机采样一个另一个类别输入作为的负例。

对比损失: $\ell_{i}=-\log \frac{e^{\operatorname{sim}\left(\mathbf{h}{i}^{z{i}}, \mathbf{h}{i}^{z{i}^{\prime}}\right) / \tau}}{\sum_{j=1}^{N} e^{\operatorname{sim}\left(\mathbf{h}{i}^{z{i}}, \mathbf{h}_{j}^{\prime}\right) / \tau}}$

通过SimCSE来优化句子表示,从而增强少样本技术的鲁棒性。

集成测试

模型在最后测试数据集上的测试与训练不同,加入了集成测试。

普通原型网络每次采样N-way-K-shot的任务然后将这N类剩余的数据作为询问集(query data)进行测试,记录每个类别的准确性。单个类别可能会出现在不同的任务中(N-way-K-shot task),最后对于所有的类别取一个平均,作为该类别的准确率。但是这样有一个问题,在实际衡量中我们希望得到的并不是准确性指标而是希望对于无标签数据能够打个标签,那选用哪个出现该类的任务是一个问题。

针对实际应用以及鲁棒性,我们在测试环节加入了集成测试。我们最后关心的是单个类别的准确性,所以在对单个类别进行测试时,我们采样出现该类别的任务,在每个任务中会对该类别的所有数据进行分类,最后综合所有任务,我们进行一个投票,选择票数多的作为每个样本的类标签,这样通过投票得到了该类所有样本的最终标签。

举个实际的例子,比如测试类别有10种,我们采用5-way-5-shot的设定,对于0号类别,(0,1,2,3,4),(0,2,4,5,6),(0,5,3,6,9)……可能会有126个任务出现它,我们随机取10个任务,来进行测试,假设0号类别有90条数据,我们选取的这10个任务,会分别对90条数据打标签,那每个数据可能会有10个标签,我们对此进行投票,选择票数多的作为最终标签。

经过实验,集成测试确实增加了不同类别的准确性,一定程度上提高了在不同样本的鲁棒性。防止该类别在某些任务里表现较差从而影响最终的准确性。

prompt

为了提高模型对个别类别的鲁棒性,我们引入了目前大火的prompt。来增强模型对类别的先验知识。具体做法如下:

  1. 对于support集合,我们知道单个样本的标签,所以将样本和对应的标签拼在一起送到bert中得到句子的表示,然后根据句子表示计算得到类的原型。
  2. 对于query集合,因为我们不知道单个样本的label,但是在N way K shot的设置下,我们知道有N个候选项,那我们可以将单个样本分别与这N个label进行拼接送到bert中得到相应的表示,然后再分别与对应的类原型计算距离,选择距离近的作为目标类别。

针对query集的处理,为了方便理解,简单举个例子, 现在有5类:行政责令,股价创新高,解除冻结,欠息,入股。模型已经分别计算得到这5类的原型。来一个句子,<这是一个测试句子>,模型分别与这5类的标签进行拼接得到:

  • <这是一个测试句子>【SEP】这是一条行政责令的新闻,
  • <这是一个测试句子>【SEP】这是一条股价创新高的新闻,
  • <这是一个测试句子>【SEP】这是一条解除冻结的新闻,
  • <这是一个测试句子>【SEP】这是一条欠息的新闻
  • <这是一个测试句子>【SEP】这是一条入股的新闻。

然后,将这5个句子送到bert中得到相应句子表示。然后与相应的原型计算距离,比如 <这是一个测试句子>【SEP】这是一条行政责令的新闻. 就与 行政责令 的原型计算距离,<这是一个测试句子>【SEP】这是一条股价创新高的新闻 就与 股价创新高 的原型计算距离……选择距离小的作为目标类别。

未知类别

在实际应用场景中,进来一个样本,很有可能不是N-way中的任意一个,所以模型应当具有识别未知类别的能力。原来的模型采用的欧氏距离+交叉熵损失,现在为了识别未知类,我们采用margin loss,设定一个距离M,使正样本的距离小于M,负样本的距离大于M,所以在测试时如果一个样本与N个原型的距离都大于M,模型就认为这个样本是未知类别。不过由于训练我们并不可能让所有样本都完美拟合我们的margin M,所以在测试的时候可以适当放宽对margin的限制。 引入margin loss并不改变我们的训练过程,在训练时我们并不需要引入未知类别,只需要将样本与相对应原型的距离拉近到M以内,与其他原型距离拉远到M以外。在测试时引入未知类别的数据,在N-way-K-shot的测试中,如果与所有原型的距离都在M以上,那这个样本就会被划分为未知样本。

4. 模型训练验证以及测试的相关细节

模型训练和验证过程比较相似,默认迭代train_iter次,每次迭代采样batchsize个N-way-K-shot的任务,然后再Q个测试集上进行训练。因为迭代次数足够的多,相当于可以取遍所有的数据集。每次迭代会计算相应的准确率(accuarcy),精准率(precision),召回率(recall),F1分数。 比如当train_iterval_iterval_step都会取默认值是,就是总的训练迭代10000次,其中每1000次在验证集上测试一次,保存验证集上性能最优的模型,在验证集上进行测试是,会迭代1000次,将这1000次的准确率,精准率等取平均作为在验证集上的指标。

二、代码运行指南

1. 数据集

这里我们提供了两份:
一份是华泰金融数据集:few_shot_train.json,few_shot_dev.json
一份是英文的FewEvent数据集:Few_Shot_ED.json,自行按照80:10:10来划分成训练集,验证集和测试集

华泰数据集:

训练集:few_shot_train.json 验证集few_shot_dev.json
测试集few_shot_test.json

2. 预训练模型

我们的代码是用BERT编码的,中文的bert采用的是谷歌bert-chinese-base,英文的bert采用的bert-uncase-base。其中在华泰数据集上我们使用了专门的金融领域的FinBERT,点击 here下载,解压后放到pretrain文件夹下。

3. 运行

训练:

python train_demo.py 

测试:

python train_demo.py --only_test --load_ckpt {CHECKPOINT_PATH} {OTHER_ARGS}

如果想要使用多卡运行的话:

CUDA_VISIBLE_DEVICES=$gpu_id python train_demo.py $argu_list

maml-nlp's People

Contributors

gaoyi-byte avatar

Stargazers

 avatar

Watchers

 avatar

maml-nlp's Issues

小样本命名实体识别

作者您好,请教一下,这个模型可以改成小样本命名实体识别吗。另外我想问下,这个数据集的格式是什么呢?期待您的回复

数据集

可以提供一份华泰金融的数据集吗?

MAML如何预测

请问如何构建预测的数据集?

在test数据集中有标签数据,而在实际预测阶段中没有单个text 上没有标签数据,应该如何构建预测所用的数据集呢?看起来不太能做成接口的形式进行单个text的预测。劳烦解答疑惑,十分感谢🙏

训练时GPU使用情况

您好,我想问下您在训练的时候使用的显存占用大概是什么样的?

我尝试运行您的代码batch_size=1,bert_type=ch,N=4,K=5,Q=20,gpu是3090 24G,提示显存超出。

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.