Code Monkey home page Code Monkey logo

pvt-flax's Introduction

Pyramid Vision Transformer

License JAX

This repo contains the unofficial JAX/Flax implementation of PVT v2: Improved Baselines with Pyramid Vision Transformer.
All credits to the authors Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan, Kaitao Song, Ding Liang, Tong Lu, Ping Luo, Ling Shao for their wonderful work.

Dependencies

It is recommended to create a new virtual environment so that updates/downgrades of packages do not break other projects.

  • Environment characteristics:
    python = 3.9.12 cuda = 11.3 jax = 0.3.16 flax = 0.6.0

  • Follow the instructions on official JAX/Flax documentation for installing their packages.

    pip install -r requirements.txt
    

Note: Flax is not dependent on TensorFlow itself, however, we make use of methods that take advantage of tf.io.gfile As such, we only install tensorflow-cpu. Same is the case with PyTorch, we only install it in order to use their torch.data.DataLoader.

Run

To get started, clone this repo and install the required dependencies.

Datasets

  • TensorFlow Datasets - Refer to TensorFlow Dataset Image Classification Catalog and accordingly modify the following keys in config/default.py.

    config.dataset_name = "imagenette"
    config.data_shape = [224, 224]
    config.num_classes = 10
    config.split_keys = ["train", "validation"]
  • PyTorch DataLoader - To load datasets in PyTorch style, use the wrapper for torch.DataLoader in data/numpyloader.py -> NumpyLoader along with a custom collate function.

  • Custom Dataset - Currently, this repo does not support out of the box support for custom image classification dataset. However, you can manipulate NumpyLoader to accomplish this.

Training

  • Configure the {key: value pairs} in the config file present at config/default.py.

  • Execute train.py with path to checkpoint and --eval-only argument. Example usage:

    python train.py --model-name "PVT_V2_B0" --work-dir "output/"

Evaluation

  • Execute train.py with appropriate arguments. Example usage:

    python train.py --model-name "PVT_V2_B0" \
                    --eval-only \
                    --checkpoint_dir "output/"

To do

  • Convert ImageNet pretrained PyTorch weights (.pth) to Flax weights

Note: Since my undergrad studies are resuming after summer break, I may or may not be able to find time to complete the above tasks. If you want to implement the aforelisted tasks, I'll be more than glad to merge your pull request. ❤️

Acknowledgements

We acknowledge the excellent implementation of PVT in MMDetection, PyTorch Image Models and the official implementation. I referred to these implementations as a source of reference.

Citing PVT

  • PVT v1

    @inproceedings{wang2021pyramid,
      title={Pyramid vision transformer: A versatile backbone for dense prediction without convolutions},
      author={Wang, Wenhai and Xie, Enze and Li, Xiang and Fan, Deng-Ping and Song, Kaitao and Liang, Ding and Lu, Tong and Luo, Ping and Shao, Ling},
      booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
      pages={568--578},
      year={2021}
    }
    
  • PVT v2

    @article{wang2021pvtv2,
      title={Pvtv2: Improved baselines with pyramid vision transformer},
      author={Wang, Wenhai and Xie, Enze and Li, Xiang and Fan, Deng-Ping and Song, Kaitao and Liang, Ding and Lu, Tong and Luo, Ping and Shao, Ling},
      journal={Computational Visual Media},
      volume={8},
      number={3},
      pages={1--10},
      year={2022},
      publisher={Springer}
    }
    

pvt-flax's People

Contributors

muhd-umer avatar

Stargazers

 avatar fred monroe avatar

Watchers

Kostas Georgiou avatar  avatar

pvt-flax's Issues

Distributed training on TPU unfunctional

Currently, trying to train on multiple cloud TPUs results in the training being stuck at 0%.

JAX Process: 0 / 1
JAX Local Devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
[15:59:22] Epoch: 1
           Training:   0%|                              | 0/937 [00:00<?, ?it/s]

It is highly likely that the issue is in the TF record dataset building pipeline, but I couldn't definitively single out a root cause.
Note: Training works just fine on GPU.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.