Code Monkey home page Code Monkey logo

vit_10b_fsdp's Introduction

Vision Transformer (ViT) model using PyTorch/XLA FSDP

This repo implements sharded training of a Vision Transformer (ViT) model on a 10-billion parameter scale using the FSDP algorithm in PyTorch/XLA. It is now officially supported in the PyTorch/XLA 1.12 release.


Installation

  1. Allocate a v3-128 TPU VM pod (e.g. with name rh-128-0 in zone europe-west4-a) from the tpu-vm-pt-1.12 environment as follows according to TPU VM instruction. You can also try out larger TPU pods such as v3-256 or v3-512.
TPU_NAME=sfr-b-pang-tpu-32-us-east1-1  # change to your TPU name
ZONE=us-east1-d  # change to your TPU zone
ACCELERATOR_TYPE=v3-32  # you can also try out larger TPU pods
RUNTIME_VERSION=tpu-vm-pt-1.12  # the XLA FSDP interface is supported in PyTorch/XLA

gcloud alpha compute tpus tpu-vm create ${TPU_NAME} \
  --zone ${ZONE} \
  --accelerator-type ${ACCELERATOR_TYPE} \
  --version ${RUNTIME_VERSION}
  1. Install timm as a dependency (to create vision transformer layers) and clone this repository to all TPU VM nodes as follows.
TPU_NAME=sfr-b-pang-tpu-32-us-east1-1  # change to your TPU name
ZONE=us-east1-d  # change to your TPU zone

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} \
  --worker all \
  --command "
# ViT dependency
sudo pip3 install timm==0.4.12

# clone this repo ViT FSDP example
cd ~ && rm -rf vit_10b_fsdp_example && git clone https://github.com/bpucla/vit_10b_fsdp
"
  1. Download ImageNet-1k to a shared directory (e.g. to /datasets/imagenet-1k) that can be accessed from all nodes, which should have the following structure (the validation images moved to labeled subfolders, following the PyTorch ImageNet example).
/datasets/imagenet-1k
|_ train
|  |_ <n0......>
|  |  |_<im-1-name>.JPEG
|  |  |_...
|  |  |_<im-N-name>.JPEG
|  |_ ...
|  |_ <n1......>
|  |  |_<im-1-name>.JPEG
|  |  |_...
|  |  |_<im-M-name>.JPEG
|  |  |_...
|  |  |_...
|_ val
|  |_ <n0......>
|  |  |_<im-1-name>.JPEG
|  |  |_...
|  |  |_<im-N-name>.JPEG
|  |_ ...
|  |_ <n1......>
|  |  |_<im-1-name>.JPEG
|  |  |_...
|  |  |_<im-M-name>.JPEG
|  |  |_...
|  |  |_...

You can use a Persistent Disk or a Filestore NFS on GCP to store the ImageNet-1k dataset.

Also, you can also use --fake_data to run on fake datasets (dummy images filled with all zeros) as an alternative way to test the model.

Running the experiments

  1. Now log into your TPU VM.
TPU_NAME=rh-128-0  # change to your TPU name
ZONE=europe-west4-a  # change to your TPU zone

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --worker 0
  1. Before running any experiments, first set up the gcloud ssh configuration on your TPM VM as follows (only need to do it once):
cd ${HOME} && gcloud compute config-ssh --quiet
  1. Now we can run the experiments. For example, to train a ViT model with 10 billion parameters (5120 embed dim, 32 attention heads, 32 layers, and an MLP ratio of 4.0 that gives 20480 = 5120 * 4.0 feed-forward MLP dim), you can launch the following in a tmux session.
TPU_NAME=sfr-b-pang-tpu-32-us-east1-1  # change to your TPU name
SAVE_DIR=~/vit_10b_fsdp_example_ckpts  # this can be any directory (it doesn't need to be a shared one across nodes)

mkdir -p ${SAVE_DIR}
cd ${HOME} && python3 -m torch_xla.distributed.xla_dist \
  --tpu=${TPU_NAME} --restart-tpuvm-pod-server --env PYTHONUNBUFFERED=1 -- \
python3 -u ~/vit_10b_fsdp_example/run_vit_training.py \
  --data_dir /datasets/imagenet-1k \
  --ckpt_dir ${SAVE_DIR} \
  --image_size 224 \
  --patch_size 14 \
  --embed_dim 5120 \
  --mlp_ratio 4.0 \
  --num_heads 32 \
  --num_blocks 32 \
  --batch_size 1024 \
  --num_epochs 300 \
  --lr 1e-3 \
  --weight_decay 0.1 \
  --clip_grad_norm 1.0 \
  --warmup_steps 10000 \
  --log_step_interval 20 \
  2>&1 | tee ${SAVE_DIR}/stdout_stderr_$(date +%Y-%m-%d_%H-%M-%S).log

Note that these hyperparameters (e.g. learning rate) are not necessarily optimal and you may need to tweak them to get the best performance. You can also use --fake_data to run on fake datasets (dummy images filled with all zeros). As a comparison, you can pass --run_without_fsdp to launch without FSDP, which can only fit much smaller model sizes.

You can also try running on models larger than the 10 billion size above. In general, you will need more TPU cores to fit more parameters. Don't worry if you see messages like tcmalloc: large alloc 1677729792 bytes == 0x181ff4000 when trying to run this codebase on even larger models (e.g. 60B parameters) -- this message is not an error. You can get rid of it by passing --env TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=4294967296 in torch_xla.distributed.xla_dist to raise the tcmalloc report threshold to e.g. 4 GB.

vit_10b_fsdp's People

Contributors

bpucla avatar ronghanghu avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.