Code Monkey home page Code Monkey logo

fairscale's Introduction

FairScale Logo

PyPI Documentation Status CircleCI PyPI - License PRs Welcome

Description

FairScale is a PyTorch extension library for high performance and large scale training on one or multiple machines/nodes. This library extends basic PyTorch capabilities while adding new experimental ones.

FairScale supports:

  • Parallelism:
    • pipeline parallelism (fairscale.nn.Pipe)
  • Sharded training:
    • Optimizer state sharding (fairscale.optim.oss)
    • Sharded grad scaler - automatic mixed precision
    • Sharded distributed data parallel
  • Optimization at scale:
    • AdaScale SGD (from fairscale.optim import AdaScale)

Requirements

  • PyTorch >= 1.5.1

Installation

Normal installation:

pip install fairscale

Development mode:

cd fairscale
pip install -r requirements.txt
pip install -e .

If either of the above fails, add --no-build-isolation to the pip install command (this could be a problem with recent versions of pip).

Getting Started

The full documentation (https://fairscale.readthedocs.io/) contains instructions for getting started and extending fairscale.

Examples

Pipe

Run a 4-layer model on 2 GPUs. The first two layers run on cuda:0 and the next two layers run on cuda:1.

import torch

import fairscale

model = torch.nn.Sequential(a, b, c, d)
model = fairscale.nn.Pipe(model, balance=[2, 2], devices=[0, 1], chunks=8)

Optimizer state sharding (ZeRO)

See a more complete example here, but a minimal example could look like the following :

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from fairscale.optim.oss import OSS
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP

def train(
    rank: int,
    world_size: int,
    epochs: int):

    # DDP init example
    dist.init_process_group(backend='nccl', init_method="tcp://localhost:29501", rank=rank, world_size=world_size)

    # Problem statement
    model = myAwesomeModel().to(rank)
    dataloader = mySuperFastDataloader()
    loss_fn = myVeryRelevantLoss()
    base_optimizer = torch.optim.SGD # pick any pytorch compliant optimizer here
    base_optimizer_arguments = {} # pass any optimizer specific arguments here, or directly below when instantiating OSS

    # Wrap the optimizer in its state sharding brethren
    optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments)

    # Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks
    model = ShardedDDP(model, optimizer)

    # Any relevant training loop, nothing specific to OSS. For example:
    model.train()
    for e in range(epochs):
        for batch in dataloader:
            # Train
            model.zero_grad()
            outputs = model(batch["inputs"])
            loss = loss_fn(outputs, batch["label"])
            loss.backward()
            optimizer.step()

    dist.destroy_process_group()

if __name__ == "__main__":
    # Supposing that WORLD_SIZE and EPOCHS are somehow defined somewhere
    mp.spawn(
        train,
        args=(
            WORLD_SIZE,
            EPOCHS,
        ),
        nprocs=WORLD_SIZE,
        join=True,
    )

AdaScale SGD

AdaScale can be used to wrap a SGD optimizer and to be used in DDP (Distributed Data Parallel) training or non-DDP with gradient accumulation. The benefit is to re-use the same LR schedule from a baseline batch size when effective batch size is bigger.

Note that AdaScale does not help increase per-GPU batch size.

from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR  # or your scheduler
from fairscale.optim import AdaScale

...
optim = AdaScale(SGD(model.parameters(), lr=0.1))
scheduler = LambdaLR(optim, ...)
...
# Note: the train loop should be with DDP or with gradient accumulation.
last_epoch = 0
step = 0
done = False
while not done:
    for sample in dataset:
        ...
        step += optim.gain()
        optim.step()
        epoch = step // len(dataset)
        if last_epoch != epoch:
            scheduler.step()
            last_epoch = epoch
        if epoch > max_epoch:
            done = True

Primary goal is to allow scaling to bigger batch sizes without losing model accuracy. (However, training time might be longer comparing to without AdaScale.)

At a high level, we want ML researchers to:

  • go parallel more easily (i.e. no need to find new learning rate schedules)
  • not worrying about lossing accuracy
  • potentially higher GPU efficiency (fewer steps, less networking overhead, etc.)

Testing

We use circleci to test on PyTorch versions 1.5.1, 1.6.0 and 1.7.0 and CUDA version 10.1. Please create an issue if you are having trouble with installation.

Contributors

See the CONTRIBUTING file for how to help out.

License

fairscale is licensed under the BSD-3-Clause License.

fairscale.nn.pipe is forked from torchgpipe, Copyright 2019, Kakao Brain, licensed under Apache License.

fairscale.nn.model_parallel is forked from Megatron-LM, Copyright 2020, NVIDIA CORPORATION, licensed under Apache License.

fairscale.optim.adascale is forked from AdaptDL, Copyright 2020, Petuum, Inc., licensed under Apache License.

References

Here is a list of all authors on relevant research papers this work is based on:

  • torchgpipe: Chiheon Kim, Heungsub Lee, Myungryong Jeong, Woonhyuk Baek, Boogeon Yoon, Ildoo Kim, Sungbin Lim, Sungwoong Kim. [Paper] [Code]
  • ZeRO: Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He. [Paper] [Code]
  • Megatron-LM: Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper, Bryan Catanzaro. [Paper][Code]
  • AdaScale SGD: Tyler B. Johnson, Pulkit Agrawal, Haijie Gu, Carlos Guestrin. [Paper]

fairscale's People

Contributors

andersonic avatar anj-s avatar aurickq avatar blefaudeux avatar froody avatar jessijzhao avatar joshim5 avatar min-xu-ai avatar msbaines avatar ngoyal2707 avatar stas00 avatar vitaliyli avatar vittorio-caggiano avatar yuanyuan0057 avatar

Stargazers

 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.