Code Monkey home page Code Monkey logo

tf2qa's Introduction

7th place solution to the TensorFlow 2.0 Question Answering competition

Solution summary: https://www.kaggle.com/c/tensorflow2-question-answering/discussion/127259

envorinment: python 3.6+, tensorflow 1.15

Files

Most of the model code are based on bert joint. Evaluation code are based on official NQ metric, but modified for this competition.

  • prepare_nq_data.py: pre-processing
  • jb_train_tpu.py: training on TPU
  • jb_pred_tpu.py: inference and evaluation of dev set on TPU
  • ensemble_and_tune.py: tuning ensemble weights and thresholds
  • 7th-place-submission.ipynb: inference notebook, same as this
  • vocab_cased-nq.txt: vocab file for cased model with special NQ tokens added
  • bert_config_cased.json: config file for cased model

scripts for the 3 single models

model c: wwm, neg sampling, max_contexts=200, dev 64.5

# pre-processing (this step does not require TPU and could be distributed over multiple processes)
export do_lower_case=True
export max_contexts=200
export tfrecord_dir=fix_top_level_bug_max_contexts_200_0.01_0.04
for shard in {0..49} 
do 
	python3 prepare_nq_data.py --do_lower_case=$do_lower_case --tfrecord_dir=$tfrecord_dir --include_unknowns_answerable=0.01 --include_unknowns_unanswerable=0.04 --shard=$shard --max_contexts=$max_contexts
done

# training
export TPU_NAME=node-1
export train_batch_size=64
export learning_rate=4e-5
export model_suffix=_wwm_fix_top_level_bug_max_contexts_200_0.01_0.04
export train_precomputed_file=gs://<your_bucket>/tfrecords/fix_top_level_bug_max_contexts_200_0.01_0.04/nq-train.tfrecords-*
export init_checkpoint=gs://<your_bucket>/wwm_uncased_L-24_H-1024_A-16/bert_model.ckpt
python3 jb_train_tpu.py --tpu=$TPU_NAME --model_suffix=${model_suffix} --train_batch_size=${train_batch_size} --learning_rate=${learning_rate} --train_precomputed_file=$train_precomputed_file --init_checkpoint=$init_checkpoint --num_train_epochs=1 

# evaluation (ckpt 9500 turned out to be the best)
export MODEL_SUFFIX=_wwm_fix_top_level_bug_max_contexts_200_0.01_0.04-64-4.00E-05
export CKPT_FROM=8000
export CKPT_TO=10000
export doc_stride=256
export do_lower_case=True
python3 jb_pred_tpu.py --tpu=$TPU_NAME --doc_stride=$doc_stride --model_suffix=$MODEL_SUFFIX --ckpt_from=$CKPT_FROM --ckpt_to=$CKPT_TO --eval_set=dev --do_predict=True --do_lower_case=$do_lower_case

model d: wwm, neg sampling, stride=192, dev 63.8

# pre-processing (this step does not require TPU and could be distributed over multiple processes)
export do_lower_case=True
export doc_stride=192
export tfrecord_dir=stride_192_0.01_0.04
for shard in {0..49} 
do 
	python3 prepare_nq_data.py --do_lower_case=$do_lower_case --tfrecord_dir=$tfrecord_dir --include_unknowns_answerable=0.01 --include_unknowns_unanswerable=0.04 --shard=$shard --doc_stride=$doc_stride
done

# training
export TPU_NAME=node-1
export train_batch_size=64
export learning_rate=2e-5
export model_suffix=_wwm_stride_192_neg_0.01_0.04
export train_precomputed_file=gs://<your_bucket>/tfrecords/stride_192_0.01_0.04/nq-train.tfrecords-*
export init_checkpoint=gs://<your_bucket>/wwm_uncased_L-24_H-1024_A-16/bert_model.ckpt
python3 jb_train_tpu.py --tpu=$TPU_NAME --model_suffix=${model_suffix} --train_batch_size=${train_batch_size} --learning_rate=${learning_rate} --train_precomputed_file=$train_precomputed_file --init_checkpoint=$init_checkpoint --num_train_epochs=1 

# evaluation (ckpt 7000 turned out to be the best)
export MODEL_SUFFIX=_wwm_stride_192_neg_0.01_0.04-64-2.00E-05
export CKPT_FROM=5000
export CKPT_TO=8000
export doc_stride=256
export do_lower_case=True
python3 jb_pred_tpu.py --tpu=$TPU_NAME --doc_stride=$doc_stride --model_suffix=$MODEL_SUFFIX --ckpt_from=$CKPT_FROM --ckpt_to=$CKPT_TO --eval_set=dev --do_predict=True --do_lower_case=$do_lower_case

model e: wwm, neg sampling, cased, dev 63.3

# pre-processing (this step does not require TPU and could be distributed over multiple processes)
export do_lower_case=False
export tfrecord_dir=fix_top_level_bug_cased_0.01_0.04
for shard in {0..49} 
do 
	python3 prepare_nq_data.py --do_lower_case=$do_lower_case --tfrecord_dir=$tfrecord_dir --include_unknowns_answerable=0.01 --include_unknowns_unanswerable=0.04 --shard=$shard
done

# training
export TPU_NAME=node-1
export train_batch_size=64
export learning_rate=4.5e-5
export model_suffix=_wwm_cased_fix_top_level_bug_0.01_0.04
export train_precomputed_file=gs://<your_bucket>/tfrecords/fix_top_level_bug_cased_0.01_0.04/nq-train.tfrecords-*
export init_checkpoint=gs://<your_bucket>/wwm_cased_L-24_H-1024_A-16/bert_model.ckpt
export do_lower_case=False
python3 jb_train_tpu.py --tpu=$TPU_NAME --model_suffix=${model_suffix} --train_batch_size=${train_batch_size} --learning_rate=${learning_rate} --train_precomputed_file=$train_precomputed_file --init_checkpoint=$init_checkpoint --num_train_epochs=1 --do_lower_case=${do_lower_case}

# evaluation (ckpt 8500 turned out to be the best)
export MODEL_SUFFIX=_wwm_cased_fix_top_level_bug_0.01_0.04-64-4.50E-05
export CKPT_FROM=6000
export CKPT_TO=8500
export doc_stride=256
export do_lower_case=False
python3 jb_pred_tpu.py --tpu=$TPU_NAME --doc_stride=$doc_stride --model_suffix=$MODEL_SUFFIX --ckpt_from=$CKPT_FROM --ckpt_to=$CKPT_TO --eval_set=dev --do_predict=True --do_lower_case=$do_lower_case

tf2qa's People

Contributors

boliu61 avatar

Stargazers

Ruhong avatar  avatar Bibek Chaudhary avatar Charco Hui avatar Astha Verma avatar fujiyuu75 avatar Ze Liu avatar Jerry Wang avatar SSSSQD avatar Avinash  avatar

Watchers

James Cloos avatar  avatar

Forkers

jt120 cytsinghua

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.