Code Monkey home page Code Monkey logo

bitllama's Introduction

1.58 Bit Llama Model

Initial implementation of 1.58-bit Llama Model following the reference paper: https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf

In this paper, they outline the code changes necessary to make 1.58 bit ternary quantized training work. In this repo + my transformers fork, I have implemented the changes necessary to make this work. The main changes are in the transformers repo, where I have added the BitLlamaModel and BitLlamaForCausalLM classes.

Training code changes

Training hyperparams

See my transformers repo, the uploaded initialized model or configuration_bitllama.py and modeling_bitllama.py for the modeling implementation.

I've also included a basic axolotl toy pretraining config (pretraining_bitllama.yaml) based on a small model that I initialized and uploaded to HF with the modelling code. You can use this to test the training. Currently, the training is not working as expected as I'm getting nan grads: alt text

Installation

git clone [email protected]:bjoernpl/transformers.git@add_bitllama
cd transformers
pip install -e .

For optional training with axolotl, you can install the following:

git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl

pip3 install packaging
pip3 install -e '.[flash-attn]'

Changes to code:

The relevant changes are adding the following to the modeling code:

def activation_quant(x):
    scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
    y = (x * scale).round().clamp_(-128, 127) / scale
    return y


def weight_quant(w):
    scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
    u = (w * scale).round().clamp_(-1, 1) / scale
    return u


class BitLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features, bias=bias)
        self.norm = LlamaRMSNorm(in_features)

    def forward(self, x):
        w = self.weight
        x_norm = self.norm(x)
        x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
        w_quant = w + (weight_quant(w) - w).detach()
        return F.linear(x_quant, w_quant)

Also changing as the paper says all F.linear to BitLinear in the attention layers.

Notes

This is absolutely WIP and experimental. Contributions and ideas towards fixing the nan grads are welcome. This also does not include the custom kernels mentioned (but not described) in the paper.

bitllama's People

Contributors

bjoernpl avatar

Stargazers

Marc E. Solèr avatar Sebastian Bodza avatar

Watchers

Kostas Georgiou avatar  avatar

bitllama's Issues

o_proj initialization

Hi, nice to see some work in that area
I just checked the diff, to check why you are having the nans in training

In the llamaattention you are using this:

self.o_proj = BitLinear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)

while the transformers original is using:

self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)

Is there a specific reason for using num_heads * head_dim instead of hidden_size ?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.