Code Monkey home page Code Monkey logo

1.58bitnet's Introduction

1.58 BitNet Ternary Model Implementation

This project focuses on creating and training a LLaMA model using ternary quantization techniques. The goal is to optimize the model's performance while reducing its memory footprint. This is my 1.58 BitNet implementation based on this paper: https://arxiv.org/abs/2402.17764

Basically when you generate the model - the model is blank and you need to train it. This is where I'm having the biggest issues - I still cant seem to get the training to work properly. I need help with this implementation. I've take this as far as I can with my knowledge but for some reason I cant get the training to work properly.

I've been testing different parameter models against the implementation and this is what I've observed:

Parameter Size Model Size (MB/GB)
350M 350 MB
750M 750 MB
1B 1 GB
3B 3 GB
7B 7 GB
14B - GB
24B - GB
34B - GB
70B - GB
100B - GB
120B - GB
300B - GB

I was able to create these size models on my 96GB M2 Max Macbook Pro. Just an FYI these scripts are specifically created to work on MPS with a CPU fallback. I'm hoping I can get help to get it working through MLX once we've fixed the finetuning / training issues.

I also think that this isnt completely optimized for memory management and theres probably opportunity for that.

Table of Contents

Installation

Git clone the repo

Usage

  1. Run new-model-architecture-creation.py
  2. You'll be prompted for how many parameters you want your model to be. The script will create the model and save it in the same repo as where the files are saved
  3. Once the model is created - to fine-tune the LLaMA model, use the trainingv2.py script with the appropriate command-line arguments:
python trainingv2.py --dataset <dataset_path> --model_path <model_path> --batch_size <batch_size> --num_epochs <num_epochs> --learning_rate <learning_rate> --output_dir <output_directory> --iters <num_iterations> --max_length <max_sequence_length> --grad_accum_steps <gradient_accumulation_steps>
  • dataset_path: Path to the dataset file.
  • model_path: Path to the pre-trained LLaMA model.
  • batch_size: Batch size for training.
  • num_epochs: Number of training epochs.
  • learning_rate: Learning rate for the optimizer.
  • output_directory: Output directory to save the fine-tuned model.
  • num_iterations: Number of training iterations.
  • max_sequence_length: Maximum sequence length for input tokens.
  • gradient_accumulation_steps: Number of steps for gradient accumulation.

Sample training command:

python trainingv2.py --dataset /Users/user/folder/Datasets/codeDataset/data/train.jsonl --batch_size 8 --num_epochs 5000 --output_dir /Users/user/Downloads/llama_750m_finetune_tritnet-v2 --iters 10000 --max_length 4096 --learning_rate 1e-4 --grad_accum_steps 10

Dataset

The dataset should be in one of the following formats: txt, json, jsonl. The preprocess_dataset function in trainingv2.py handles the preprocessing of the dataset based on its format. Here is the format that the jsonl file should be formatted in :

{"text": "This is an example for the model."}

For example: {"text": "<s>[INST] Create an array of length 5 which contains all even numbers between 1 and 10. [/INST]arr = [2, 4, 6, 8, 10]</s>"}

{"text": "<s>[INST] Formulate an equation to calculate the height of a triangle given the angle, side lengths and opposite side length. [/INST]Height of triangle = opposite side length * sin (angle) / side length</s>"}

{"text": "<s>[INST] Write a replace method for a string class which replaces the given string with a given set of characters.string = \"Hello World!\"\nreplace_with = \"Greetings!\" [/INST]def replace(self, replace_with):\n new_string = \"\"\n for char in self:\n if char == \" \":\n new_string += replace_with\n else:\n new_string += char\n return new_string</s>"}

{"text": "<s>[INST] Create an array of length 15 containing numbers divisible by 3 up to 45. [/INST]arr = [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45]</s>"}

Model Architecture

The LLaMA model architecture is defined in llama_model.py. It consists of an embedding layer, multiple decoder layers, and a language model head. The model uses RMSNorm for normalization and applies rotary position embeddings to the attention mechanism.

Ternary Quantization

Ternary quantization is applied to the model's weights to reduce memory usage. The QuantizedEmbedding and BitLinear classes in llama_model.py handle the quantization of the embedding layer and linear layers, respectively. The quantize_tensor function in quantization_utils.py performs the actual quantization.

Custom Gradient Checkpointing

To reduce memory consumption during training, custom gradient checkpointing is implemented in custom_gradient_checkpointing.py. The custom_checkpoint function is used to checkpoint the forward pass and compute gradients during the backward pass.

Training

The train function in trainingv2.py handles the training process. It iterates over the dataset in batches, computes the loss, and performs gradient accumulation. The model's parameters are updated using an optimizer.

Evaluation

The evaluate function in trainingv2.py evaluates the model on a validation set. It computes the average loss over the validation batches.

Saving and Loading Models

The model generated from new-model-architecture-creation.py will be saved in the same directory where you ran the script from. When running trainingv2.py you will be prompted to enter the path of that model. The output directory you specify in the training command will be where the finetuned model is saved.

Contributing

Contributions to this project are welcome. If you find any issues or have suggestions for improvements, please open an issue or submit a pull request.

1.58bitnet's People

Contributors

nkotak avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.