Code Monkey home page Code Monkey logo

libmtl's Introduction

LibMTL

Documentation Status License: MIT PyPI version Supported Python versions CodeFactor Total lines arXiv coverage visitors Made With Love

LibMTL is an open-source library built on PyTorch for Multi-Task Learning (MTL). See the latest documentation for detailed introductions and API instructions.

โญ Star us on GitHub โ€” it motivates us a lot!

News

Table of Content

Features

  • Unified: LibMTL provides a unified code base to implement and a consistent evaluation procedure including data processing, metric objectives, and hyper-parameters on several representative MTL benchmark datasets, which allows quantitative, fair, and consistent comparisons between different MTL algorithms.
  • Comprehensive: LibMTL supports many state-of-the-art MTL methods including 8 architectures and 13 optimization strategies. Meanwhile, LibMTL provides a fair comparison of several benchmark datasets covering different fields.
  • Extensible: LibMTL follows the modular design principles, which allows users to flexibly and conveniently add customized components or make personalized modifications. Therefore, users can easily and fast develop novel optimization strategies and architectures or apply the existing MTL algorithms to new application scenarios with the support of LibMTL.

Overall Framework

framework

Each module is introduced in Docs.

Supported Algorithms

LibMTL currently supports the following algorithms:

Optimization Strategies Venues Comments
Equal Weighting (EW) - Implemented by us
Gradient Normalization (GradNorm) ICML 2018 Implemented by us
Uncertainty Weights (UW) CVPR 2018 Implemented by us
MGDA NeurIPS 2018 Referenced from official PyTorch implementation
Dynamic Weight Average (DWA) CVPR 2019 Referenced from official PyTorch implementation
Geometric Loss Strategy (GLS) CVPR 2019 workshop Implemented by us
Projecting Conflicting Gradient (PCGrad) NeurIPS 2020 Implemented by us
Gradient sign Dropout (GradDrop) NeurIPS 2020 Implemented by us
Impartial Multi-Task Learning (IMTL) ICLR 2021 Implemented by us
Gradient Vaccine (GradVac) ICLR 2021 Implemented by us
Conflict-Averse Gradient descent (CAGrad) NeurIPS 2021 Referenced from official PyTorch implementation
Nash-MTL ICML 2022 Referenced from official PyTorch implementation
Random Loss Weighting (RLW) TMLR 2022 Implemented by us
Architectures Venues Comments
Hard Parameter Sharing (HPS) ICML 1993 Implemented by us
Cross-stitch Networks (Cross_stitch) CVPR 2016 Implemented by us
Multi-gate Mixture-of-Experts (MMoE) KDD 2018 Implemented by us
Multi-Task Attention Network (MTAN) CVPR 2019 Referenced from official PyTorch implementation
Customized Gate Control (CGC), Progressive Layered Extraction (PLE) ACM RecSys 2020 Best Paper Implemented by us
Learning to Branch (LTB) ICML 2020 Implemented by us
DSelect-k NeurIPS 2021 Referenced from official TensorFlow implementation

Supported Benchmark Datasets

Datasets Problems Task Number Tasks Multi/Single-input
NYUv2 Scene Understanding 3 Semantic Segmentation+
Depth Estimation+
Surface Normal Prediction
S
Office-31 Image Recognition 3 Classification M
Office-Home Image Recognition 4 Classification M
QM9 Molecular Property Prediction 11 (default) Regression S
PAWS-X Paraphrase Identification 4 (default) Classification M

Installation

  1. Create a virtual environment

    conda create -n libmtl python=3.8
    conda activate libmtl
    pip install torch==1.8.0 torchvision==0.9.0 numpy==1.20
  2. Clone the repository

    git clone https://github.com/median-research-group/LibMTL.git
  3. Install LibMTL

    cd LibMTL
    pip install -e .

Quick Start

We use the NYUv2 dataset as an example to show how to use LibMTL.

Download Dataset

The NYUv2 dataset we used is pre-processed by mtan. You can download this dataset here.

Run a Model

The complete training code for the NYUv2 dataset is provided in examples/nyu. The file train_nyu.py is the main file for training on the NYUv2 dataset.

You can find the command-line arguments by running the following command.

python train_nyu.py -h

For instance, running the following command will train a MTL model with EW and HPS on NYUv2 dataset.

python train_nyu.py --weighting EW --arch HPS --dataset_path /path/to/nyuv2 --gpu_id 0 --scheduler step

More details is represented in Docs.

Citation

If you find LibMTL useful for your research or development, please cite the following:

@article{LibMTL,
  title={{LibMTL}: A Python Library for Multi-Task Learning},
  author={Baijiong Lin and Yu Zhang},
  journal={arXiv preprint arXiv:2203.14338},
  year={2022}
}

Contributor

LibMTL is developed and maintained by Baijiong Lin.

Contact Us

If you have any question or suggestion, please feel free to contact us by raising an issue or sending an email to [email protected].

Acknowledgements

We would like to thank the authors that release the public repositories (listed alphabetically): CAGrad, dselect_k_moe, MultiObjectiveOptimization, mtan, nash-mtl, pytorch_geometric, and xtreme.

License

LibMTL is released under the MIT license.

libmtl's People

Contributors

baijiong-lin avatar median-research-group avatar stu-yue avatar yuezhixiong avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.