Code Monkey home page Code Monkey logo

staged-training's Introduction

staged-training

In our paper Staged Training for Transformer Language Models, we propose a staged training setup that begins with a small model and incrementally increases the amount of compute used for training by applying a "growth operator" to increase the model depth and width. By initializing each stage with the output of the previous one, the training process effectively re-uses the compute from prior stages and becomes more efficient.

We release the reproducible code for the growth operator and evaluation scripts here.

Setup

The scripts in this repository require Python 3.7 or newer. Once you have a suitable Python environment, first install PyTorch v1.9.0 according the official instructions. Then run

pip install -r requirements.txt

Growth Operator

Our growth operators (width/depth) each take as input the entire training state (including model parameters, optimizer state, learning rate schedule, etc.) and output a new training state from which training continues.

Please see the scripts/cheatsheet.txt for more examples on how to use the corresponding scripts.

For example, you can apply the width operator with:

CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \
  --save_prefix final_gpt2_large_div2_width_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \
  --gpu_count -1 \
  --model gpt2  \
  --tokenizer gpt2 \
  --batch_size 4 \
  --grad_accum 32  \
  --lr 0.002006911598778545  \
  --warmup_steps 3000 \  \
  --train_steps 250000  \
  --val_every 50  \
  --val_batches 50 \
  --fp16 \
  --seqlen 1024 \
  --log_rate 10 \
  --num_workers 4 \
  --size GPT2_large_div2_width \
  --random \
  --resume final_runs/final_gpt2_large_div2_width_check_bs512_lr0.0021_warmup3k_seqlen1024_debug/checkpoint-xxx.ckpt \
  --doubling weights

Or the depth operator with:

CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \
  --save_prefix final_gpt2_large_div2_depthx2_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \
  --gpu_count -1 \
  --model gpt2  \
  --tokenizer gpt2 \
  --batch_size 4 \
  --grad_accum 32 \
  --lr 0.002006911598778545 \
  --warmup_steps 3000 \
  --train_steps 250000 \
  --val_every 50 \
  --val_batches 50 \
  --fp16 \
  --seqlen 1024 \
  --log_rate 10 \
  --num_workers 4 \
  --size GPT2_large_div2_depth \
  --random \
  --resume final_runs/final_gpt2_large_div2_depth_check_bs512_lr0.0020_warmup3k_seqlen1024_debug/checkpoint-epoch=0-step=6499.ckpt \
  --doubling layers

Evaluation

Use evaluation/eval_wikitext.py or evaluation/eval_lambada.py to evaluate GPT-2 on one of the supported datasets. For example:

python evaluation/eval_wikitext.py

Or using Docker:

docker build -t evaluation:latest .
docker run --rm --gpus all evaluation:latest evaluation/eval_wikitext.py

Reference

If you use staged training in your research or wish to refer to the baseline results published here, please use the following BibTeX entry.

@misc{shen2022staged,
    title={Staged Training for Transformer Language Models},
    author={Sheng Shen and Pete Walsh and Kurt Keutzer and Jesse Dodge and Matthew Peters and Iz Beltagy},
    year={2022},
    eprint={2203.06211},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}

staged-training's People

Contributors

epwalsh avatar friendshipkim avatar ibeltagy avatar sincerass avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.