Code Monkey home page Code Monkey logo

bert4doc-classification's Introduction

How to Fine-Tune BERT for Text Classification?

This is the code and source for the paper How to Fine-Tune BERT for Text Classification?

In this paper, we conduct exhaustive experiments to investigate different fine-tuning methods of BERT on text classification task and provide a general solution for BERT fine-tuning.

*********** update at Mar 14, 2020 *************

Our checkpoint can be loaded in BertEmbedding from the latest fastNLP package.

Link to fastNLP.embeddings.BertEmbedding

Requirements

For further pre-training, we borrow some code from Google BERT. Thus, we need:

  • tensorflow==1.1x
  • spacy
  • pandas
  • numpy

Note that you need Python 3.7 or earlier for compatibility with tensorflow 1.1x.

For fine-tuning, we borrow some codes from pytorch-pretrained-bert package (now well known as transformers). Thus, we need:

  • torch>=0.4.1,<=1.2.0

Run the code

1) Prepare the data set:

Sogou News

We determine the category of the news based on the URL, such as “sports” corresponding to “http://sports.sohu.com”. We choose 6 categories – “sports”, “house”, “business”, “entertainment”, “women” and “technology”. The number of training samples selected for each class is 9,000 and testing 1,000.

Data is available at here.

The rest data sets

The rest data sets were built by Zhang et al. (2015). We download from URL created by Xiang Zhang.

2) Prepare Google BERT:

BERT-Base, Uncased

BERT-Base, Chinese

3) Further Pre-Training:

Generate Further Pre-Training Corpus

Here we use AG's News as example:

python generate_corpus_agnews.py

File agnews_corpus_test.txt can be found in directory ./data.

Run Further Pre-Training

python create_pretraining_data.py \
  --input_file=./AGnews_corpus.txt \
  --output_file=tmp/tf_AGnews.tfrecord \
  --vocab_file=./uncased_L-12_H-768_A-12/vocab.txt \
  --do_lower_case=True \
  --max_seq_length=128 \
  --max_predictions_per_seq=20 \
  --masked_lm_prob=0.15 \
  --random_seed=12345 \
  --dupe_factor=5
  
python run_pretraining.py \
  --input_file=./tmp/tf_AGnews.tfrecord \
  --output_dir=./uncased_L-12_H-768_A-12_AGnews_pretrain \
  --do_train=True \
  --do_eval=True \
  --bert_config_file=./uncased_L-12_H-768_A-12/bert_config.json \
  --init_checkpoint=./uncased_L-12_H-768_A-12/bert_model.ckpt \
  --train_batch_size=32 \
  --max_seq_length=128 \
  --max_predictions_per_seq=20 \
  --num_train_steps=100000 \
  --num_warmup_steps=10000 \
  --save_checkpoints_steps=10000 \
  --learning_rate=5e-5

4) Fine-Tuning

Convert Tensorflow checkpoint to PyTorch checkpoint

python convert_tf_checkpoint_to_pytorch.py \
  --tf_checkpoint_path ./uncased_L-12_H-768_A-12_AGnews_pretrain/model.ckpt-100000 \
  --bert_config_file ./uncased_L-12_H-768_A-12_AGnews_pretrain/bert_config.json \
  --pytorch_dump_path ./uncased_L-12_H-768_A-12_AGnews_pretrain/pytorch_model.bin

Fine-Tuning on downstream tasks

While fine-tuning on downstream tasks, we notice that different GPU (e.g.: 1080Ti and Titan Xp) may cause slight differences in experimental results even though we fix the initial random seed. Here we use 1080Ti * 4 as example.

Take Exp-I (See Section 5.3) as example,

export CUDA_VISIBLE_DEVICES=0,1,2,3
python run_classifier_single_layer.py \
  --task_name imdb \
  --do_train \
  --do_eval \
  --do_lower_case \
  --data_dir ./IMDB_data/ \
  --vocab_file ./uncased_L-12_H-768_A-12_IMDB_pretrain/vocab.txt \
  --bert_config_file ./uncased_L-12_H-768_A-12_IMDB_pretrain/bert_config.json \
  --init_checkpoint ./uncased_L-12_H-768_A-12_IMDB_pretrain/pytorch_model.bin \
  --max_seq_length 512 \
  --train_batch_size 24 \
  --learning_rate 2e-5 \
  --num_train_epochs 3.0 \
  --output_dir ./imdb \
  --seed 42 \
  --layers 11 10 \
  --trunc_medium -1

where num_train_epochs can be 3.0, 4.0, or 6.0.

layers indicates list of layers which will be taken as feature for classification. -2 means use pooled output, -1 means concat all layer, the command above means concat layer-10 and layer-11 (last two layers).

trunc_medium indicates dealing with long texts. -2 means head-only, -1 means tail-only, 0 means head-half + tail-half (e.g.: head256+tail256), other natural number k means head-k + tail-rest (e.g.: head-k + tail-(512-k)).

There also other arguments for fine-tuning:

pooling_type indicates which feature will be used for classification. mean means mean-pooling for hidden state of the whole sequence, max means max-pooling, default means taking hidden state of [CLS] token as features.

layer_learning_rate and layer_learning_rate_decay in run_classifier_discriminative.py indicates layer-wise decreasing layer rate (See Section 5.3.4).

Further Pre-Trained Checkpoints

We upload IMDb-based further pre-trained checkpoints at here.

For other checkpoints, please contact us by e-mail.

How to cite our paper

@inproceedings{sun2019fine,
  title={How to fine-tune {BERT} for text classification?},
  author={Sun, Chi and Qiu, Xipeng and Xu, Yige and Huang, Xuanjing},
  booktitle={China National Conference on Chinese Computational Linguistics},
  pages={194--206},
  year={2019},
  organization={Springer}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.