Code Monkey home page Code Monkey logo

l2r2's Introduction

L2R2

PyTorch implementation of L2R2: Leveraging Ranking for Abductive Reasoning.

Usage

Set up environment

L2R2 is tested on Python 3.6 and PyTorch 1.0.1.

$ pip install -r requirements.txt

Prepare data

αNLI

$ wget https://storage.googleapis.com/ai2-mosaic/public/alphanli/alphanli-train-dev.zip
$ unzip -d alphanli alphanli-train-dev.zip

Training

We train the L2R2 models on 4 K80 GPUs. The appropriate batch size on each K80 is 1, so the batch size in our experiment is 4.

The available criterion for optimization could selected in:

  • list_net: list-wise KLD loss used in ListNet
  • list_mle: list-wise Likelihood loss used in ListMLE
  • approx_ndcg: list-wise ApproxNDCG loss used in ApproxNDCG
  • rank_net: pair-wise Logistic loss used in RankNet
  • hinge: pair-wise Hinge loss used in Ranking SVM
  • lambda: pair-wise LambdaRank loss used in LambdaRank

Note that in our experiment, we manually reduce the learning rate instead of using any automatic learning rate scheduler.

For example, we first fine-tune the pre-trained RoBERTa-large model for up to 10 epochs with a learning rate of 5e-6 and save the model checkpoint which performs best on the dev set.

$ python run.py \
  --data_dir=alphanli/ \
  --output_dir=ckpts/ \
  --model_type='roberta' \
  --model_name_or_path='roberta-large' \
  --linear_dropout_prob=0.6 \
  --max_hyp_num=22 \
  --tt_max_hyp_num=22 \
  --max_seq_len=72 \
  --do_train \
  --do_eval \
  --criterion='list_net' \
  --per_gpu_train_batch_size=1 \
  --per_gpu_eval_batch_size=1 \
  --learning_rate=5e-6 \
  --weight_decay=0.0 \
  --num_train_epochs=10 \
  --seed=42 \
  --log_period=50 \
  --eval_period=100 \
  --overwrite_output_dir

Then, we continue to fine-tune the just saved model for up to 3 epochs with a smaller learning rate, such as 3e-6, 1e-6 and 5e-7, until the performance on the dev set is no longer improved.

python run.py \
  --data_dir=alphanli/ \
  --output_dir=ckpts/ \
  --model_type='roberta' \
  --model_name_or_path=ckpts/H22_L72_E3_B4_LR5e-06_WD0.0_MMddhhmmss/checkpoint-best_acc/ \
  --linear_dropout_prob=0.6 \
  --max_hyp_num=22 \
  --tt_max_hyp_num=22 \
  --max_seq_len=72 \
  --do_train \
  --do_eval \
  --criterion='list_net' \
  --per_gpu_train_batch_size=1 \
  --per_gpu_eval_batch_size=1 \
  --learning_rate=1e-6 \
  --weight_decay=0.0 \
  --num_train_epochs=3 \
  --seed=43 \
  --log_period=50 \
  --eval_period=100 \
  --overwrite_output_dir

Note: change the seed to reshuffle training samples.

Evaluation

Evaluate the performance on the dev set.

$ export MODEL_DIR="ckpts/H22_L72_E3_B4_LR5e-07_WD0.0_MMddhhmmss/checkpoint-best_acc/"
$ python run.py \
  --data_dir=alphanli/ \
  --output_dir=$MODEL_DIR \
  --model_type='roberta' \
  --model_name_or_path=$MODEL_DIR \
  --max_hyp_num=2 \
  --max_seq_len=72 \
  --do_eval \
  --per_gpu_eval_batch_size=1

Inference

$ ./run_model.sh

Citation

@inproceedings{10.1145/3397271.3401332,
  author = {Zhu, Yunchang and Pang, Liang and Lan, Yanyan and Cheng, Xueqi},
  title = {L2R²: Leveraging Ranking for Abductive Reasoning},
  year = {2020},
  url = {https://doi.org/10.1145/3397271.3401332},
  doi = {10.1145/3397271.3401332},
  booktitle = {Proceedings of the 43rd International ACM SIGIR Conference on Research and Development in Information Retrieval},
  series = {SIGIR '20}
}

l2r2's People

Contributors

zycdev avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.