Code Monkey home page Code Monkey logo

checkfreq's Introduction

CheckFreq: Frequent, Fine-Grained DNN Checkpointing

This repository contains the source code implementation of the FAST'21 paper "CheckFreq: Frequent, Fine-Grained DNN Checkpointing". This work was done as part of Microsoft Research's Project Fiddle. This source code is available under the MIT License.

CheckFreq is an automatic, fine-grained checkpointing framework that

  1. Algorithmically determines the checkpointing frequency at the granularity of iterations using systematic online profiling
  2. Dynamically tunes checkpointing frequency at runtime to bound the checkpointing overhead using adaptive rate tuning
  3. Maintains the training data invariant of using each item in the dataset exactly once per epoch by checkpointing data loader state using a light-weight resumable iterator
  4. Carefully pipelines checkpointing with computation to reduce the checkpoint cost by introducing two-phase checkpointing.

[pdf] [slides]

Setup

CheckFreq is implemented as a extendible module for PyTorch. To run CheckFreq, you will need a NVIDIA GPU with CUDA 10.0, nvidia-docker2, and Python 3. We used the prebuilt NVIDIA docker container nvcr.io/nvidia/pytorch:19.05-py3 container as the base image, which can be downloaded using,

docker pull nvcr.io/nvidia/pytorch:19.05-py3

CheckFreq's resumable data iterator is built as an extension to the state-of-the-art data loader CoorDL, built on top of NVIDIA DALI. To build a docker container based off the above base image with CheckFreq's resumable iterator, apply the patch to the master branch of CoorDL repo using

git apply resumable_iterator.patch

Then build the docker image with CheckFreq's iterator by following the instructions in its repo.

The final docker image is tagged nvidia/dali:py36_cu10.run and can be run using

    nvidia-docker run --ipc=host --mount src=/,target=/datadrive/,type=bind -it --rm --network=host --privileged nvidia/dali:py36_cu10.run

Using CheckFreq

CheckFreq can be used in the training script with a few changes.

  1. Import CheckFreq manager, and iterator in the training script

     from cf_checkpoint import CFCheckpoint
     from cf_manager import CFManager, CFMode
     from cf_iterator import  
    
  2. Initialize a checkpoint wrapper that tracks state to be checkpointed.

     chk = CFCheckpoint(model=model, optimizer=optimizer)
    

We assume that each of these parameters to be tracked exposes a state_dict that is snapshotted during the checkpoint operation. Then create a CheckFreq manager by specifying the frequency estimation mode(MANUAL/AUTO), checkpoint wrapper, and the path to store final checkpoints.

      cf_manager = CFManager(chk_prefix, chk, mode=CFMode.AUTO)
  1. Pass in the epoch and batch ID to resume the dataloader (got from the previous checkpoint if training is resumed or 0 if starting)

     self.input = ops.FileReader(..., resume_index=resume_index, resume_epoch=resume_epoch, cf_det=cf_det)
    
  2. Wrap the DALIClassificationIterator by CFIterator and optionally pass in arguments for adaptive rate tuning (dynamic=True), a checkpointing frequency is MANUAL mode is set (chk_freq=N)

     train_loader = DALIClassificationIterator(...)
     train_loader = CFIterator(train_loader, ...)
    
  3. On the main process (local rank 0), use the wrapper for optimizer.step

     cf_manager.weight_update()
    

A complete working example with changes to integrate CheckFreq in the training script for image classification is here

Example

We demonstrate an example of running CheckFreq for image classification using popular models like ResNets, VGGs, and Inception using the ImageNet ILSVC 2012 dataset.

The source code for the training script with CheckFreq integration is here

To train VGG16 across 8 GPUs on a server for 2 epochs with CheckFreq, use the following command :

  python -m torch.distributed.launch --nproc_per_node=8 models/image_classification/pytorch-imagenet-cf.py --dali -a resnet18 -b 256 --workers 3 --epochs 2  --deterministic --noeval --barrier --checkfreq --chk-prefix ./chk/ --cf_iterator --data <imagenet_data_directory> > stdout.out 2>&1

To run the same without CheckFreq, using epoch boundary checkpointing, use:

  python -m torch.distributed.launch --nproc_per_node=8 models/image_classification/pytorch-imagenet-cf.py --dali -a resnet18 -b 256  --workers 3 --epochs 2  --deterministic --noeval --barrier --checkfreq --chk-freq 0 --chk_mode_baseline --chk-prefix ./chk/ --cf_iterator --data $DATA_DIR > stdout.out 2>&1

A complete script to train different models with and without CheckFreq is here. You can run it using:

   ./run_all_256.sh <data-dir> <out-dir> <data_threads_per_GPU>

Code of Conduct

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ.

License

Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT license.

checkfreq's People

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.