Code Monkey home page Code Monkey logo

lightning-convai's Introduction

State-of-the-art Conversational AI

This code is based on the transfer-learning-conv-ai repo from Hugging Face. Please check the accompanying blog post How to build a State-of-the-Art Conversational AI with Transfer Learning.

The major difference is that we use PyTorch Lightning instead of Ignite and a more "up to date" version of Transformers. We also made an effort to make everything well documented and "easy" to understand.

Presentation Slides

Model Architecture

architecture

Our model is built on top of a pretrained GPT2 model and its is trained in a multi-task setting where we minimize the following losses:

  • Language modeling: we project the hidden-state on the word embedding matrix to get logits and apply a cross-entropy loss on the portion of the target corresponding to the gold reply (green labels on the above figure).
  • Next-sentence prediction: we pass the hidden-state of the last token (the end-of-sequence token) through a linear layer to get a score and apply a cross-entropy loss to classify correctly a gold answer among distractors.

Install:

virtualenv -p python3.6 convai-env
source convai-env/bin/activate

git clone https://github.com/HLT-MAIA/lightning-convai
cd lightning-convai
pip install -r requirements.txt

Command Line Interface:

Train:

To set up your training you have to define your model configs. Take a look at the example.yaml in the configs folder, where all hyperparameters are briefly described.

After defining your hyperparameter run the following command:

python cli.py train -f configs/example.yaml

Monitor training with Tensorboard:

Launch tensorboard with:

tensorboard --logdir="experiments/"

Test:

To test your model ability to rank candidate answers and reply to user questions just run the following command:

python cli.py test --experiment experiments/{experiment_id}/ --test_set data/personachat_val.json

where experiment_id is the name of the experiment folder containing the model you want to test.

Options:
  --experiment PATH    Path to the experiment folder containing the checkpoint
                       we want to interact with.  [required]

  --test_set PATH      Path to the json file containing the testset.
                       [required]

  --cuda / --cpu       Flag that either runs inference on cuda or in cpu.
                       [default: True]

  --seed INTEGER       Seed value used during inference. This influences
                       results only when using sampling.

  --sample / --search  Flag that either runs Nucleus-Sampling or Beam search.
                       [default: True]

  --top_p FLOAT        Nucleus filtering (top-p) before sampling (<=0.0: no
                       filtering)

  --temperature FLOAT  Use temperature to decrease the sensitivity to low
                       probability candidates when sampling.

  --num_beams INTEGER  Number of beams during search.
  --to_json TEXT       Creates and exports model predictions to a JSON file.
                       [default: False]

  --help               Show this message and exit.

Interact:

Fun command where we can interact with with a trained model that impersonates a Vegan that likes cooking and radical activities such as sky-diving.

python cli.py interact --experiment experiments/{experiment_id}/

Benchmarks:

Training with the example.yaml config should result in the following:

Metric GPT2 DialoGPT-small
Hits@1↑ 0.8023 0.8231
Hits@5↑ 0.9721 0.9771
Hits@10↑ 0.9948 0.9960
BLEU↑ 2.7799 2.9633
TER↓ 1.0497 1.0528
BERTScore↑ 0.8548 0.8548

Download DialoGPT2-small trained with PersonaChat:

cd experiments
wget https://unbabel-experimental-models.s3.amazonaws.com/maia/persona/dialogpt2-small.zip
unzip dialogpt2-small.zip

Test the model:

python cli.py test --experiment experiments/dialogpt2-small/ --test_set data/personachat_val.json --to_json

Code Style:

All the code follows the same style we use Black.

lightning-convai's People

Contributors

ricardorei avatar

Watchers

James Cloos avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.