Code Monkey home page Code Monkey logo

fmix's Introduction

FMix

This repository contains the official implementation of the paper 'FMix: Enhancing Mixed Sampled Data Augmentation'

PWC PWC

ArXivPapers With CodeAboutExperimentsImplementationsPre-trained Models

Dive in with our example notebook in Colab!

About

FMix is a variant of MixUp, CutMix, etc. introduced in our paper 'FMix: Enhancing Mixed Sampled Data Augmentation'. It uses masks sampled from Fourier space to mix training examples. Take a look at our example notebook in colab which shows how you can generate masks in two dimensions

and in three!

Experiments

Core Experiments

Shell scripts for our core experiments can be found in the experiments folder. For example,

bash cifar_experiment cifar10 resnet fmix ./data

will train a PreAct-ResNet18 on CIFAR-10 with FMix. More information can be found at the start of each of the shell files.

Additional Experiments

All additional classification experiments can be run via trainer.py

Analyses

For Grad-CAM, take a look at the Grad-CAM notebook in colab.

For the other analyses, have a look in the analysis folder.

Implementations

The core implementation of FMix uses numpy and can be found in fmix.py. We provide bindings for this in PyTorch (with Torchbearer or PyTorch-Lightning) and Tensorflow.

Torchbearer

The FMix callback in torchbearer_implementation.py can be added directly to your torchbearer code:

from implementations.torchbearer_implementation import FMix

fmix = FMix()
trial = Trial(model, optimiser, fmix.loss(), callbacks=[fmix])

See an example in test_torchbearer.py.

PyTorch-Lightning

For PyTorch-Lightning, we provide a class, FMix in lightning.py that can be used in your LightningModule:

from implementations.lightning import FMix

class CoolSystem(pl.LightningModule):
    def __init__(self):
        ...
        
        self.fmix = FMix()
    
    def training_step(self, batch, batch_nb):
        x, y = batch
        x = self.fmix(x)

        x = self.forward(x)

        loss = self.fmix.loss(x, y)
        return {'loss': loss}

See an example in test_lightning.py.

Tensorflow

For Tensorflow, we provide a class, FMix in tensorflow_implementation.py that can be used in your tensorflow code:

from implementations.tensorflow_implementation import FMix

fmix = FMix()

def loss(model, x, y, training=True):
    x = fmix(x)
    y_ = model(x, training=training)
    return tf.reduce_mean(fmix.loss(y_, y))

See an example in test_tensorflow.py.

Pre-trained Models

We provide pre-trained models via torch.hub (more coming soon). To use them, run

import torch
model = torch.hub.load('ecs-vlc/FMix:master', ARCHITECTURE, pretrained=True)

where ARCHITECTURE is one of the following:

CIFAR-10

PreAct-ResNet-18

Configuration ARCHITECTURE Accuracy
Baseline 'preact_resnet18_cifar10_baseline' --------
+ MixUp 'preact_resnet18_cifar10_mixup' --------
+ FMix 'preact_resnet18_cifar10_fmix' --------
+ Mixup + FMix 'preact_resnet18_cifar10_fmixplusmixup' --------

PyramidNet-200

Configuration ARCHITECTURE Accuracy
Baseline 'pyramidnet_cifar10_baseline' 98.31
+ MixUp 'pyramidnet_cifar10_mixup' 97.92
+ FMix 'pyramidnet_cifar10_fmix' 98.64

ImageNet

ResNet-101

Configuration ARCHITECTURE Accuracy (Top-1)
Baseline 'renset101_imagenet_baseline' 76.51
+ MixUp 'renset101_imagenet_mixup' 76.27
+ FMix 'renset101_imagenet_fmix' 76.72

fmix's People

Contributors

ethanwharris avatar mattpainter01 avatar antoniamarcu avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.