Code Monkey home page Code Monkey logo

bllama's Introduction

bLLaMa

bLLaMa is a b1.58 LLaMa model.

Set up

Both the module configuration dataclass and the module itself are contained on bllama. By default, the configuration is a 1.7B model, which can be found on config.py.

from bllama import bLlamaConfig, bLlama

config = bLlamaConfig()
bllm = bLlama(config)

Training

bLLaMa is built as a Lightning module, so you may pass pl.Trainers and pl.LightningDataModules for training tasks. To faciliate, some examples of datasets the corresponding datamodules are given on utils.

from transformers import LlamaTokenizer
from utils import ShakespeareDataset, TextDataModule, GepetoDataset

tokenizer = LlamaTokenizer.from_pretrained("fxmarty/tiny-llama-fast-tokenizer")
dataset = ShakespeareDataset.from_file("/path/to/shakespeare/input.txt", tokenizer=tokenizer, max_length=1024)
dataloader = TextDataModule(dataset, batch_size=config.batch_size, train_test_split=0.9)

To setup a trainer, you may pass a pl.Trainer or a manually setup a training run.

import pytorch_lightning as pl

bllm_trainer = pl.Trainer(
    accelerator="gpu",
    max_epochs=1,
)

bllm_trainer.fit(bllm, dataloader)

Inference

The BitLinear layers of bLLaMa have 2 modes, one for training (fp32) and one for quantized inference (int8). To perform quantized inference, the weights have to be offline-quantized. bLLaMa has a built-in method to quantize the BitLinear modules for inference:

bLLaMa quantization

After quantization, the model can then generate with the generate method.

bllm.generate(prompt="In a shocking turn of events,", tokenizer=tokenizer, max_len=200, do_sample=False, top_k=3, repetition_penalty=2)

Full precision inference is also allowed, but the model will promptly caution all the BitLinear layers that are not quantized.

TODOS

  • Inference with int8 BitLinear quantization
  • Custom GEMM for lowbit inference
  • KV Cache
  • Model sharding / Parallel training optimizations
  • Custom model saving for quantized tensors
  • Full 1.7B model training

Resources and inspiration

Notes on training

This repo contains only the implementation and training code for bLLaMa. No (relevant) model checkpoints or model weights have been yet produced as it requires significantly more compute than I have at my disposal at the moment.

Nonetheless, small training runs using a 1.7B model were done to assess training performance. The training runs were done with ~15M tokens from wikitext and minipile.

Using a single NVIDIA A6000, due to the VRAM bottleneck, the batch size used was 1. This may indicate some issues with memory usage and/or optimization opportunities, as the full fp32 model alone uses ~41GB for training.

The total test training time was 5 hours. Based on this, we can extrapolate that, with the current configuration, to achieve a Chinchilla-optimal 1.7B model, it would take ~472 hours of training on a A6000.

bllama's People

Contributors

rafacelente avatar

Stargazers

 avatar Manjeet Singh avatar  avatar Tobi Akomolede AKA Mocuto avatar  avatar Hanaky avatar Siddhanth Bhattacharyya avatar Mengzhao Chen avatar DohyeonKwon avatar  avatar  avatar Jason Lee avatar Alexandre Truppel avatar  avatar  avatar  avatar  avatar Viraat Das avatar Mohammad Ausaf avatar  avatar  avatar Mateus Nobre avatar Ganesh avatar Nathan Trudeau avatar Chen-Ting Chuang avatar  avatar Jubilee.Yang avatar Elijah Kulpinski avatar Greg Schoeninger  avatar Robert avatar FelixTang avatar Aliet Expósito García avatar Underscore avatar  avatar banshan avatar  avatar Trevor Strieber avatar Don Kang avatar Francesco Baldassarri avatar  avatar Joel Boehland avatar  avatar Stephen Vorwerk avatar  avatar Nate Nethercott avatar Timon Käch avatar BoloniniD avatar Rachid avatar Taras Glek avatar A.J avatar Jobst Kehrhahn avatar iamiks avatar Lino Valdovinos avatar Han Wang avatar Subhayu Kumar Bala avatar  avatar Tuan Pham avatar  avatar  avatar Jay avatar Zizhe Wang avatar Noryve  avatar  avatar Wenxi Chen avatar  avatar  avatar Tobias Haustein avatar Jeremy Song avatar  avatar  avatar  avatar Anurag Vohra avatar Martin Mauch avatar Steffen Röcker avatar Rahul D Shetty avatar Adeel Ahmad avatar  avatar

Watchers

Martin Mauch avatar Robert avatar  avatar Paulo Alencar avatar Kostas Georgiou avatar Stefan Bielmeier avatar Underscore avatar  avatar  avatar Ronan McGovern avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.