Code Monkey home page Code Monkey logo

microsoft / bang Goto Github PK

View Code? Open in Web Editor NEW
27.0 5.0 6.0 182 KB

BANG is a new pretraining model to Bridge the gap between Autoregressive (AR) and Non-autoregressive (NAR) Generation. AR and NAR generation can be uniformly regarded as to what extent previous tokens can be attended, and BANG bridges AR and NAR generation by designing a novel model structure for large-scale pretraining. The pretrained BANG model can simultaneously support AR, NAR and semi-NAR generation to meet different requirements.

License: MIT License

Python 100.00%

bang's Introduction

BANG

This repo provides the code for reproducing the experiments in BANG.
In the paper, we propose a new pre-trained language model called BANG for sequence-to-sequence learning, which considers autoregressive, non-autoregressive and semi-autoregressive generation as its pretraining tasks.

Pretrained Models:

BANG base
Pretrained on 16GB English corpus, Wikipedia and BookCorpus.

Dependency

  • pip install torch==1.3.0
  • pip install fairseq==v0.9.0
  • pip install tensorboardX==1.7

How to use

The procedure includes 1) Tokenize, 2) Binarize, 3) Finetune, 4) Inference.
BANG is implemented on base of Fairseq, which you can refer to Fairseq Mannual.

Tokenize. Prepare your train.src, train.tgt, and valid, test sets. Input and output of one sample are placed in the .src and .tgt file with one line.
Use bert-uncased tokenizer to tokenize your data into word piece.

from transformers import BertTokenizer


def bert_uncased_tokenize(fin, fout):
    fin = open(fin, 'r', encoding='utf-8')
    fout = open(fout, 'w', encoding='utf-8')
    tok = BertTokenizer.from_pretrained('bert-base-uncased')
    for line in fin:
        word_pieces = tok.tokenize(line.strip())
        new_line = " ".join(word_pieces)
        fout.write('{}\n'.format(new_line))
bert_uncased_tokenize('train.src', 'tokenized_train.src')
bert_uncased_tokenize('train.tgt', 'tokenized_train.tgt')
bert_uncased_tokenize('valid.src', 'tokenized_valid.src')
bert_uncased_tokenize('valid.tgt', 'tokenized_valid.tgt')
bert_uncased_tokenize('test.src', 'tokenized_test.src')
bert_uncased_tokenize('test.tgt', 'tokenized_test.tgt')

Binirize it with fairseq-preprocess

fairseq-preprocess \
--user-dir ./bang/bang \
--task translation_bang \
--source-lang src --target-lang tgt \
--trainpref tokenized_train --validpref tokenized_valid --testpref tokenized_test \
--destdir processed_data --srcdict ./bang/vocab.txt --tgtdict ./bang/vocab.txt \
--workers 20

Fine tune with fairseq-train.

Autoregressive Generation

Set these parameters:
--disable-ngram-loss:please set True for AR finetuning
--ngram: please set 1 for AR finetuning
--nar-ratio: please set 0.0 for AR finetuning
--fp16: if your GPU device supports, set True to accelerate training

DATA_DIR=processed_data
ARCH=bang_ar_nar_mixed_base
CRITERION=ngram_language_loss_NAR_mixed
SAVE_DIR=models/model_ar
TENSORBOARD_LOGDIR=models/logs_ar
PRETRAINED_MODEL=checkpoint_base_9gram_ck35.pt
NAR_RATIO=0.0

fairseq-train $DATA_DIR \
--user-dir ./bang/bang  \
--task translation_bang --arch $ARCH \
--optimizer adam --adam-betas '(0.9, 0.999)' --clip-norm 0.1 \
--lr 0.0001 --min-lr 1e-09 --nar-ratio ${NAR_RATIO} --ngram 1 --disable-ngram-loss \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 1000 \
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
--criterion $CRITERION --label-smoothing 0.1 \
--update-freq 1  --max-tokens 3072 \
--num-workers 8  \
--load-from-pretrained-model $PRETRAINED_MODEL \
--ddp-backend=no_c10d --max-epoch 10 \
--max-source-positions 512 --max-target-positions 512 \
--truncate-source \
--save-dir $SAVE_DIR \
--keep-last-epochs 10  --save-interval 1 \
--tensorboard-logdir $TENSORBOARD_LOGDIR \

Inference with fairseq-generate to generate targets for given processed test files. Or you can fairseq-interactive to generate answers for your typed-in text (which should also been tokenized).

BEAM=4
LENPEN=1.2
CHECK_POINT=models/model_ar/checkpoint8.pt
SUFFIX=_ar_pelt${LENPEN}_test_beam${BEAM}
OUTPUT_FILE=outputs/output$SUFFIX.txt

PYTHONIOENCODING=utf-8 fairseq-generate ./processed_data --path $CHECK_POINT --user-dir ./bang/bang --task translation_bang --batch-size 36 --gen-subset train --beam $BEAM --num-workers 4 --lenpen $LENPEN 2>&1 > $OUTPUT_FILE
grep ^H $OUTPUT_FILE | cut -c 3- | sort -n | cut -f3- | sed "s/ ##//g" > outputs/sort_hypo$SUFFIX.txt
grep ^H $OUTPUT_FILE | cut -c 3- | sort -n | cut -f3-  > outputs/sort_hypo$SUFFIX.txt.tokenized

Non-autoregressive Generation

--nar-ratio: please set 1.0 for NAR finetuning
--fp16: if your GPU device supports, set True to accelerate training

DATA_DIR=processed_data
ARCH=bang_ar_nar_mixed_base
CRITERION=ngram_language_loss_NAR_mixed
SAVE_DIR=models/model_nar
TENSORBOARD_LOGDIR=models/logs_nar
PRETRAINED_MODEL=checkpoint_base_9gram_ck35.pt
NAR_RATIO=1.0

fairseq-train $DATA_DIR \
--user-dir ./bang/bang  \
--task translation_bang --arch $ARCH \
--optimizer adam --adam-betas '(0.9, 0.999)' --clip-norm 0.1 \
--lr 0.0001 --min-lr 1e-09 --nar-ratio $NAR_RATIO --ngram 1 --disable-ngram-loss \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 1000 \
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
--criterion $CRITERION --label-smoothing 0.1 \
--update-freq 1  --max-tokens 3072 \
--num-workers 8  \
--load-from-pretrained-model $PRETRAINED_MODEL \
--ddp-backend=no_c10d --max-epoch 50 \
--max-source-positions 512 --max-target-positions 512 \
--truncate-source \
--save-dir $SAVE_DIR \
--keep-last-epochs 10  --save-interval 5 \
--tensorboard-logdir $TENSORBOARD_LOGDIR \

Inference with fairseq-generate to generate targets for given processed test files. Or you can fairseq-interactive to generate answers for your typed-in text (which should also been tokenized).

SUFFIX=_nar
CHECK_POINT=models/model_nar/checkpoint40.pt
OUTPUT_FILE=outputs/output${SUFFIX}.txt

PYTHONIOENCODING=utf8 fairseq-generate processed_data  --user-dir ./bang/bang --path ${CHECK_POINT} --truncate-source --max-source-positions 512 --task translation_bang_nar --batch-size 36 --beam 1 --gen-subset test  2>&1 > ${OUTPUT_FILE}

grep ^H $OUTPUT_FILE | cut -c 3- | sort -n | cut -f3- > outputs/sort_hypo${SUFFIX}.txt
python post_processed_nar.py outputs_v1/sort_hypo${SUFFIX}.txt outputs/sort_hypo${SUFFIX}.txt.dedup

TIPS:

1, Autoregressive needs fewer finetuning steps, while Non-autoregressive needs longtime finetuning to get good performance.
2, We highly recommend you use sequence distillation before NAR finetuning.
3, If you met problems to run fairseq-preprocess, fairseq-train and other commands, or if you want to modify the workflow/inference pipeline, it's a good choice to download fairseq git repo, checkout v0.9.0, and merge our codes. Then, modify their preprocess.py, train.py or generate.py, to run your new pipeline.

Repo Reference

This repo is referred to Fairseq-v0.9.0 and ProphetNet.

How to Cite

If you extend or use this work, please cite the paper where it was introduced:

@inproceedings{qi2021bang,
  title={Bang: Bridging autoregressive and non-autoregressive generation with large scale pretraining},
  author={Qi, Weizhen and Gong, Yeyun and Jiao, Jian and Yan, Yu and Chen, Weizhu and Liu, Dayiheng and Tang, Kewen and Li, Houqiang and Chen, Jiusheng and Zhang, Ruofei and others},
  booktitle={International Conference on Machine Learning},
  pages={8630--8639},
  year={2021},
  organization={PMLR}
}

Microsoft Open Source Code of Conduct

bang's People

Contributors

jianjiao16 avatar microsoft-github-operations[bot] avatar microsoftopensource avatar qiweizhen avatar v-weizqi avatar weizhenq avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

bang's Issues

Reproduction on PersonaChat

@qiweizhen Hello, thanks for your great work!

I am reproducing the result on personachat (in the NAT setting) but find it hard to reach the performance reported in your paper.

Specifically, I use the dataset from GLGE, and then finetune BANG use following command:

DATA_DIR=processed_data
ARCH=bang_ar_nar_mixed_base
CRITERION=ngram_language_loss_NAR_mixed
SAVE_DIR=models/model_nar
TENSORBOARD_LOGDIR=models/logs_nar
PRETRAINED_MODEL=checkpoint_base_9gram_ck35.pt
NAR_RATIO=1.0

MAX_TOKENS=12000
LR=5e-5

# on 8 GPUs
fairseq-train $DATA_DIR \
--user-dir ./bang/bang  \
--task translation_bang --arch $ARCH \
--optimizer adam --adam-betas '(0.9, 0.999)' --clip-norm 0.1 \
--lr $LR --min-lr 1e-09 --nar-ratio $NAR_RATIO --ngram 1 --disable-ngram-loss \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 1000 \
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
--criterion $CRITERION --label-smoothing 0.1 \
--update-freq 1  --max-tokens $MAX_TOKENS \
--num-workers 8  \
--load-from-pretrained-model $PRETRAINED_MODEL \
--ddp-backend=no_c10d --max-epoch 50 \
--max-source-positions 512 --max-target-positions 512 \
--truncate-source \
--save-dir $SAVE_DIR \
--keep-last-epochs 10  --save-interval 5 \
--tensorboard-logdir $TENSORBOARD_LOGDIR

Then I use the evaluation script here, but only get:

-- B1 B2 D1 D2
My reproduction 30.69 19.27 2.1 23.0
Your Result 39.82 30.72 1.9 14.2

Can you give some advice for the reproduction? For example, the training and evaluation script?

Denpendency Version

Hi, I noticed that BANG requires pytorch 1.3.0 as dependency. However, there's no distribution of pytorch 1.3.0 in the offcial site. Also, I found some other versions of pytorch (like 1.2.0 or 1.4.0) will result in an error. Any suggestion for solving the problem? Thanks!

How to get the Rouge scores in the paper?

Bang is a great paper. But I have some problems when i try to get the scores in the paper.
First, the BLUE-4 and Rouge-L in Squad question generation by Mass pretrained model is different with the paper with beam size as 5,and get BLEU4 = 22.43, Rouge as:

1 ROUGE-1 Average_R: 0.48431 (95%-conf.int. 0.47986 - 0.48875)
1 ROUGE-1 Average_P: 0.54315 (95%-conf.int. 0.53853 - 0.54740)
1 ROUGE-1 Average_F: 0.49817 (95%-conf.int. 0.49411 - 0.50238)

1 ROUGE-2 Average_R: 0.26775 (95%-conf.int. 0.26311 - 0.27234)
1 ROUGE-2 Average_P: 0.29883 (95%-conf.int. 0.29365 - 0.30367)
1 ROUGE-2 Average_F: 0.27436 (95%-conf.int. 0.26965 - 0.27884)

1 ROUGE-L Average_R: 0.44690 (95%-conf.int. 0.44248 - 0.45166)
1 ROUGE-L Average_P: 0.49998 (95%-conf.int. 0.49532 - 0.50436)
1 ROUGE-L Average_F: 0.45929 (95%-conf.int. 0.45507 - 0.46371)
BLUE4 is higher than paper, but Rouge-L is lower than paper.

Second, I use Bang pretrained model and the code in this repo, but my rouge scores lags behind the scores in the paper – with a gap about 2. and the generative quality is poor:
when dominicn choe was as to , singapore , he to .
northern ireland ' s euro 2016 qualifier ireland was after a crash .
prime may says she has faith ' in ' trident nuclear after afire a the bbc . tennis police is investigating by a williams a at . a coast has been that to oil coast a public bonfire park bonfire belfast has been up a of bonfire the bbc has strongly a claims thatguana iling iguana was ' .

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.