Code Monkey home page Code Monkey logo

icl-ceil's Introduction

Compositional Exemplars for In-context Learning

This repository contains the source code for CEIL, which is an in-context example retriever proposed in our paper “Compositional Exemplars for In-context Learning”, Besides, this repo also implements several learning-free (e.g., Random, BM25, TopK-BERT, TopK-DPP) and learning-based retrievers (e.g., EPR).

Instead of independently retrieving each exemplar (or in-context example), CEIL models the full exemplar sets by learning its joint probability with a conditional DPP, which is further trained to align with the LM score through a contrastive loss. For a given test input during inference, the optimal exemplar set is obtained by the learned DPP retriever through MAP inference. The black-box LM is frozen during the whole procedure.

Contents

Setup

All required packages can be found in requirements.txt. You can install them in a new environment with

conda create -n icl python=3.7
conda activate icl

git clone [email protected]:HKUNLP/icl-ceil.git
#[Optional] If you want to experiment on Break dataset with LF-EM evaluation metric, you have to clone recursively with the following commands to include third-party dependencies:
#git clone --recurse-submodules [email protected]:HKUNLP/HKUNLP.git

# The following line to be replaced depending on your cuda version.
pip install torch==1.10.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html

cd icl-ceil
pip install -r requirements.txt
# if you don't want to use API from openai, just comment out the `openai` package in `requirements.txt`.

Setup WandB for tracking the training status for EPR and CEIL in scripts/run_epr.sh and scripts/run_dpp_epr.sh:

export WANDB_API_KEY=YOUR_WANDB_API_KEY
export WANDB_PROJECT=YOUR_PROJECT_NAME
export WANDB_ENTITY=YOUR_TEAM_NAME

Usage

Given an index dataset (by default the training set) and an test dataset (by default the validation set), we include scripts to run six in-context example retrievers under scripts/:

  • run_random.sh: random retrieve;
  • run_bm25.sh: sparse retrieve with BM25;
  • run_dense.sh: dense retrieve with BERT encoder;
  • run_dense_dpp.sh: dense retrieve with BERT encoder, considering diversity between examples;
  • run_epr.sh: the learning-based retriever trained to retrieve a better singleton in-context example (Rubin et al., 2022);
  • run_ceil.sh: our proposed learning-based retriever. You need to run run_epr.sh first to get an initialization for training CEIL.

The config files and detailed explanation of each argument can be found in configs/.

Modules

  1. bm25_retriever.py: retrieve examples from training set with bm25, saved json will have additional field:
    • ctxs: a single in-context example sequence which can be directly used in inferencer. This is a idx list and each idx specifies the position of an example in the index (e.g., training set).
    • ctxs_candidates: multiple in-context example sequences which can be used to train a dense retriever by further running scorer and retriever_trainer. This is a list of idx list.
  2. dense_retriever.py: similar as bm25_retriever but retrieve examples with embed model, which is specified in configs/dense_retriever.yaml.
  3. scorer.py: scoring each candidate in ctxs_candidates and reranking ctxs_candidates based on the scores to create a training data.
  4. retriever_trainer.py: train a dense retriever based on the ordered ctxs_candidates by contrastive learning.
  5. inferencer.py: in-context learning inference with pre-retrieved in-context examples (i.e., ctxs) and report the final metric.

When using a local huggingface model as inferencer LM, both scorer and inferencer use accelerate to run on multiple GPUs in parallel. For API-based inferencer, we also support multiprocessing api calls. For example, you can input multiple (say n) OpenAI authenticated keys in openai_keys.txt, and we will create n processes to run the inference.

For each task, a dataset_wrapper is needed, which defines the q (input or question) and a (output or answer) fields. Besides, qa and gen_a fields are used to define the format of each in-context example and the whole prompt.

For classification tasks, an additional field is choices, which pre-defines the description for each class. During answer generation for classification tasks, we select the class description that has the lowest LM perplexity.

Add a New Task

Change the task by modify task_name argument, and the current available tasks are sst5, mrpc, qnli, mnli, cmsqa, swag, webqs, geoquery, nl2bash, mtop, break, smcalflow. It's easy to add a new task with this repo. You can take the following steps:

  1. (Optional) Define a dataset script under src/hf_datasets if the dataset is not available in Huggingface Datasets.
  2. Create a file under src/dataset_readers/dataset_wrapper, and define several interfaces of getting question (task input), answer (task output) and prompt, which will be used by different dataset_readers (e.g., base_dsr, inference_dsr).
  3. (Optional) Define the task metric under src/metrics if you have ground-truth outputs for your test dataset.

Citation

If you find our work helpful, please cite us:

@article{ye2023ceil,
      title={Compositional Exemplars for In-context Learning}, 
      author={Jiacheng Ye, Zhiyong Wu, Jiangtao Feng, Tao Yu, Lingpeng Kong},
      year={2023},
      eprint={2302.05698},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

icl-ceil's People

Contributors

jiacheng-ye avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

icl-ceil's Issues

question about details of parameters

hello, i am trying to reproduce the result in the paper.I run the scripts/run_epr.sh successfully. i get the em score of about 71 in mrpc which is 75.98 in the paper.Are the settings in run_epr.sh different from the paper? Can you provide the setting parameters in the paper?
Thanks a lot!

Error while trying to implement for SST-2

Hello

I followed your instructions step by step. It is said:

run_ceil.sh: our proposed learning-based retriever. You need to run run_epr.sh first to get an initialization for training CEIL.

It worked fine with for task_name in mrpc but then I tried to implement your method for the sst2 dataset. I created sst2.py based on icl-ceil/src/dataset_readers/dataset_wrappers/sst5.py with a few modifications:

@field_getter.add("choices")
def get_choices(entry):
    return ["bad", "good"]


class DatasetWrapper(ABC):
    name = "sst2"
    ice_separator = "\n"
    question_field = "sentence"
    answer_field = "label"
    hf_dataset = "sst2"
    hf_dataset_name = None
    field_getter = field_getter

Then I ran run_epr.sh and got the following error:

Error executing job with overrides: ['task_name=sst2', 'dataset_reader.dataset_path=output/epr/sst2/EleutherAI/gpt-neo-2.7B/scored.json', 'index_reader.dataset_path=index_data/sst2/index_dataset.json', 'training_args.output_dir=output/epr/sst2/EleutherAI/gpt-neo-2.7B/bert-fix_ctx-shared-bs64', 'training_args.run_name=bert-fix_ctx-shared-bs64', 'model_config.ctx_model_name=null']
Traceback (most recent call last):
  File "/home/lavendermint/anaconda3/envs/ceil/lib/python3.7/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 92, in _call_target
    return _target_(*args, **kwargs)
  File "/home/lavendermint/distillation/ExamplesSelection/icl-ceil/src/dataset_readers/training_dsr.py", line 32, in __init__
    self.encoded_dataset = self.encode_field(dataset_wrapper, field)
  File "/home/lavendermint/distillation/ExamplesSelection/icl-ceil/src/dataset_readers/training_dsr.py", line 41, in encode_field
    'tokenizer': self.tokenizer}
  File "/home/lavendermint/anaconda3/envs/ceil/lib/python3.7/site-packages/datasets/dataset_dict.py", line 871, in map
    for k, dataset in self.items()
  File "/home/lavendermint/anaconda3/envs/ceil/lib/python3.7/site-packages/datasets/dataset_dict.py", line 871, in <dictcomp>
    for k, dataset in self.items()
  File "/home/lavendermint/anaconda3/envs/ceil/lib/python3.7/site-packages/datasets/arrow_dataset.py", line 578, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/home/lavendermint/anaconda3/envs/ceil/lib/python3.7/site-packages/datasets/arrow_dataset.py", line 543, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/home/lavendermint/anaconda3/envs/ceil/lib/python3.7/site-packages/datasets/arrow_dataset.py", line 2992, in map
    f"Column to remove {list(filter(lambda col: col not in self._data.column_names, remove_columns))} not in the dataset. Current columns in the dataset: {self._data.column_names}"
ValueError: Column to remove ['train', 'validation', 'test'] not in the dataset. Current columns in the dataset: ['idx', 'sentence', 'label']

Since sst5 and sst2 have the same structure (both datasets have train, test and validation splits, and each split has idx, sentence/text and label columns), I don't get what's wrong, exactly. How can I fix that?

FSDP requires PyTorch >= 1.12.0

Hi! Thank you for making the code available so swiftly! I've been trying to get a working setup. With PyTorch 1.10 as given in setup description, I get the error FSDP requires PyTorch >= 1.12.0. Can you specify the versions in the requirements?

As an aside, I've also tried with PyTorch 1.13 (scripts/run_bm_25.sh on geoquery) but then I get:

WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 2016882 closing signal SIGTERM
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -9) local_rank: 1 (pid: 2016883) of binary:
...
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 

But maybe this an issue that is specific to the newer PyTorch version.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.