Code Monkey home page Code Monkey logo

distil-bilstm's Introduction

distil-bilstm

Run on FloydHub

This repository contains scripts to train a tiny bidirectional LSTM classifier on the SST-2 dataset (url). It also contains a script to fine-tune bert-large-uncased on the same task. The procedure is inspired by the paper Distilling Task-Specific Knowledge from BERT into Simple Neural Networks.

Installing requirements

pip install -r requirements.txt  # Skip this if you are running on FloydHub
python -m spacy download en

Fine-tuning bert-large-uncased

>> python train_bert.py --help

usage: train_bert.py [-h] --data_dir DATA_DIR --output_dir OUTPUT_DIR
                     [--epochs EPOCHS] [--batch_size BATCH_SIZE] [--lr LR]
                     [--lr_schedule {constant,warmup,cyclic}]
                     [--warmup_steps WARMUP_STEPS]
                     [--epochs_per_cycle EPOCHS_PER_CYCLE] [--do_train]
                     [--seed SEED] [--no_cuda] [--cache_dir CACHE_DIR]

optional arguments:
  -h, --help            show this help message and exit
  --data_dir DATA_DIR   Directory containing the dataset.
  --output_dir OUTPUT_DIR
                        Directory where to save the model.
  --epochs EPOCHS
  --batch_size BATCH_SIZE
  --lr LR               Learning rate.
  --lr_schedule {constant,warmup,cyclic}
                        Schedule to use for the learning rate. Choices are:
                        constant, linear warmup & decay, cyclic.
  --warmup_steps WARMUP_STEPS
                        Warmup steps for the 'warmup' learning rate schedule.
                        Ignored otherwise.
  --epochs_per_cycle EPOCHS_PER_CYCLE
                        Epochs per cycle for the 'cyclic' learning rate
                        schedule. Ignored otherwise.
  --do_train
  --seed SEED           Random seed.
  --no_cuda
  --cache_dir CACHE_DIR
                        Custom cache for transformer models.

Example:

python train_bert.py --data_dir SST-2 --output_dir bert_output --epochs 1 --batch_size 16 --lr 1e-5 --lr_schedule warmup --warmup_steps 100 --do_train

Generating the augmented dataset

The file used in my tests is available at https://www.floydhub.com/alexamadori/datasets/sst-2-augmented/1, but you may want to generate another one with a random seed or to use a different teacher model.

>> python generate_dataset.py --help

usage: generate_dataset.py [-h] --input INPUT --output OUTPUT --model MODEL
                           [--no_augment] [--batch_size BATCH_SIZE]
                           [--no_cuda]

optional arguments:
  -h, --help            show this help message and exit
  --input INPUT         Input dataset.
  --output OUTPUT       Output dataset.
  --model MODEL         Model to use to generate the labels for the augmented
                        dataset.
  --no_augment          Don't perform data augmentation
  --batch_size BATCH_SIZE
  --no_cuda

Example:

python generate_dataset.py --input SST-2/train.tsv --output SST-2/augmented.tsv --model bert_output

Training the BiLSTM model

>> python train_bilstm.py --help

usage: train_bilstm.py [-h] --data_dir DATA_DIR --output_dir OUTPUT_DIR
                       [--augmented] [--epochs EPOCHS]
                       [--batch_size BATCH_SIZE] [--lr LR]
                       [--lr_schedule {constant,warmup,cyclic}]
                       [--warmup_steps WARMUP_STEPS]
                       [--epochs_per_cycle EPOCHS_PER_CYCLE] [--do_train]
                       [--seed SEED] [--no_cuda]

optional arguments:
  -h, --help            show this help message and exit
  --data_dir DATA_DIR   Directory containing the dataset.
  --output_dir OUTPUT_DIR
                        Directory where to save the model.
  --augmented           Wether to use the augmented dataset for knowledge
                        distillation
  --epochs EPOCHS
  --batch_size BATCH_SIZE
  --lr LR               Learning rate.
  --lr_schedule {constant,warmup,cyclic}
                        Schedule to use for the learning rate. Choices are:
                        constant, linear warmup & decay, cyclic.
  --warmup_steps WARMUP_STEPS
                        Warmup steps for the 'warmup' learning rate schedule.
                        Ignored otherwise.
  --epochs_per_cycle EPOCHS_PER_CYCLE
                        Epochs per cycle for the 'cyclic' learning rate
                        schedule. Ignored otherwise.
  --do_train
  --seed SEED
  --no_cuda

Example:

python train_bilstm.py --data_dir SST-2 --output_dir bilstm_output --epochs 1 --batch_size 50 --lr 1e-3 --lr_schedule warmup --warmup_steps 100 --do_train --augmented

distil-bilstm's People

Contributors

redeipirati avatar tacchinotacchi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

distil-bilstm's Issues

Model name 'bert_output' was not found in model name list

generate dataset does not work as expected:

!python generate_dataset.py --input /tmp/stanford/SST-2/train.tsv --output /tmp/stanford/SST-2/augmented.tsv --model bert_output

Output:

2020-01-17 13:31:21.462906: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libnvinfer.so.6'; dlerror: libnvinfer.so.6: cannot open shared object file: No such file or directory
2020-01-17 13:31:21.463138: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libnvinfer_plugin.so.6'; dlerror: libnvinfer_plugin.so.6: cannot open shared object file: No such file or directory
2020-01-17 13:31:21.463172: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:30] Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Loading dataset: 100%|████████████████████| 67349/67349 [15:06<00:00, 74.27it/s]
Generation: 100%|█████████████████████████| 67349/67349 [11:42<00:00, 95.89it/s]
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/transformers/configuration_utils.py", line 160, in from_pretrained
    config = cls.from_json_file(resolved_config_file)
  File "/opt/conda/lib/python3.7/site-packages/transformers/configuration_utils.py", line 213, in from_json_file
    with open(json_file, "r", encoding='utf-8') as reader:
FileNotFoundError: [Errno 2] No such file or directory: '/root/.cache/torch/transformers/c25684d5ade5636f37761238680b40384cf72feb7f62908661756e71d9e2318a'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "generate_dataset.py", line 83, in <module>
    model = BertForSequenceClassification.from_pretrained(args.model).to(device)
  File "/opt/conda/lib/python3.7/site-packages/transformers/modeling_utils.py", line 350, in from_pretrained
    **kwargs
  File "/opt/conda/lib/python3.7/site-packages/transformers/configuration_utils.py", line 173, in from_pretrained
    raise EnvironmentError(msg)
OSError: Model name 'bert_output' was not found in model name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased-whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert-base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, bert-base-japanese, bert-base-japanese-whole-word-masking, bert-base-japanese-char, bert-base-japanese-char-whole-word-masking, bert-base-finnish-cased-v1, bert-base-finnish-uncased-v1). We assumed 'https://s3.amazonaws.com/models.huggingface.co/bert/bert_output/config.json' was a path or url to a configuration file named config.json or a directory containing such a file but couldn't find any such file at this path or url.

augment_dataset.py -> generate_dataset.py

Hey there - thanks for posting your code, it's really easy to read.

Slight typo: It looks like you didn't update the README:

could you please fix the filename: augment_dataset.py -> generate_dataset.py

Thanks!

Loss calculation

Thanks for sharing the repo and blog. I need some clarity on the loss calculation. As mentioned in the blog loss is calculated as below.
L=(1−α)LH+αLKL

However in the implementation(trainer.py, get_loss()) only computes either of them. Is it inline with the theory as in the blog Or am I missing anything?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.