Code Monkey home page Code Monkey logo

rwkv_lm_ext's Introduction

RWKV_LM_EXT

This project is to extend RWKV LM's capabilities including sequence classification/embedding/peft/cross encoder/bi encoder/multi modalities, etc.

src/model_ext.py

We extends two types of model based on RWKV(5,6)。

  • RwkvForClassification

This class is used to do sequence classification.

graph LR
    A(idx) --> B[embeddings]
    B --> C[Apply rwkv blocks]
    C --> D[Found the eos id's embeddings]
    D --> E[Score the embeddings]
    E --> F(Return scores)
Loading
  • RwkvForSequenceEmbedding

This class is used to do sequence embedding.

graph LR
    A(idx) --> B[embeddings]
    B --> C[Apply rwkv blocks]
    C --> D[Apply pooling method]
    D --> E(Return embeddings)
Loading

Some lora checkpoints:

Try python peft_train/peft_test.py to see how both 2 sft work seemlessly. In the future more sft parameters can work together to build a more sophiscated AI assistant.

Further more some utilities:

  • peft_train/hf2rwkv_lora.py convert the lora check point trained by huggingface peft to rwkv lora format. This utility can be used to TURN BiEncoder checkpoint.

  • peft_train/peft_test.py is a better example to show how to use one base model to do variable stuffs in runtime by switching lora adapters only.

    . if you're using PISSA checkpoint, please set --chat_lora_alpha 8 instead --chat_lora_alpha 32 in Lora.

Training process to do a SFT using RWKV's lora/pissa/state tuning.

python data/SftUtilities.py --input_dir JSON_DIR --output_dir DATASET_DIR --tokenizer_file PATH_TO_rwkv_vocab_v20230424.txt
  • We got datasets in DATASET_DIR, now we can train our SFT model. The following's are the script we use to train lora/pissa/state tuning. Let's assume we output our SFT peft model to OUTPUT_DIR. For a 4090ti, we can train the data with batch 32 16 8 4 2 1 corresponding to 64 128 256 512 1024 2048 max length.

    • Lora
    RWKV_TRAIN_TYPE=lora python peft_train/peft_train_sft.py --train_data /home/rwkv/data/tigerbot_sft_dataset --model_file MODEL_DIR/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth --output_dir OUTPUT_DIR/lora_rwkv --train_type lora --log_every_n_steps 100000 --target_modules att ffn --log_every_n_steps 100000 --wandb lora_tigerbots --num_epochs 3 --train_lengths 64 128 256 512 1024 2048 --train_batch_sizes 32 16 8 4 2 1
    • Pissa
    RWKV_TRAIN_TYPE=pissa python peft_train/peft_train_sft.py --train_data /home/rwkv/data/tigerbot_sft_dataset --model_file MODEL_DIR/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth --output_dir OUTPUT_DIR/pissa_rwkv --train_type pissa --log_every_n_steps 100000 --target_modules att ffn --log_every_n_steps 100000 --wandb pissa_tigerbots --num_epochs 3 --train_lengths 64 128 256 512 1024 2048 --train_batch_sizes 32 16 8 4 2 1
    

Beam search with logits processors

Now user can use beam search to get variable results with beam search. Try src/tests/TestBeamSearch.py.

截图 2024-05-24 10-22-40.png

Encoders for inference

Please refer infer/encoders to utilize the pure RWKV's mulitple lora adapters. 截图 2024-05-24 17-43-40.png

rwkv_lm_ext's People

Contributors

yynil avatar jl-er avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.