Code Monkey home page Code Monkey logo

mlp-mixer-trainer's Introduction

MLP-MIXER-KERAS

This repo presents a framework to train a variety of MLP Mixer models. The framework was further used to conduct experimentation on non-linearities on the projection layers to match performance with Vision Transformers while being memory and compute efficient.

The pipeline is implemented in Keras and has many functionalities to train, fine-tune, and test it on different datasets.

User is provided the option to choose:

  • Model architecture : b16/l16/s16/b32/l32/s32
  • Dataset : Cifar10, Pets, Tiny Imagenet
  • Pretrained Model weights : B16 or L16 with ImageNet 1k or 21k
  • Load a locally saved model
  • Whether to fine-tune, train from scratch, or just test.

The user can further specify hyperparameters like:

  • Non-linear Activations and whether to use them (original paper did not use them for the Projection layers)
  • Un-Freezing of layers (Only top or all)
  • Optimizers and learning rates, decays etc.
  • Batch Size
  • Dropout and Drop Connect

Environment setup

We have provided reqs.txt file, please use it to recreate an environment. Otherwise it runs fine with the default ecbm4040 tf24 environment, with additional installation of only a few libraries like pandas, scikit-learn and seaborn needed.

Dataset Preparation

Cifar

We automatically handle the preparation for Cifar if it has not been downloaded yet. We upscale Cifar to 224,224 if a pretrained model is being fine-tuned on Cifar.

Pets

Go to official Oxford Link and download it. Please extract the 'images' directory from Pets dataset into ../Datasets/PETS/

We automatically handle the re-organization of directories after this.

Tiny Imagenet

Go to ImageNet website and make an account. Download tiny-imagenet-200.zip and unzip it at ../Datasets/tiny-imagenet-200/

We handle the loading from here automatically.

Training

User has 2 options:

Notebook

Open Jupyter Notebook using

jupyter notebook

Open MLP_MIXER_MAIN.ipynb and change whatever you wish in the user_configs dictionary. Even if user comments out the parameters, we will handle it with default params. User will be shown the configs before training. Some default parameters may not be used if irrelevant and can be ignored safely.

Run all cells sequentially to get all training and testing results. Plots will be saved as well as displayed on notebook.

python file

Open run.py and change whatever you wish in the user_configs dictionary. User will be shown the configs before training. Some default parameters may not be used if irrelevant and can be ignored safely.

Run the python file with nohup for No hangups, useful when training for a high number of epochs.

nohup python run.py

Results

User will have obtained the results of the training in saved_models/ directory under a relevant folder named according to your configs.

The user can find in the same directory the following:

  • Best saved model (h5 file, can be loaded directly)
  • Tensorboard logs in TBlogs/
  • Csv of training history
  • Plots for both training and validation accuracies
  • Plots for both training and validation curves
  • Plot for Confusion Matrix
  • Plot for ROC Curve
  • Testing results
  • Time taken result

Description of Files

train.py

Has functions to train the model (and other assisting functions to support it) and save plots of accuracies and validation in directies named according to experiment. Handles

data_utils.py

Contains code to automatically handle which Data Generators to load. Contains functions to re-organize directories in a way that it can be fed to the Data Generators. Throws exceptions if requested dataset is not present.

model_utils.py

Has functions to create a new MLP Mixer model based on the architecture requested. Gets the number of classes, and optimizers etc

test_utils.py

Has functions to test the model, and plot Confusion matrices and ROC curves.

MLP_mixer.py

Has the whole model implementation for MLP mixer and other intermediate blocks. Handles whether we preatrain or finetune, and can reload from a local model.

Acknowledgements:

Reference : Original ML-Mixer Paper by Ilya Tolstikhin et al. [Google Research]

mlp-mixer-trainer's People

Contributors

arastogi1997 avatar

Watchers

Kostas Georgiou avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.