Code Monkey home page Code Monkey logo

mlx-tuning-fork's Introduction

mlx-tuning-fork

A very basic framework for parameterized Large Language Model (Q)LoRa fine-tuning with MLX. It uses mlx, mlx_lm, and OgbujiPT, and is based primarily on the excellent mlx-example libraries but adds very minimal architecture for systematic running of easily parameterized fine tunes, hyperparameter sweeping, declarative prompt construction, an equivalent of HF's train on completions, and other capabilities.

Installation

Can be installed via:

$ pip install mlx-tuning-fork

Currently just has a single Mistral prompt format (-f/ --prompt-format) module, but with mlx-lm and OgbujiPT you can do something similar with other models:

  • Llama
  • Mixtral
  • Qwen
  • [..]

Command-line options

You can get documentation of the command-line options for fine tuning via:

Usage: python -m mlx_tuning_fork.training [OPTIONS] CONFIG_FILE

Options:
  --verbose / --no-verbose
  --summary / --no-summary        Just summarize training data
  --loom-file TEXT                An OgbujiPT word loom file to use for prompt
                                  construction
  --loom-markers TEXT             Loom marker values
  -p, --prompt TEXT               Commandline prompt (overrides) prompt in
                                  YAML configuration
  -t, --temperature FLOAT         Prompt generation temperature
  -nt, --num-tokens INTEGER       Overide number of tokens in config file
  --train-type [completion-only|self-supervised]
  -f, --prompt-format [mistral|chatml]
  -a, --adapter TEXT              Adapter to use instead of the one specified
                                  in the config file
  --wandb-project TEXT            Wandb project name
  --wandb-run TEXT                Wandb run name
  -rp, --repetition-penalty FLOAT
                                  The penalty factor for repeating tokens
                                  (none if not used)
  --repetition-context-size INTEGER
                                  The number of tokens to consider for
                                  repetition penalty
  -tp, --top-p FLOAT              Sampling top-p
  --build-prompt TEXT             Which word loom sections to use in building
                                  the claim (space-separated list of sections)
  --help                          Show this message and exit.

The format of the prompts used to train the model is specified via the -f/--prompt-format option, which currently is one of mistral or chatml.

Configuration

It uses mlx_lm's YAML config format and adds additional parameters and sections:

  • epochs (How many epochs, i.e., the number of iterations for a full pass of the data)
  • reporting_interval_proportion (The proportion of iterations in an epoch to wait between recording training loss - defaults to .01 or 1%)
  • validation_interval_proportion (Same proportions for interval between validations - defaults to 0.2 or 20%)
  • validations_per_train_item (The ration of the number of validation per training record seen - defaults to .5 or 1 validation per 2 training records)
  • adapter_save_interval_proportion (Same proportions for intervals between saving the LoRa adapter - defaults to .1)

Learning Rate Schedules

Learning rate schedulers can be specified in the configuration file with a section such as the following ( for Cosine annealing):

learning_schedule:
  type: "cosine"
  max_lr: 2e-5 #upper bound for learning rate 
  cycle_length: -1 #-1 for the number of steps/iterations in 1 epoch or a specific number otherwise (LR set to min_lr afterwards)

The following for Cosine Annealing with proportional warmup:

learning_schedule:
  type: "cosine_w_warmup"
  start_lr: 1e-8 #learning rate used at start of the warm-up, which ends at the top-level learning rate
  warmup_proportion: .1 #proportion of steps/iterations in 1 epoch to spend warming up
  min_lr: 1e-7
  cycle_length: -1

Otherwise a constant learning rate (specified via learning_rate top-level configuration variable) is used throughout

Prompting

It also provides the ability to dispatch prompts to the model referenced in the config (in conjunction with any LoRA adapters specified). The -p/--prompt option can be used to provide a prompt, and the -t/--temperature, -rp/--repetition-penalty, --repetition-context-size, -tp/--top-p can be used to configure the evaluation of the prompt. There is also an additional colorize parameter (specified in the config), which if true, will render the model's completion using a coloring scheme that captures the probability of each token using mlx_lm's capability in this regard.

Declarative Prompts Construction

OgbujiPts Word Loom can also be used for templated construction of prompts.

There are 3 command-line options for this:

--loom-file TEXT                An OgbujiPT word loom file to use for prompt
                                construction
--build-prompt TEXT             Which word loom sections to use in building
                                 the claim (space-separated list of sections)                                  
--loom-markers TEXT             Loom marker values

The --loom-file option is the location of a word loom file to use for prompt construction, a TOML file.

The loom file provides a system prompt, context, as well as the user prompt. The system prompt and context are optional, but the user prompt is not.

The --build-prompt option is a expected to be single or list of table header names. If only one is provided, it is assumed to the name of the table with a text key whose value will be used for the user prompt. If two values are provided, they should be quoted and separated by spaces. The first refers to a table that provides the system prompt and the second refers to the user prompt. Finally, if three values are provided they are assumed to be system prompt, context, and user prompt.

If they are not specified via --build-prompt, the system prompt is assumed to be specified in a table named system_prompt, the context is from a table named context, and the user prompt is from a table named question.

If any of the table header name of the context is of the form [filename.txt] the contents of the specified filename are used for the instead.

If any of the text values in the corresponding tables have curly braces, the --loom-markers option can be used to provide values for the names specified in between the braces. It is expected to be a string in the format: name=[.. value ..].

So, the following command-line:

$ python -m mlx_tuning_fork.training --loom-file=loom.toml \
         --build-prompt "system_prompt_final templated_question_final" -f chatml \
         --loom-markers "medical_problems=[Lymphoid aggregate]" /path/to/loom.toml

where the contents of loom.toml are:

lang = "en"
[system_prompt_final]
text = """You are a medical professional.  If you cannot provide an answer based on the given context, please let me know."""

[context]
text = """Lymphoid aggregates are a collection of B cells, T cells, and supporting cells, present within the stroma of various organs"""

[templated_question_final]
text = """The patient has {Lymphoid aggregate}.  Summarize the patient's problems"""

will result in the following ChatML prompt being sent to the model:

<|im_start|>system
You are a medical professional.  If you cannot provide an answer based on the given context, please let me know.

Lymphoid aggregates are a collection of B cells, T cells, and supporting cells, present within the stroma of various organs
<|im_end|>
<|im_start|>user

The patient has {medical_problems}.  Summarize the patient's problems
<|im_end|>
<|im_start|>assistant

Dataset format

The dataset files are expected to be in this format:

{"input": "[..]", 
 "output": "[..]"}

Learning (completion-only v.s. self-supervised)

By default, mlx_tuning_fork will train on completions only, using the input field for the input prompt and output for the expected output. However, you can use mlx_lm's default self-supervised learning using the --train-type with a value of self-supervised. In this case, only the value of the output field in the training data is used.

Running Weights and Biases (Wandb) Hyperparameter Sweeps

mlx_tuning_fork also allows you to run Wandb hyperparameter sweeps/searches using the mlx_tuning_form.wandb_sweep module. You can get the command-line options for this via:

$ python -m mlx_tuning_fork.wandb_sweep --help
Usage: python -m mlx_tuning_fork.wandb_sweep [OPTIONS] CONFIG_FILE

Options:
  --verbose / --no-verbose
  --wandb-project TEXT            Wandb project name
  --train-type [completion-only|self-supervised]
  -f, --prompt-format [mistral|chatml]
  --help                          Show this message and exit.

It takes a single argument which is a Wandb sweep configuration (YAML) file . The --wandb-project options refers to a Wandb project where the sweep output is be stored.

mlx-tuning-fork's People

Contributors

chimezie avatar

Stargazers

Chris avatar  avatar Mark avatar  avatar  avatar  avatar Riley Retzloff avatar Adeel Ahmad avatar Beckett avatar Reza Sayar avatar Art A. avatar Martin Mauch avatar Nick Chapman avatar Patrick Sprowls avatar  avatar Choccy avatar  avatar Udoka Ogbuji avatar Jay shah avatar Ali avatar  avatar Ivan Fioravanti avatar Fred Bliss avatar Todsaporn Banjerdkit avatar Anchen avatar  avatar Tim Kersey avatar  avatar Awni Hannun avatar Uche Ogbuji avatar

Watchers

 avatar Kostas Georgiou avatar  avatar

mlx-tuning-fork's Issues

__main__.py missing

It seems a main.py is missing ? Can't execute it as indicated.
Thanks for your help

Dataset format example

Newbies here, I'm not sure what exactly to put in here?

{
 "input": "[..]", 
 "output": "[..]"
}

Something like this?

[
  {
    "input": "Who killed Cedric Diggory?",
    "output": "Peter Pettigrew, acting under the orders of Lord Voldemort, during the Triwizard Tournament in Harry Potter and the Goblet of Fire."
  },
  {
    "input": "What are Dementors?",
    "output": "Spectral guards of Azkaban prison, feeding on human happiness and hope, leaving victims with feelings of despair and hopelessness."
  },
  {
    "input": "What is the Order of the Phoenix?",
    "output": "A secret society led by Albus Dumbledore, formed to oppose Lord Voldemort and his Death Eaters."
  },
  {
    "input": "What is the significance of the Marauder's Map?",
    "output": "A magical map revealing Hogwarts' secrets and the location of its occupants, created by James Potter, Sirius Black, Remus Lupin, and Peter Pettigrew during their school years."
  },
  {
    "input": "What is the purpose of the Triwizard Tournament?",
    "output": "A magical competition between three schools of witchcraft and wizardry, testing magical prowess and courage."
  },
  {
    "input": "Who are Harry's closest friends?",
    "output": "Ron Weasley and Hermione Granger, forming the \"Golden Trio\" who face challenges and adventures together."
  },
  {
    "input": "What is the story behind the Chamber of Secrets?",
    "output": "A hidden chamber under Hogwarts, opened by Tom Riddle (Voldemort) in his youth, unleashing a monstrous basilisk."
  },
  {
    "input": "What is the function of the Ministry of Magic?",
    "output": "The governing body of the magical world in Britain, responsible for regulating magic and maintaining secrecy from Muggles (non-magical people)."
  },
  {
    "input": "What are Horcruxes?",
    "output": "Objects imbued with part of a dark wizard's soul, creating immortality by splitting their soul."
  },
  {
    "input": "What is the prophecy about Harry and the Elder Wand?",
    "output": "The Elder Wand, one of the Deathly Hallows, will only obey a master who has conquered the previous owner, potentially influencing the choice of its wielder."
  }
]

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.