Code Monkey home page Code Monkey logo

rlhf-reward-modeling's Introduction

RLHF-Reward-Modeling

πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯

πŸ”₯πŸ”₯πŸš€πŸš€πŸš€ Check out our blog post! πŸš€πŸš€πŸš€πŸ”₯πŸ”₯

πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯

TL;DL: this is a repo for training the reward model for DRL-based RLHF (PPO), Iterative SFT (Rejection sampling fine-tuning), and iterative DPO.

  • 4 x A40 48G: we can train Gemma-7B-it with max_length 4096 by Deepspeed Zero-3 + gradient checkpoint;
  • 4 x A100 80G: we can train Gemma-7B-it with max_length 4096 by gradient checkpoint;
  • The resulting reward models achieve SOTA performance in the RMs with based model ≀ 13B in the leaderboard of RewardBench.

Installation instructions

To be updated.

The current solution is based on the alignment handbook and the environment, which should be sufficient for plain RM training. Before starting, please make sure your linux machine has nvidia-cuda-toolkit installed.

conda create -n newhandbook python=3.10.9
conda activate newhandbook

git clone https://github.com/huggingface/alignment-handbook.git
cd ./alignment-handbook/
python -m pip install .

python -m pip install .
pip install flash-attn

git clone https://github.com/WeiXiongUST/RLHF-Reward-Modeling.git

Some possible problems:

CUDA_HOME may not exist, unable to compile CUDA op(s)AssertionError:[end of output]

conda install nvidia/label/cuda-12.2.0::cuda-nvcc

You also need to install wandb to record the training and log in with the huggingface accout to access Gemma.

pip install wandb
wandb login

huggingface-cli login

Dataset Preparation

The dataset should be preprocessed as the standard format, where each of the sample consists of two conversations 'chosen' and 'rejected' and they share the same prompt. Here is an example of the rejected sample in the comparison pair.

[
{ "content": "Please identify the top 5 rarest animals in the world.", "role": "user" },
{ "content": "Do you mean animals that are really rare, or rare relative to the size of the human population?", "role": "assistant" },
{ "content": "The ones that are really rare.", "role": "user" },
{ "content": "Alright, here’s what I found:", "role": "assistant" }, 
]

We preprocess 4 dataset and upload them to the hugginface hub.

Version 1: The model is trained on a mixture1 of

The total number of the comparison pairs is 250K, where we perform the following data selection and cleaning strateges:

  • HH-RLHF: we use all the base, rejection sampling, and online subsets but delete the samples whose chosen == rejected, leading to 115547;
  • SHP: we only use the samples with score ratio > 2, for each prompt, we only take 1 comparison, leading to 55916;
  • Ultrafeedback: similar to UltraFeedback-Binarized, we use the fine-grained score instead of the overall one to rank samples. Meanwhile, for each prompt, we take the best one v.s. random chosen one in the remaining samples. Finally, we delete the selected pairs with equal scores, leading to 62793.
  • HelpSteer: we use the mean of helpfulness and correctness to rank samples. Meanwhile, we take the best sample v.s. the random chosen one in the remaining samples. Finally, we delete the selected pairs with equal scores, leading to 8206;
  • Capybara: we delete the pairs whose chosen and rejected samples are of the same rating, leading to 7562;
  • Orca: we delete the pairs whose chosen and rejected samples are of the same rating, leading to 6405.

Version 2: The model is also trained on a mixture2 of

Difference:

  • SHP: we only use the samples with score ratio > 2, for each prompt, we take 5 comparison at most, leading to 109526;
  • Ultrafeedback: similar to UltraFeedback-Binarized, we use the fine-grained score instead of the overall one to rank samples. Meanwhile, for each prompt, we take all possible 6 pairs of comparisons. Finally, we delete the selected pairs with equal scores, leading to 267416.
  • HelpSteer: we use the mean of helpfulness and correctness to rank samples. Meanwhile, we take all possible 6 pairs of comparisons. Finally, we delete the selected pairs with equal scores, leading to 21576;

Version 3: Mixture2 + 30K safety is the mixture2 + the training set of PKU-Alignment/PKU-SafeRLHF-30K

Version 4: 1 Mixture2 + 150K safety is the mixture2 + 150K samples from PKU-Alignment/PKU-SafeRLHF

Version 5 Directly leverage the dataset from llm-blender/Unified-Feedback, which includes 886K preference samples from 8 prior datasets: openai/summarize_from_feedback, openai/webgpt_comparisons, Dahoas/instruct-synthetic-prompt-responses, Anthropic/hh-rlhf, lmsys/chatbot_arena_conversations, openbmb/UltraFeedback, argilla/ultrafeedback-binarized-preferences-cleaned, berkeley-nest/Nectar.

Running the Code

Running the code with Gemma-2b-it.

accelerate launch rm.py --model_name google/gemma-2b-it --max_length 4096 --train_set_path weqweasdas/preference_dataset_mix2

You can also modify the learning rate, batch size, output_path.. with either command or modify the ScriptArguments in the rm_gemma.py

If you encounter out-of-memory issue. Running the code with Gemma-2b-it with deepspeed stage 3. If OOM still exists, use a smaller max length and per_device_batch_size.

accelerate launch rm.py --model_name google/gemma-2b-it --max_length 4096 --train_set_path weqweasdas/preference_dataset_mix2 --deepspeed deepspeed_3.json

REMARK: note that with deepspeed stage 3, the final mode saving does not work normally. You should set the save_every_steps as the total number of training steps - 1 so that the trainer will save a model for you just before finishing the training.

Evaluation Results

You can evaluate the resulting reward model with the dataset provided by benchmark by the following command.

accelerate launch eval_bench_mark.py --reward_name_or_path ./models/gemma_2b_mixture2_last_checkpoint --record_dir ./bench_mark_eval.txt

Some models trained by our script are competitive in the leaderboard.

image

To Do

  • Bradley-Terry Reward Model based on Gemma and Mistral.
  • Bradley-Terry Reward Model based on Mixtral;
  • Preference model;
  • Regression-based reward model;
  • Multi-objective reward model.

Citation

The repo was part of the iterative rejection sampling fine-tuning and iterative DPO. If you find the content of this repo useful in your work, please consider cite it as follows:

@article{dong2023raft,
  title={Raft: Reward ranked finetuning for generative foundation model alignment},
  author={Dong, Hanze and Xiong, Wei and Goyal, Deepanshu and Pan, Rui and Diao, Shizhe and Zhang, Jipeng and Shum, Kashun and Zhang, Tong},
  journal={arXiv preprint arXiv:2304.06767},
  year={2023}
}

@misc{xiong2024iterative,
      title={Iterative Preference Learning from Human Feedback: Bridging Theory and Practice for RLHF under KL-Constraint}, 
      author={Wei Xiong and Hanze Dong and Chenlu Ye and Ziqi Wang and Han Zhong and Heng Ji and Nan Jiang and Tong Zhang},
      year={2024},
      eprint={2312.11456},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

rlhf-reward-modeling's People

Contributors

hendrydong avatar violet24k avatar weixiongust avatar yangrui2015 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

rlhf-reward-modeling's Issues

Cannot run the training script

Thank you for sharing your scripts for training reward models based on mistral-7b and gemma-2b! May I ask which version of transformers and accelerate are you using? I followed the instruction and built the environment from alignment-handbook. However, when I use their latest version (transformers==4.39.0 and accelerate=0.27.2), I keep getting the following errors when using deepspeed

ValueError: DeepSpeed Zero-3 is not compatible with `low_cpu_mem_usage=True` or with passing a `device_map`.

But in the training script we didn't set device_map so it should be None.

And when running the script without deepspeed, I got another error:

File "/p/anaconda3/envs/reward-model/lib/python3.10/site-packages/accelerate/utils/dataclasses.py", line 1061, in set_auto_wrap_policy
    raise Exception("Could not find the transformer layer class to wrap in the model.")

Both errors will gone once I switch the transformers library back to 4.36.2 and accelerate library back to 0.23.0 (which are used by alignment-handbook before Mar 1st, 2024). But in this case transformers won't support gemma...

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.