Code Monkey home page Code Monkey logo

generative_distance-based_ood's Introduction

A Deep Generative Distance-Based Classifier for Out-of-Domain Detection with Mahalanobis Space

This repository is the official implementation of A Deep Generative Distance-Based Classifier for Out-of-Domain Detection with Mahalanobis Space (GOLING2020) by Hong Xu, Keqing He, Yuanmeng Yan, Sihong Liu, Zijun Liu, Weiran Xu.

Introduction

A deep generative distance-based model with Mahalanobis distance to detect OOD samples.

The architecture of the proposed model:

Dependencies

We use anaconda to create python environment:

conda create --name python=3.6

Install all required libraries:

pip install -r requirements.txt

How to run

1. Train (only):

python3 experiment.py --dataset SNIPS --proportion 50 --mode train

   When only training model, there is no need to provide setting parameters.

2. Predict (only):

  • LOF
python3 experiment.py --dataset <dataset> --proportion <proportion> --mode test --setting lof --model_dir <model_dir>
  • GDA
python3 experiment.py --dataset <dataset> --proportion <proportion> --mode test --setting gda_lsqr_800 --model_dir <model_dir>
  • MSP
python3 experiment.py --dataset <dataset> --proportion <proportion> --mode test --setting msp --model_dir <model_dir>

  When only predicting the model(no training), the parameter model_dir is required to represent the folder where the model resides (which contains model.h5 model files).

3. Train model, and use the trained model to predict:

python3 experiment.py --dataset <dataset> --proportion <proportion> --mode both --setting <setting>

   Setting parameter is required to specify using which algorithm to predict, but model_dir parameter is not required.

4. Specify the visible category:

python3 experiment.py --dataset <dataset> --proportion <proportion> --mode both --setting <setting> --seen_classes SearchCreativeWork RateBook
python3 experiment.py --dataset SNIPS --proportion 50 --mode test --setting msp_0.5 msp_0.6 msp_0.7 msp_0.8 msp_0.9 --model_dir ./outputs/SNIPS-50-06112350 --seen_classes AddToPlaylist BookRestaurant PlayMusic RateBook

Parameters

The parameters that must be specified:

  • dataset, required, The dataset to use, ATIS or SNIPS or CLINC.
  • proportion, required, The proportion of seen classes, range from 0 to 100.
  • seen_classes, optional, The random seed to randomly choose seen classes.(e.g.--seen_classes SearchCreativeWork RateBook)
  • mode, optional,Specify running mode, onlytrain,onlytest or both
  • setting, required, The settings to detect ood samples, e.g.
    • lof:using LOF for predicting.
    • gda_lsqr_800:using GDA for predicting, using lsqr for solver, and the threshold is 800 (Mahalanobis distance).
    • msp: using MSP for predicting.
  • model_dir, The directory contains model file (.h5), requried when test only.

The parameters that have default values (In general, it can stay fixed):

  • gpu_device, default=1
  • output_dir, default="./outputs"
  • embedding_file,default="glove.6B.300d.txt"
  • embedding_dim, default=300
  • max_seq_len, default=None
  • max_num_words, default=10000
  • max_epoches, default=200
  • patience, default=20
  • batch_size, default=256

Results

  1. Macro f1-score of unknown intents with different proportions (25%, 50% and 75%) of classes are treated as known intents on SNIPS and ATIS datasets.
Snips ATIS CLINC-Full CLINC-Imbal
%of known intents 25% 50% 75% 25% 50% 75% 25% 50% 75% 25% 50% 75%
MSP 0.0 6.2 8.3 8.1 15.3 17.2 0.0 21.3 40.4 0.0 27.8 40.4
DOC 72.5 67.9 63.9 61.6 62.8 37.7 - - - - - -
DOC(softmax) 72.8 65.7 61.8 63.6 63.3 38.7 - - - - - -
LOF(softmax) 76.0 69.4 65.8 67.3 61.8 38.9 91.1 83.1 63.5 88.4 77.6 57.5
LOF(LMCL) 79.2 84.1 78.8 68.6 63.4 39.6 91.3 83.3 62.8 88.7 78.9 56.7
GDA+Euclidean distance 85.6 85.6 82.9 77.9 75.4* 43.7* 91.1 84.2 64.5 91.1 81.2 60.8
GDA+Mahalanobis distance 89.2* 87.4* 83.2 78.5* 72.8 42.1 91.4 84.4 65.1* 91.5 81.5 61.3*
  1. Comparison between our unsupervised OOD detection method and supervised N+1 classification.
% of known intents 50 75
Macro f1-score overall seen unseen overall seen unseen
GDA+Mahalanobis distance 80.2 80.1 84 79.4 79.6 65.7
N+1 classification(2000) 64.6 64.6 67.7 65.7 65.7 66.6
N+1 classification(4000) 45.3 44.9 77.7 46.3 46.1 78.9

Citation

@inproceedings{xu2020deep,
  title={A Deep Generative Distance-Based Classifier for Out-of-Domain Detection with Mahalanobis Space},
  author={Xu, Hong and He, Keqing and Yan, Yuanmeng and Liu, Sihong and Liu, Zijun and Xu, Weiran},
  booktitle={Proceedings of the 28th International Conference on Computational Linguistics},
  pages={1452--1460},
  year={2020}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.