Code Monkey home page Code Monkey logo

bert_implement's Introduction

BERT_implement

使用BERT模型进行文本分类,相似句子判断,以及中文命名实体识别(序列标注任务)

说明

  1. 谷歌提供的run_classify.py本身就是针对句子配对与分类的,所以,文本分类任务和句子配对任务对 代码修改不多,只用重写接口,就可以达到先进的效果
  2. 序列标注不仅要重写接口,由于中文序列标注数据集格式问题,还要重写读取数据的方式,同时,原本run_classify.py在create_model函数中提供的是提取[CLS]编码的方式用来分类,序列标注要返回最后一层所有隐层值而不是仅返回[cls]的编码
  3. 序列标注任务中还要注意当mode==eval时,对评价函数的修改

使用

下载预训练模型

wget https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip  

将下载好的预训练模型存放在目录中

准备好数据

参数说明

参数 说明

data_dir 训练数据的地址

task_name processor的名字

vocab_file 字典地址,用默认提供的就可以了,当然也可以自定义

bert_config_file 配置文件

output_dir 模型的输出

文本分类

export BERT_BASE_DIR=你的模型保存目录
export DATA_DIR=数据保存目录
export OUTPUT_DIR=结果保存目录

# 使用官方提供的参数进行训练
python BERT_implement.py \
  --task_name=text_classify \
  --do_train=true \
  --do_eval=true \
  --data_dir=$DATA_DIR/ \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
  --max_seq_length=128 \
  --train_batch_size=32 \
  --learning_rate=2e-5 \
  --num_train_epochs=3.0 \
  --output_dir=$OUTPUT_DIR

结果

句子配对

export BERT_BASE_DIR=你的模型保存目录
export DATA_DIR=数据保存目录
export OUTPUT_DIR=结果保存目录

# 使用官方提供的参数进行训练
python BERT_implement.py \
  --task_name=pair_sentence \
  --do_train=true \
  --do_eval=true \
  --data_dir=$DATA_DIR/ \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
  --max_seq_length=128 \
  --train_batch_size=32 \
  --learning_rate=2e-5 \
  --num_train_epochs=3.0 \
  --output_dir=$OUTPUT_DIR

结果

tips:

  1. 最大序列长度(max sequence length)对模型的效果影响比较大。随着最大序列长度增加,效果有所提升,但模型的训练时间也相应增加。当最大序列长度变小后(如截取信息),模型的准确率下降
  2. 批次大小(batch size)对模型的效果影响也比较,如从64下降到16后,模型的准确率下降幅度较大。
  3. fine-tuning模式下略微提高训练轮次(epoch) ,效果可进一步提高。

命名实体识别(序列标注)

export BERT_BASE_DIR=你的模型保存目录
export DATA_DIR=数据保存目录
export OUTPUT_DIR=结果保存目录

# 使用官方提供的参数进行训练
python BERT_implement.py \
  --task_name=ner \
  --do_train=true \
  --do_eval=true \
  --data_dir=$DATA_DIR/ \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
  --max_seq_length=128 \
  --train_batch_size=32 \
  --learning_rate=2e-5 \
  --output_dir=$OUTPUT_DIR

结果

在线预测

在输出文件中可以得到BERT经过fine tunning 后的检查点ckpt,但是可以看到检查点大小与谷歌原本检查点大小相比非常大,是因为权重文件中包含了动量和方差,可以使用提供的脚本进行模型精简

python compress_ckpt.py  --input_file YOUR_CKPT  --output_file OUTPUT_FILE

其他NLP任务

待补充。。。

bert_implement's People

Contributors

fennudetudou avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.