Code Monkey home page Code Monkey logo

lowmemorybp's Introduction

LowMemoryBP

This is the official repository of our paper "Reducing Fine-Tuning Memory Overhead by Approximate and Memory-Sharing Backpropagation"[paper]. The Approx-BP and MS-BP techniques from our paper can reduce a significant amount of GPU activation memory usage during fine-tuning stage, without slowing down the training throughput.

activation Fig 1: The effect of our method on the activation memory usage. The separated parts are the reduced activation memory.

comparision Fig 2: The comparision of the our method with gradient checkpointing and Mesa on the peak GPU memory usage.

Up to now, we have applied these techniques to two activation layers (GELU, SiLU) and two normalization layers (LayerNorm, RMSNorm). The deduced low-memory counterparts are packed in the lomem torch-based package.

Installation Guidance of lomem

As lomem is developped based on pytorch and CUDA, you should prepare a python3 environment with a CUDA-available pytorch installed. And you should install a CUDA with the same version as that used to compile pytorch. You can chack the version of pytorch by pip show torch and the version of CUDA by nvcc -V.

In our test, we use torch_2.3.1+cu118 and cuda_11.8 in the Ubuntu 20.04.6 LTS system with gcc version 9.4.0.

To install:

  1. Make sure that packaging is installed (pip install packaging)
  2. Make sure that ninja is installed (pip install ninja)
  3. Clone this repository and enter this directory (cd LowMemoryBP)
  4. Compile from source pip install -e .

Usage and Features

lomem supports CUDA-device tensors with fp32, fp16 and bf16 dtype.

Firstly, import torch and lomem:

import torch
import lomem

Activation Layers

To use ReGELU2 and ReSiLU2, just replace torch.nn.GELU and torch.nn.SiLU with lomem.nn.ReGELU2 and lomem.nn.ReSiLU2, for example:

# act_layer_1 = torch.nn.GELU()
act_layer_1 = lomem.nn.ReGELU2()

# act_layer_2 = torch.nn.SiLU()
act_layer_2 = lomem.nn.ReSiLU2()

ReGELU2 and ReSiLU2 have the approximate derivatives, so that they only store 2-bit-type activation momery. The searching programs for ReGELU2 and ReSiLU2 are put in the subdirectory. This technique can be viewed as 2-bit functional quantization of the derivative function, while remains the original function. Therefore, they produce much less activation memory ($\times$ 8 for 16bit-type) than GELU and SiLU.

Nomalization Layers

To use MS-LayerNorm and MS-RMSNorm, you should replace torch.nn.LayerNorm and RMSNorm with lomem.nn.MSLayerNorm and lomem.nn.MSRMSNorm, for example:

# norm_layer_1 = torch.nn.LayerNorm(2048, eps=1e-8, elementwise_affine=False)
norm_layer_1 = lomem.nn.MSLayerNorm(2048, eps=1e-8)

# norm_layer_2 = RMSNorm(2048, eps=1e-8, elementwise_affine=False)
norm_layer_2 = lomem.nn.MSRMSNorm(2048, eps=1e-8)

MS-LayerNorm and MS-RMSNorm use the resigned formulas, storing the outputs into the activation memory and thereby sharing the activation memory with the following layers. Therefore, they can be viewed as activation memory free layers.

Note:

  1. MSLayerNorm and MSRMSNorm contain no affine parameters. When loading pretrained weights into the model with MSLayerNorm or MSRMSNorm, you should manually merge the affine parameters of LayerNorm or RMSNorm into the following linear layers and maybe change the computational process to keep the mathematical consistency (please refer to this code).

  2. For unknown reasons, torch.nn.functional.linear (thereby torch.nn.Linear) can not share activation memory with MSLayerNorm or MSRMSNorm. To address this problem, use our custom implementation to express linear operation (please refer to this code).

Bool Packing

Apart from the implementations of low-memory activaiton layers and normalization layers, lomem provides useful functions to pack a bool-type tensor into a uint8-type tensor and restore it.

shape = (11, 45, 77)
x = torch.rand(shape, device="cuda") > 0.5           # bool  (11, 45, 77)
z = lomem.functional.pack_bool_to_uint8(x)           # uint8 (4765)
y = lomem.functional.unpack_uint8_to_bool(z, shape)  # bool  (11, 45, 77)
# print((x == y).all()) True

Experiments

Please see our experiments subdirectory, which contains:

  1. ViT experiments
  2. LLaMA experiments (Coming soon!)
  3. RoBERTa experiments (Coming soon!)

Citation

@inproceedings{yangreducing,
  title={Reducing Fine-Tuning Memory Overhead by Approximate and Memory-Sharing Backpropagation},
  author={Yang, Yuchen and Shi, Yingdong and Wang, Cheems and Zhen, Xiantong and Shi, Yuxuan and Xu, Jun},
  booktitle={Forty-first International Conference on Machine Learning}
}

lowmemorybp's People

Contributors

yyyyychen avatar

Stargazers

 avatar  avatar Shi avatar yao teng avatar Sang avatar kyle avatar  avatar Xiang Zhao avatar  avatar Xu Jun avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.