Code Monkey home page Code Monkey logo

dept's Introduction

DePT: Decomposed Prompt Tuning for Parameter-Efficient Fine-tuning

This repository provides the code for the paper titled DePT: Decomposed Prompt Tuning for Parameter-Efficient Fine-tuning, making the integration of our code contributions into other projects more accessible.

arxiv-link made-with-pytorch License: MIT


Quick Links#

Overview

You can reproduce the experiments of our paper DePT: Decomposed Prompt Tuning for Parameter-Efficient Fine-tuning.

Abstract Prompt tuning (PT), where a small amount of trainable soft (continuous) prompt vectors is affixed to the input of language models (LM), has shown promising results across various tasks and models for parameter-efficient fine-tuning (PEFT). PT stands out from other PEFT approaches because it maintains competitive performance with fewer trainable parameters and does not drastically scale up its parameters as the model size expands. However, PT introduces additional soft prompt tokens, leading to longer input sequences, which significantly impacts training and inference time and memory usage due to the Transformer's quadratic complexity. Particularly concerning for Large Language Models (LLMs) that face heavy daily querying. To address this issue, we propose Decomposed Prompt Tuning (DePT), which decomposes the soft prompt into a shorter soft prompt and a pair of low-rank matrices that are then optimised with two different learning rates. This allows DePT to achieve better performance while saving over 20% memory and time costs compared to vanilla PT and its variants, without changing trainable parameter sizes. Through extensive experiments on 23 natural language processing (NLP) and vision-language (VL) tasks, we demonstrate that DePT outperforms state-of-the-art PEFT approaches, including the full fine-tuning baseline in some scenarios. Additionally, we empirically show that DEPT grows more efficient as the model size increases. Our further study reveals that DePT integrates seamlessly with parameter-efficient transfer learning in the few-shot learning setting and highlights its adaptability to various model architectures and sizes.

1. Requirements and Installation

To run the prompt-based or cls-based fine-tuning, you need to install the following packages.

  • Transformers
  • Pytorch

2. Prepare the datasets

We use the following NLP datasets in our experiments: GLUE, SuperGLUE, MRQA 2019 Shared Task, WinoGrande, Yelp-2, SciTail and PAWS-Wiki. All these datasets are available in the Huggingface Datasets and can be downloaded automatically. Please refer to the file src/tasks.py for the details of the datasets.

3. Run Experiments

We provide the scripts to reproduce the main experiments in our paper. For example, you can run the following script to reproduce the results of DePT on the GLUE dataset. The PREFIX_LENGTH represents the length of the soft prompt m in the paper. The R represents the rank of low-rank matrices r in the paper.

MODEL=t5-base
MAX_LENGTH=256
MAX_STEPS=40000
PREFIX_LENGTH=40 
R=45
for TASK_NAME in cola mrpc mnli qnli qqp rte sst2 stsb; do
  for LORA_LR in 5e-3 3e-1 5e-4; do
      for lr in 3e-1 4e-1; do
            CUDA_VISIBLE_DEVICES=0 python train.py \
                --peft_type PROMPT_TUNING_LORA \
                --lora_embedding_lr ${LORA_LR} \
                --learning_rate ${lr} \
                --prefix_length ${PREFIX_LENGTH} \
                --r ${R} \
                --task_name ${TASK_NAME} \
                --dataset_config_name en \
                --model_name_or_path ${MODEL} \
                --do_train \
                --do_eval \
                --do_predict \
                --per_device_train_batch_size 32 \
                --per_device_eval_batch_size 32 \
                --max_seq_length ${MAX_LENGTH} \
                --save_strategy steps \
                --evaluation_strategy steps \
                --max_steps ${MAX_STEPS} \
                --eval_steps 1000 \
                --save_steps 1000 \
                --warmup_steps 500 \
                --weight_decay 1e-5 \
                --load_best_model_at_end \
                --save_total_limit 1 \
                --output_dir saved_${MODEL}/${TASK_NAME}_lr${lr}_loralr${LORA_LR}_pl${PREFIX_LENGTH}_r${R}_st${MAX_STEPS};
        done;
    done;
done

You can replace the TASK_NAME with superglue-multirc superglue-wic superglue-wsc.fixed superglue-cb superglue-boolq for the SuperGLUE benchmark, newsqa searchqa hotpotqa nq for the MRQA 2019 Shared Task, winogrande for the WinoGrande dataset, yelp_polarity for the Yelp-2 dataset, scitail for the SciTail dataset, and paws for the PAWS-Wiki dataset.

Additionally, you can add the argument --peft_model_id to initialize the soft prompt and the pair of low-rank matrices with the pretrained prompt vectors. You can add the argument --k_shot_examples to specify the number of examples used for the few-shot learning.

Bugs or questions?

If you have any questions regarding the code or the paper, please feel free to reach out to Zhengxiang at [email protected]. If you experience any difficulties while using the code or need to report a bug, feel free to open an issue. We kindly ask that you provide detailed information about the problem to help us provide effective support.

Citation

@article{shi2023dept,
title = {DePT: Decomposed Prompt Tuning for Parameter-Efficient Fine-tuning},
author = {Shi, Zhengxaing and Lipani, Aldo},
journal = {ArXiv},
url = {http://arxiv.org/abs/2309.05173},
year = {2023},
}

Acknowledgement

This repository is built upon the following repositories:

dept's People

Contributors

zhengxiangshi avatar eltociear avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.