Code Monkey home page Code Monkey logo

jax-models's Introduction

JAX Models

license-shield release-shield python-shield code-style

Table of Contents
  1. About The Project
  2. Getting Started
  3. Contributing
  4. License
  5. Contact

About The Project

The JAX Models repository aims to provide open sourced JAX/Flax implementations for research papers originally without code or code written with frameworks other than JAX. The goal of this project is to make a collection of models, layers, activations and other utilities that are most commonly used for research. All papers and derived or translated code is cited in either the README or the docstrings. If you think that any citation is missed then please raise an issue.

All implementations provided here are available on Papers With Code.


Available model implementations for JAX are:
  1. MetaFormer is Actually What You Need for Vision (Weihao Yu et al., 2021)
  2. Augmenting Convolutional networks with attention-based aggregation (Hugo Touvron et al., 2021)
  3. MPViT : Multi-Path Vision Transformer for Dense Prediction (Youngwan Lee et al., 2021)
  4. MLP-Mixer: An all-MLP Architecture for Vision (Ilya Tolstikhin et al., 2021)
  5. Patches Are All You Need (Anonymous et al., 2021)
  6. SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers (Enze Xie et al., 2021)
  7. A ConvNet for the 2020s (Zhuang Liu et al., 2021)
  8. Masked Autoencoders Are Scalable Vision Learners (Kaiming He et al., 2021)
  9. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows (Ze Liu et al., 2021)
  10. Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions (Wenhai Wang et al., 2021)
  11. Going deeper with Image Transformers (Hugo Touvron et al., 2021)
  12. Visual Attention Network (Meng-Hao Guo et al., 2022)

Available layers for out-of-the-box integration:
  1. DropPath (Stochastic Depth) (Gao Huang et al., 2021)
  2. Squeeze-and-Excitation Layer (Jie Hu et al. 2019)
  3. Depthwise Convolution (François Chollet, 2017)

Prerequisites

Prerequisites can be installed separately through the requirements.txt file in the main directory using:

pip install -r requirements.txt

The use of a virtual environment is highly recommended to avoid version incompatibilites.

Installation

This project is built with Python 3 for the latest JAX/Flax versions and can be directly installed via pip.

pip install jax-models

If you wish to use the latest version then you can directly clone the repository too.

git clone https://github.com/DarshanDeshpande/jax-models.git

Usage

To see all model architectures available:

from jax_models import list_models
from pprint import pprint

pprint(list_models())

To load your desired model:

from jax_models import load_model
load_model('swin-tiny-224', attach_head=True, num_classes=1000, dropout=0.0, pretrained=True)

Note: It is necessary to pass attach_head=True and num_classes while loading pretrained models.

Contributing

Please raise an issue if any implementation gives incorrect results, crashes unexpectedly during training/inference or if any citation is missing.

You can contribute to jax_models by supporting me with compute resources or by contributing your own resources to provide pretrained weights.

If you wish to donate to this inititative then please drop me a mail here.

License

Distributed under the Apache 2.0 License. See LICENSE for more information.

Contact

Feel free to reach out for any issues or requests related to these implementations

Darshan Deshpande - Email | Twitter | LinkedIn

jax-models's People

Contributors

darshandeshpande avatar maxidl avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

jax-models's Issues

Missing Axis Swap in ExtractPatches and MergePatches

In patch_utils.py, the modules ExtractPatches and MergePatches are missing an axis swap between the reshapes, resulting in the extracted patches becoming horizontal stripes. For example, if we follow the code in ExtractPatches:

>>> inputs = jnp.arange(16).reshape(1, 4, 4, 1)
>>> inputs[0, :, :, 0]

DeviceArray([[ 0,  1,  2,  3],
             [ 4,  5,  6,  7],
             [ 8,  9, 10, 11],
             [12, 13, 14, 15]], dtype=int32)

>>> patch_size = 2
>>> batch, height, width, channels = inputs.shape
>>> height, width = height // patch_size, width // patch_size
>>> x = jnp.reshape(inputs, (batch, height, patch_size, width, patch_size, channels))
>>> x = jnp.reshape(x, (batch, height * width, patch_size ** 2 * channels))
>>> x[0, 0, :]

DeviceArray([0, 1, 2, 3], dtype=int32)

We see that the first patch extracted is not the patch containing [0, 1, 4, 5], but the horizontal stripe [0, 1, 2, 3].
To fix this problem, we should add an axis swap. For ExtractPatches, this should be:

batch, height, width, channels = inputs.shape
height, width = height // patch_size, width // patch_size
x = jnp.reshape(
    inputs, (batch, height, patch_size, width, patch_size, channels)
)
x = jnp.swapaxes(x, 2, 3)
x = jnp.reshape(x, (batch, height * width, patch_size ** 2 * channels))

For MergePatches, this should be:

batch, length, _ = inputs.shape
height = width = int(length**0.5)
x = jnp.reshape(inputs, (batch, height, width, patch_size, patch_size, -1))
x = jnp.swapaxes(x, 2, 3)
x = jnp.reshape(x, (batch, height * patch_size, width * patch_size, -1))

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.