Code Monkey home page Code Monkey logo

halos's Introduction

Human-Aware Loss Functions (HALOs) ๐Ÿ˜‡

This repo allows you to design new Human-Aware Loss Functions (HALOs) for aligning LLMs with offline human feedback at scale (read more in our technical report or our full paper). It was used to create Archangel, the largest-ever suite of human-feedback-aligned LLMs, and has been tested at scales from 1B to 30B.

This repo draws from the excellently written DPO repo and has preserved many design choices from the original. Some of the key changes we introduced are:

  • making data loading more modular, so that you can easily write your own dataloader
  • making trainers more modular, so that each HALO has its own trainer subclass
  • adding code for doing open-ended evaluation with GPT-4 as a judge
  • supporting losses beyond SFT and DPO (including KTO, PPO (offline, off-policy variant), and SLiC)

To first SFT a model, run a command like

python train.py loss=sft model=llama7b datasets=[shp,hh,oasst] exp_name=llama7b_sft mode=train ++cache_dir=/data/models

which will save a model to /data/models/llama7b_sft/LATEST/policy.pt. To then align a model with KTO, run a command like

python train.py loss=kto model=llama7b datasets=[shp,hh,oasst] exp_name=llama7b_kto mode=train ++cache_dir=/data/models ++model.load_from=llama7b_sft/LATEST/policy.pt

which will save a model to /data/models/llama7b_kto/LATEST/policy.pt.

Quickstart

Let's say we want to implement a new HALO called Kahneman-Tversky optimization (KTO). This is already implemented in this repo based on the details in our report, but let's pretend that it's not. What should we do?

  1. First, create and activate the conda environment.

    conda env create -f environment.yml

    conda activate halos

    If you can't create a conda environment, or you face some issue during installtion, try doing

    conda create -n halos3 python=3.10.12
    pip3 install numpy==1.24.3 ninja==1.11.1.1 packaging==23.1 
    conda install pytorch==2.1.1 pytorch-cuda=12.1 -c pytorch -c nvidia
    pip3 install flash-attn==2.3.3 
    pip3 install transformers==4.35.2 datasets hydra-core==1.3.2 wandb==0.15.3 openai==1.6.1 accelerate==0.21.0 tensor-parallel==1.2.4
    
  2. Determine whether you need a new dataset. If you have a dataset called foo, add a function called get_foo to dataloader.py that will return a Dataset instance. This function should have the following signature, where the prefixes and suffixes determine how the dataset is formatted (see config.yaml) and split should be either train or test:

    def get_foo(split: str, human_prefix: str, human_suffix: str, assistant_prefix: str, assistant_suffix: str) -> Dataset:

  3. Determine whether you need a new dataloader. KTO doesn't use preference pairs, just knowledge of whether outputs are desirable or undesirable. This means we use dataloader.UnpairedPreferenceDataLoader. However, that dataloader assumes that you're working with datasets that originally contain preference pairs, like Anthropic HH or SHP. If you wanted a custom dataloader, you would implement it in the same Python file by extending the base DataLoader class.

  4. Write a trainer in trainers.py. This should subclass either UnpairedPreferenceTrainer or PairedPreferenceTrainer depending on whether it uses pairs of preferences or not. If you need highly custom behavior that is not in either, then you can subclass BasicTrainer directly.

    We can implement a simple version of KTO as follows (note that this is different from the proper version of KTO in KTOTrainer, which does not assume the existence of both chosen and rejected examples in each batch).

    To make SimpleKTOTrainer, we just subclass trainers.UnpairedPreferenceTrainer as trainers.SimpleKTOTrainer and overwrite the loss function definition. KTO has one hyperparameter, beta, which we can access via self.config.loss.beta:

    class SimpleKTOTrainer(UnpairedPreferenceTrainer):
       """A simple version of KTO meant to introduce you to the HALOs repo."""
       def loss(self,
            policy_chosen_logps: torch.FloatTensor,
            policy_rejected_logps: torch.FloatTensor,
            reference_chosen_logps: torch.FloatTensor,
            reference_rejected_logps: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
       """Compute the Kahneman-Tversky loss for a batch of policy and reference model log probabilities. 
       For each batch of n/2 chosen examples and n/2 rejected examples (belonging to n different inputs), calculate the loss as follows.
    
       If generation y ~ p_chosen, where x' ~ are the examples with rejected generations, we have the 'chosen' loss:
           L(x, y) := 1 - sigmoid(beta * (log p_policy(y|x) - log p_reference(y|x) - KL(p_policy(y_rejected|x') || p_reference(y_rejected|x')))
       If generation y ~ p_rejected, , where x' ~ are the examples with chosen generations, we have the 'rejected' loss:
           L(x, y) := 1 - sigmoid(beta * KL(p_policy(y_chosen|x') || p_reference(y_chosen|x')) - [log p_policy(y|x) - log p_reference(y|x)])
       """
       # your implementation goes here
       return losses, chosen_rewards, rejected_rewards
  5. Add a file to the config/loss folder specifying the details of the loss:

     name: kto-simple
     beta: 0.1 # the temperature parameter for simple KTO; lower values mean we care less about the reference model
     trainer: SimpleKTOTrainer # implemented in trainers.py
     dataloader: UnpairedPreferenceDataLoader # already exists in dataloaders.py
     use_reference_model: true # true because the loss definition includes a reference model
  6. Now we can start training a model! Let's train a Llama-7B model on the SHP, Anthropic HH, and Open Assistant datasets. Since the corresponding entry for Llama-7B is config/model/llama7b.yaml, we run a command with Hydra:

    python train.py loss=kto-simple model=llama7b datasets=[shp,hh,oasst] exp_name=kto-simple_llama7b mode=train ++cache_dir=/data/models

    which will align a Llama-7B model from scratch. If we want to align a model that we've already finetuned with the HALOs repo, we can add something like ++model.load_from=/data/models/sft_llama7b/LATEST/policy.pt to the end of the command.

    That's it! Your model will be saved to /data/models/kto-simple_llama7b/LATEST/policy.pt.

  7. Let's sample some generations from our newly trained model. The sampling configs are in either config/config.yaml or under models/. We can sample 512 generations from our newly trained model in batches of 32 with the command, which will create a JSON file under samples/{config.exp_name}.json.

    python eval.py --config-path=/data/models/kto-simple_llama7b/config.yaml ++mode=sample ++n_samples=512 ++model.eval_batch_size=32 ++samples_dir=samples/

  8. After setting OPENAI_API_KEY, we can evaluate our aligned model with GPT-4 with the following command, which compares the aligned model's generations to the human-chosen response in the data:

    python compare.py -f samples/kto-simple_llama7b.json -mc 512 -bk chosen -ck policy -r result.jsonl

FAQs

  1. Do you support multi-node training?

    No, currently the repo only supports single-node training. Multi-node training will be added at some point in the future. Every model in the Archangel suite was trained with 8 x A100 GPUs on a single node.

  2. How do I save intermediate checkpoints?

    Set intermediate_checkpoints to true in config/config.yaml or on the command line with ++config.intermediate_checkpoints=true. Every config.eval_every steps, a checkpoint will be saved in the experiment directory ($cache_dir/$exp_name).

  3. Where do I find all the Archangel models?

    They are all on the Huggingface Hub:

Model PPO DPO KTO SFT SLIC SFT+PPO SFT+DPO SFT+KTO CSFT SFT+CSFT
pythia1-4b weights weights weights weights weights weights weights weights weights weights
pythia2-8b weights weights weights weights weights weights weights weights weights weights
pythia6-9b weights weights weights weights weights weights weights weights weights weights
pythia12-0b weights weights weights weights weights weights weights weights weights weights
llama7b weights weights weights weights weights weights weights weights weights weights
llama13b weights weights weights weights weights weights weights weights weights weights
llama30b weights weights weights weights weights weights weights weights weights weights

halos

Citation

If you find this repo or the technical paper useful in your research, please feel free to cite our work:

@techreport{ethayarajh2023halos,
  author = {Ethayarajh, Kawin and Xu, Winnie, and Jurafsky, Dan and Kiela, Douwe},
  title = {Human-Aware Loss Functions (HALOs)},
  institution = {Contextual AI},
  note = {https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf},
  year = {2023},
}

halos's People

Contributors

kawin-contextual-ai avatar kawine avatar samuelzxu avatar xwinxu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

halos's Issues

In dataloder

pairs: List[Tuple[int, int]] = field(default_factory=list) # indices in responses, where i > j in pair (i,j) is a preference

In this line, does "i > j" mean "score(i) > score(j)"? I'm a bit confused, thank you for your clarification!
image

Llama 3 compatibility

I tried two llama 3 8B models from huggingface by creating a new config/model/.yaml with

name_or_path: meta-llama/Meta-Llama-3-8B
and also
name_or_path: NousResearch/Meta-Llama-3-8B

I am able to train the SFT model using the command
python train.py loss=sft model=llama_3_8b datasets=[shp,hh] exp_name=l38b_sft_0505 mode=train ++cache_dir=./data/models

However I face this error when I train the PPO model using the command
python train.py loss=ppo model=llama_3_8b datasets=[shp,hh] exp_name=l38b_ppo_0517 mode=train ++cache_dir=./data/models ++model.load_from=l38b_sft_0505/LATEST/policy.pt

Stacktrace:

Making experiment directory ./data/models/l38b_ppo_0517
no FSDP port specified; using open port for FSDP: 46451
seed: 1
exp_name: l38b_ppo_0517
datasets:
- shp
- hh
mode: train
debug: false
use_fsdp: true
fsdp_port: 46451
wandb:
  enabled: true
  entity: null
  project: l38b_ppo_0517
cache_dir: ./data/models
local_run_dir: ./data/models/l38b_ppo_0517
do_first_eval: true
minimum_log_interval_secs: 1.0
intermediate_checkpoints: false
trainer: BasicTrainer
lr: 5.0e-07
n_epochs: 1
n_examples: null
optimizer: RMSprop
warmup_steps: 150
eval_every: 20000
n_samples: 128
samples_dir: samples/
n_eval_examples: 512
saved_policy: ./data/models/l38b_ppo_0517/LATEST/policy.pt
top_p: 0.95
human_prefix: '

  <|user|>

  '
assistant_prefix: '

  <|assistant|>

  '
human_suffix: ''
assistant_suffix: ''
frac_unique_desirable: 1.0
frac_unique_undesirable: 1.0
model:
  name_or_path: NousResearch/Meta-Llama-3-8B
  tokenizer_name_or_path: null
  load_from: l38b_sft_0505/LATEST/policy.pt
  block_name: LlamaDecoderLayer
  policy_dtype: bfloat16
  fsdp_policy_mp: null
  reference_dtype: bfloat16
  max_grad_norm: 10.0
  v_head_max_grad_norm: 0.1
  max_length: 2048
  max_prompt_length: 1024
  activation_checkpointing: true
  batch_size: 32
  gradient_accumulation_steps: 1
  eval_batch_size: 16
  use_flash_attention: true
loss:
  name: ppo
  ppo_epochs: 1
  cliprange: 0.5
  trainer: PPOTrainer
  dataloader: UnpairedPreferenceDataLoader
  lam: 0.95
  gamma: 0.99
  critic_coef: 0.01
  KL_coef: 0.1
  use_reference_model: true

================================================================================
Writing to aid-nrt-slurm-bm-gpu-b4-8-ad1-005:./data/models/l38b_ppo_0517
================================================================================
building policy
You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Loading checkpoint shards: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 4/4 [00:00<00:00, 10.43it/s]
Error executing job with overrides: ['loss=ppo', 'model=llama_3_8b', 'datasets=[shp,hh]', 'exp_name=l38b_ppo_0517', 'mode=train', '++cache_dir=./data/models', '++model.load_from=l38b_sft_0505/LATEST/policy.pt', '++wandb.project=l38b_ppo_0517']
Traceback (most recent call last):
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_errors.py", line 304, in hf_raise_for_status
    response.raise_for_status()
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/requests/models.py", line 1021, in raise_for_status
    raise HTTPError(http_error_msg, response=self)
requests.exceptions.HTTPError: 404 Client Error: Not Found for url: https://huggingface.co/NousResearch/Meta-Llama-3-8B/resolve/main/pytorch_model.bin

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/rosridha/POCs/rlhf/HALOs/models.py", line 95, in from_pretrained
    filename = hf_hub_download(pretrained_model_name_or_path, "pytorch_model.bin")
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 119, in _inner_fn
    return fn(*args, **kwargs)
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 1261, in hf_hub_download
    metadata = get_hf_file_metadata(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 119, in _inner_fn
    return fn(*args, **kwargs)
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 1674, in get_hf_file_metadata
    r = _request_wrapper(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 369, in _request_wrapper
    response = _request_wrapper(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 393, in _request_wrapper
    hf_raise_for_status(response)
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_errors.py", line 315, in hf_raise_for_status
    raise EntryNotFoundError(message, response) from e
huggingface_hub.utils._errors.EntryNotFoundError: 404 Client Error. (Request ID: Root=1-66483dd3-1fff68591f1d7b9b7a1e9d08;04442bd4-7883-4026-8e20-334b328b960d)

Entry Not Found for url: https://huggingface.co/NousResearch/Meta-Llama-3-8B/resolve/main/pytorch_model.bin.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_errors.py", line 304, in hf_raise_for_status
    response.raise_for_status()
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/requests/models.py", line 1021, in raise_for_status
    raise HTTPError(http_error_msg, response=self)
requests.exceptions.HTTPError: 404 Client Error: Not Found for url: https://huggingface.co/NousResearch/Meta-Llama-3-8B/resolve/main/pytorch_model.bin.index.json

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/rosridha/POCs/rlhf/HALOs/train.py", line 231, in <module>
    main()
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/main.py", line 94, in decorated_main
    _run_hydra(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
    _run_app(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/_internal/utils.py", line 457, in _run_app
    run_and_report(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
    raise ex
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
    return func()
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
    lambda: hydra.run(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/_internal/hydra.py", line 132, in run
    _ = ret.return_value
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/core/utils.py", line 260, in return_value
    raise self._return_value
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/core/utils.py", line 186, in run_job
    ret.return_value = task_function(task_cfg)
  File "/home/rosridha/POCs/rlhf/HALOs/train.py", line 132, in main
    policy = model_class.from_pretrained(
  File "/home/rosridha/POCs/rlhf/HALOs/models.py", line 101, in from_pretrained
    index_file_name = hf_hub_download(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 119, in _inner_fn
    return fn(*args, **kwargs)
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 1261, in hf_hub_download
    metadata = get_hf_file_metadata(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 119, in _inner_fn
    return fn(*args, **kwargs)
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 1674, in get_hf_file_metadata
    r = _request_wrapper(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 369, in _request_wrapper
    response = _request_wrapper(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 393, in _request_wrapper
    hf_raise_for_status(response)
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_errors.py", line 315, in hf_raise_for_status
    raise EntryNotFoundError(message, response) from e
huggingface_hub.utils._errors.EntryNotFoundError: 404 Client Error. (Request ID: Root=1-66483dd3-495ad9e044bc65914d189248;d2584051-b677-439d-9982-cc1f1fd96021)

Entry Not Found for url: https://huggingface.co/NousResearch/Meta-Llama-3-8B/resolve/main/pytorch_model.bin.index.json.

ERROR:None of the inputs have requires_grad=True. Gradients will be None

Computing eval metrics: 0%| | 0/86 [00:00<?, ?it/s]/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")

Losses list appears to be empty for loss=DPO

Hi,

I got this to run using fsdp. When I print the metrics they are all strangely empty. I follow the PairedPreferenceTrainer class, which calls self.loss. When following self.loss, loss is defined in DPOTrainer. I add a print right after line :699 -> this is oddly an empty tensor. Any ideas?

Thanks!

image

Can you provide a clear description of the dataset structure we can use for our custom dataset.

Something like this?

From the Huggingface DPOTrainer docs:

dpo_dataset_dict = {
"prompt": [
"hello",
"how are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"chosen": [
"hi nice to meet you",
"I am fine",
"My name is Mary",
"My name is Mary",
"Python",
"Python",
"Java",
],
"rejected": [
"leave me alone",
"I am not fine",
"Whats it to you?",
"I dont have a name",
"Javascript",
"C++",
"C++",
],
}

Comments in KTO Trainer `forward()`

Hi there,

I'm reading through the forward() function in KTO Trainer, and in the function signature it states that if read in correctly, the sizes of chosen and rejected logps should be batch_size/2. However, this doesn't make sense to me because this sounds like a limitation for Paired preference training rather than the unpaired training method of kto.

Here's comment from lines 875-877 of trainers.py:

chosen_logps: log probabilities of chosen examples (should be batch size / 2 if data was read in correctly)
rejected_logps: log probabilities of rejected examples (should be batch size / 2 if data was read in correctly)
KL_logps: log probabilities of the unmatched y'|x (used to estimate the KL divergence between policy and reference; should be batch size)

Please let me know if this makes sense, Im happy to open a PR.

Gradient Clipping for FSDP

Hi! Thank you for maintaining such a valuable repository.
I would like to suggest a minor fix regarding gradient clipping. For FSDP, we should not use torch.nn.utils.clip_grad_norm_ (relevant issue), but instead directly call the clip_grad_norm_ method of the FSDP module. Thus, I would like to suggest modifying the following:

HALOs/trainers.py

Lines 453 to 455 in f9f7826

def clip_gradient(self):
"""Clip the gradient norm of the parameters of a non-FSDP policy."""
return torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.config.model.max_grad_norm).item()

into the following:

        def clip_gradient(self):
            """Clip the gradient norm of the parameters."""
            if self.fsdp:
                return self.policy.clip_grad_norm_(self.config.model.max_grad_norm).item()
            return torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.config.model.max_grad_norm).item()

Thank you so much!

How to sample from HF models?

Hi, this is great great project!

I'm thinking about reproducing your scores in the paper first. Specifically from the model ContextualAI/archangel_sft_llama7b. How to sample from this HF model? Right now the eval.py takes in a path from /data/..., which exists only if I trained it myself.

Compatibility with quantized embeddings

Hi,

Firstly, thanks for the awesome work!

I want to use KTO with a quantized Mistral model but am getting pickle errors from the multiprocessing thread, probably since that changes the Embedding layers to be nn.Linear4bit instead of just nn.Linear.

File "~/HALOs/train.py", line 250, in main
    mp.spawn(worker_main, nprocs=world_size, args=(world_size, config, tokenizer, train_iterator, eval_iterator, policy, reference_model), join=True)
  File "~/conda/env/halos/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 241, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
  File "~/conda/env/halos/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
    process.start()
  File "~/conda/env/halos/lib/python3.10/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "~/conda/env/halos/lib/python3.10/multiprocessing/context.py", line 288, in _Popen
    return Popen(process_obj)
  File "~/conda/env/halos/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "~/conda/env/halos/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "~/conda/env/halos/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "~/conda/env/halos/lib/python3.10/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
_pickle.PicklingError: Can't pickle <function Embedding.forward at 0x7f75fa914160>: it's not the same object as torch.nn.modules.sparse.Embedding.forward

I'm thinking a workaround would be to use multiprocess instead of multiprocessing to use dill instead of pickle, but haven't been successful with that yet... do you have any suggestions?

Error fix for evaluation script `eval.py`

Thank you for maintaining such an important repository. While using your repository, I was faced with some issues using the eval.py, and wanted to check if the following changes required are indeed valid.

  1. In line 86 of eval.py, we need to modify reference_model.resize_token_embeddings(len(tokenizer)) into the following:
if config.loss.name == 'ppo':
    reference_model.resize_token_embeddings(len(tokenizer))

this is because we have resized the token embeddings for the reference model only if the loss was ppo (refer to line 170 of train.py)

  1. When we use the eval mode of eval.py, we need to change the trainer type from BasicTrainer to DPOTrainer, otherwise we are faced with the NotImplementedError while calling the get_batch_metrics method.

Once again, thank you for maintaining an amazing repository which is very easy to use. Thank you so much!

Question about the KL term in the loss function

For y ~ p_chosen, the KL term for the loss is derived as KL(p_policy(y_rejected|x') || p_reference(y_rejected|x'))), however according to technical report, the KL divergence should be calculated over the entire dataset. Is there a reason x' for above only includes only rejected inputs and not input x's from the entire batch (i.e include y_chosen too) ? I can guess maybe numerical stability could be the reason to make sure that two terms of the loss aren't correlated but want to make sure I am not missing something here.

Hidden State mapping to two value nodes instead of 1

Hi,

I'm confused on why you've defined the value head as you did in models.py. Namely, the value head as it is will output two numbers instead of 1, since you're mapping from the (2,4096) final hidden state to a (2,1) dimension tensor for the final value. It looks like you're missing half the hidden states. I would expect for it to map from a flattened version of the final hidden state to a single node.

As a sanity check I looked for where this was used and in line 1114 of trainers.py, I noticed that you're only taking in the first value in this (2,1) vector.

Can you tell me why you've made this design choice? I feel like I'm misinterpreting something here.

Is there a problem with training?

Initially, I used KTO for training, and the loss did not converge at all, as shown in the following training result graph.
725bd748-1d85-4d80-90ff-bd2275031467

Later, based entirely on llama7b and hh data, I used the script you provided exactly: python train.py loss=sft model=llama7b datasets=[hh] exp_name=llama7b_sft mode=train ++cache_dir=/data/models, and the training result graph is as follows:
image (1)
image

The only difference in the SFT training from yours is that I set use_flash_attention to false.

Request for details and assistance on PPO Experiments with SFT+PPO training

Hello Developers,

Firstly, I would like to thank you for the excellent work on this repository and for sharing the plots on other issues. I'm currently utilizing your library to train a model using sft+ppo, and I've successfully replicated the sft experiment as per the results shared ContextualAI/HALOs/issues/13.

However, I'm experiencing negligible improvement with the PPO part of the training. Could you provide the details of your PPO experiments? I noted in a previous comment that the preferential tuning was run significantly longer, so I adjusted the PPO epochs to 3 in my experiments. Are these adjustments in line with what was done in your experiments? Additionally, could you elaborate on how and when to decide which checkpoint to use for downstream tasks, especially for PPO, DPO, and KTO scenarios?

Link to my plots: l7b_ppo_0419
Screenshot 2024-04-22 at 11 11 32โ€ฏAM

Here are the commands I used for my experiments:

  • SFT Training Command
python train.py loss=sft model=llama7b datasets=[shp,hh,oasst] exp_name=l7b_sft_0416 mode=train ++cache_dir=/data/models wandb.project=l7b_sft_0416
  • PPO Training (3 Epochs)
# Updated n_epochs in config.yaml to 3
python train.py loss=ppo model=llama7b datasets=[shp,hh,oasst] exp_name=l7b_ppo_0419 mode=train ++cache_dir=./data/models ++model.load_from=l7b_sft_0416/LATEST/policy.pt wandb.project=l7b_ppo_0419

Additional Query:

  • When conducting sft training, it calculates train and validation losses using train dataset splits. If I use the same dataset for ppo, how can I ensure that I am not retraining on the train split inadvertently? Furthermore, when using stratified datasets for both sft and preferential tuning, do you recommend holding out different data points for each, and is this approach considered best practice?
  • Based on your plots and results shared, am I correct in understanding that you had a batch size of 32 and conducted 200k steps of sft training, which equates to training on approximately 6.4 million datapoints, DPO on roughly 9.5 million datapoints (300k * 32), and KTO on about 17.6 million datapoints?

Thank you for any guidance or insights you can provide.

a few queries

Hello Kawin and authors,

Thanks for sharing the excellent work. I enjoyed reading the technical report it was crisp and concise, I appreciate the efforts put in to achieve this clarity.

I have a couple of questions:

  1. In the KTO loss function, the second term of expected KL divergence can be interpreted as the average reward score obtained over 'rejected' responses over all input prompts (in desired response case). In this sense, it is similar to DPO where the second term is 'reward' for the rejected response of that particular x but here we are replacing it with the average reward for 'rejected' response from all the xs.
    Is this understanding correct?
  2. I saw in the code that there is KTOZero implementation where the expected KL divergence is replaced with 0. I am curious what was its finding.
  3. On a lighter note, I would like to know what is the font used in the technical report.

thanks,
Onkar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.