Code Monkey home page Code Monkey logo

nola's Introduction

NOLA: Compressing LoRA using Linear Combination of Random Basis

Paper: NOLA: Compressing LoRA using Linear Combination of Random Basis (ICLR 2024)

Soroush Abbasi Koohpayegani*, Navaneet K L*, Parsa Nooralinejad, Soheil Kolouri, Hamed Pirsiavash

This Repository is an official implementation of NOLA: arXiv, ICLR 2024

Our code is based on LoRA and QLoRA.


NOLA is a novel approach for fine-tuning large models such as LLMs and Vision Transformers. Similar to LoRA, NOLA uses a low-rank decomposition of weight matrices for the fine-tuning step. However, LoRA face two primary limitations:

  • The parameter count is lower-bounded by the rank one decomposition
  • The extent of reduction is heavily influenced by both the model architecture and the chosen rank.

We introduce NOLA, which brings parameter count felexiblity to LoRA. NOLA achieves this by re-parameterizing the low-rank matrices in LoRA using linear combinations of randomly generated matrices (basis) and optimizing the linear mixture coefficients only. This approach allows us to decouple the number of trainable parameters from both the choice of rank and the network architecture.

For example, in LLaMA-2 70B, NOLA is almost 20 times more compact than the most compressed LoRA without degradation in accuracy. Remarkbly, We are able to finetune LLaMA-2 70B with only 0.6M parameters only.


Simple Code of NOLA

Please use GenerateParams to generate parameters on thy fly with linear combinations of randomly generated matrices (basis).

It's important to note that the random basis remains fixed during optimization, and we only need to compute gradients with respect to coefficients. As the random basis is static, we can discard it after computing the output of the forward function. This technique significantly reduces the GPU's memory footprint since there's no need to retain the basis. For the backward pass, we must again use the same seed as in the forward pass to compute the basis and determine the gradient for the coefficients. Although there is a slight overhead in calculating the dot product between the basis and coefficients, this overhead is negligible when the number and dimensions of the basis are small. For a more detailed analysis, please refer to our paper.

class GenerateParams(torch.autograd.Function):
    # Generate parameters on the fly with random basis
    def forward(ctx, coefficients, out_dim, in_dim, seed): 
        num_basis = coefficients.shape[0]
        Out = torch.zeros(out_dim, in_dim).to(coefficients.device)
        rand_seed = torch.randint(int(1e10), (1,))
        W = torch.zeros(num_basis, out_dim, in_dim, 
                        device=coefficients.device, dtype=coefficients.dtype)
        nn.init.uniform_(W, a=-1.0, b=1.0)
        Out = torch.einsum('b,boi->oi', coefficients, W)
        params = torch.autograd.Variable(torch.tensor([out_dim, in_dim, seed]))
        ctx.save_for_backward(coefficients, params)
        return Out 
    def backward(ctx, grad_output):
        coefficients, params = ctx.saved_tensors
        num_basis = coefficients.shape[0]

        out_dim, in_dim, seed = params
        rand_seed = torch.randint(int(1e10), (1,))
        grad_coefficients = torch.empty(0).to(grad_output.device)

        W = torch.zeros(num_basis, out_dim, in_dim, 
                        device=coefficients.device, dtype=coefficients.dtype)
        nn.init.uniform_(W, a=-1.0, b=1.0) 
        W = W.permute(1,2,0).reshape(-1, num_basis)
        grad_coefficients = torch.einsum('d,dl->l', grad_output.flatten(), W)

        return grad_coefficients , None, None, None

Below example shows how to use this re-parametrization trick on a Linear layer and LoRA:

class NOLALinear(nn.Linear): 
    # NOLA Linear Layer 
    def __init__(
        in_features: int, 
        out_features: int, 
        num_basis = 256, 
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        self.num_basis = num_basis 
        self.generate_params = GenerateParams.apply
        self.seed = nn.Parameter(torch.randint(int(1e10), (1,)), requires_grad=False)
        self.coefficients = nn.Parameter(torch.zeros(num_basis), requires_grad=True)
        self.weight.requires_grad = False
    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, num_basis={}'.format(
            self.in_features, self.out_features, self.num_basis)  

    def forward(self, x: torch.Tensor):
        W = self.generate_params(self.coefficients,
                          self.seed) + self.weight
        return x @ W.t()

class LoRA(nn.Module):
    # LoRA with NOLA, implementation of a dense layer
    def __init__(
        in_features: int, 
        out_features: int, 
        use_nola = False,
        num_basis = 1024, 
        alpha = 1.0, 
        super(LoRA, self).__init__(**kwargs)
        self.in_features = in_features
        self.out_features = out_features
        self.num_basis_A = num_basis
        self.num_basis_B = num_basis
        self.rank = rank
        self.alpha = alpha 
        self.scale = self.alpha / self.rank
        self.generate_params = GenerateParams.apply
        self.use_nola = use_nola 
        if use_nola: 
            self.lora_A = NOLALinear(in_features, rank, num_basis=self.num_basis_A, bias=False)
            self.lora_B = NOLALinear(rank, out_features, num_basis=self.num_basis_B, bias=False)
            self.lora_A = nn.Linear(in_features, rank, bias=False)
            self.lora_B = nn.Linear(rank, out_features, bias=False)
    def extra_repr(self) -> str:
        return 'NOLA: rank={}, in_features={}, out_features={}, num_basis_A={}, num_basis_B={}'.format(
            self.rank, self.in_features, self.out_features, self.num_basis_A, self.num_basis_B)

    def reset_lora_parameters(self):        
        if self.use_nola: 
        # initialize A the same way as the default for nn.Linear and B to zero
        nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))

    def forward(self, x: torch.Tensor):
        return self.lora_B(self.lora_A(x)) * self.scale


All our experiments use the PyTorch library. Instructions for PyTorch installation can be found here. We primarily use GPT-2 and LLaMA 2 for our experiments on natural language generation tasks.


We use E2ENLG, DART and WebNLG datasets for our experiments on natural language generation with GPT-2. Moreover, we use Alpaca dataset for instruction fine-tuning with LLaMA-2.

Getting Started -- Large Language Model Finetuning (LLaMA)

We modify QLoRA code and add NOLA to LoRA model.

Getting Started -- Large Language Model Finetuning (GPT-2)

Please install dependencies in a virtual environment:

pip install -r requirement.txt
cd ./eval
cd ..

Finetuning with NOLA (float)

  1. Finetune GPT-2 with NOLA:
python -m torch.distributed.launch --nproc_per_node=1 src/ \
    --train_data ./data/e2e/train.jsonl \
    --valid_data ./data/e2e/valid.jsonl \
    --train_batch_size 8 \
    --grad_acc 1 \
    --valid_batch_size 4 \
    --seq_len 512 \
    --model_card \
    --init_checkpoint ./pretrained_checkpoints/gpt2-medium-pytorch_model.bin \
    --platform local \
    --clip 0.0 \
    --lr 0.1 \
    --weight_decay 0.0 \
    --correct_bias \
    --adam_beta2 0.999 \
    --scheduler linear \
    --warmup_step 500 \
    --max_epoch 5 \
    --save_interval 1000 \
    --nola_rank 8 \
    --nola_c 1.0 \
    --nola_qv \
    --label_smooth 0.1 \
    --work_dir ./trained_models/GPT2_M/e2e \
    --random_seed 110
  1. Use beam search to generate outputs from the finetuned model:
python -m torch.distributed.launch --nproc_per_node=1 src/ \
    --data ./data/e2e/test.jsonl \
    --batch_size 1 \
    --seq_len 512 \
    --eval_len 64 \
    --model_card \
    --init_checkpoint ./trained_models/GPT2_M/e2e/ \
    --platform local \
    --nola_rank 8 \
    --nola_c 1.0 \
    --nola_qv \
    --beam 10 \
    --length_penalty 0.8 \
    --no_repeat_ngram_size 4 \
    --repetition_penalty 1.0 \
    --eos_token_id 628 \
    --work_dir ./trained_models/GPT2_M/e2e \
    --output_file predict.26290.b10p08r4.jsonl
  1. Decode outputs:
python src/ \
    --vocab ./vocab \
    --sample_file ./trained_models/GPT2_M/e2e/predict.26290.b10p08r4.jsonl \
    --input_file ./data/e2e/test_formatted.jsonl \
    --output_ref_file e2e_ref.txt \
    --output_pred_file e2e_pred.txt
  1. Run evaluation on E2E test set
python eval/e2e/ e2e_ref.txt e2e_pred.txt -p

Finetuning with NOLA (Quantized)

  1. Finetune GPT-2 with NOLA (add --qnola flag and number of bits per parameter --qbits 3):
python -m torch.distributed.launch --nproc_per_node=1 src/ \
    --train_data ./data/e2e/train.jsonl \
    --valid_data ./data/e2e/valid.jsonl \
    --train_batch_size 8 \
    --grad_acc 1 \
    --valid_batch_size 4 \
    --seq_len 512 \
    --model_card \
    --init_checkpoint ./pretrained_checkpoints/gpt2-medium-pytorch_model.bin \
    --platform local \
    --clip 0.0 \
    --lr 0.1 \
    --weight_decay 0.0 \
    --correct_bias \
    --adam_beta2 0.999 \
    --scheduler linear \
    --warmup_step 500 \
    --max_epoch 5 \
    --save_interval 1000 \
    --nola_rank 8 \
    --nola_c 0.01 \
    --qnola \
    --qbits 3 \
    --nola_qv \
    --label_smooth 0.1 \
    --work_dir ./trained_models/GPT2_M/e2e \
    --random_seed 110
  1. Use beam search to generate outputs from the finetuned model:
python -m torch.distributed.launch --nproc_per_node=1 src/ \
    --data ./data/e2e/test.jsonl \
    --batch_size 1 \
    --seq_len 512 \
    --eval_len 64 \
    --model_card \
    --init_checkpoint ./trained_models/GPT2_M/e2e/ \
    --platform local \
    --nola_rank 8 \
    --nola_c 0.01 \
    --nola_qv \
    --qnola \
    --qbits 3 \
    --beam 10 \
    --length_penalty 0.8 \
    --no_repeat_ngram_size 4 \
    --repetition_penalty 1.0 \
    --eos_token_id 628 \
    --work_dir ./trained_models/GPT2_M/e2e \
    --output_file predict.26290.b10p08r4.jsonl
  1. Decode outputs:
python src/ \
    --vocab ./vocab \
    --sample_file ./trained_models/GPT2_M/e2e/predict.26290.b10p08r4.jsonl \
    --input_file ./data/e2e/test_formatted.jsonl \
    --output_ref_file e2e_ref.txt \
    --output_pred_file e2e_pred.txt
  1. Run evaluation on E2E test set
python eval/e2e/ e2e_ref.txt e2e_pred.txt -p


If you make use of the code, please cite the following work:

 author = { Koohpayegani, Soroush Abbasi and Navaneet, K L and Nooralinejad, Parsa and Kolouri, Soheil and Pirsiavash, Hamed},
 title = {NOLA: Networks as Linear Combination of Low Rank Random Basis},
 year = {2023}


This project is under the MIT license.

nola's People


klnavaneet avatar soroush-abbasi avatar


 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar


 avatar  avatar  avatar



nola's Issues

Can not reproduce the results of "Finetuning with NOLA (float)"

Hey @klnavaneet, thank you for your nice work! I follow the instruction of "Finetuning with NOLA (float)" and run the codes, and finally the results are

  • BLEU: 0.6977
  • NIST: 8.7974
  • METEOR: 0.4656
  • ROUGE_L: 0.7148
  • CIDEr: 2.4907,
    which are a bit lower than those mentioned in your paper.

And I'd greatly appreciate it if you could share the latest version of codes. I'm wondering if I made any mistakes while editing the codes.(There seems to be some bugs in the current version, and I attempted to correct them. Maybe I messed up here.๐Ÿ˜‚)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.