FMix
This repository contains the official implementation of the paper 'Understanding and Enhancing Mixed Sample Data Augmentation'
ArXiv • Papers With Code • About • Getting Started • Pre-trained Models (Coming Soon)
About
FMix is a variant of MixUp, CutMix, etc. introduced in our paper 'Understanding and Enhancing Mixed Sample Data Augmentation'. It uses masks sampled from Fourier space to mix training examples. Here's an example:
Getting Started
The core implementation of FMix
uses numpy
and can be found in fmix.py
. We provide bindings for this in PyTorch (with Torchbearer or PyTorch-Lightning) and Tensorflow.
Torchbearer
The FMix
callback in torchbearer_implementation.py
can be added directly to your torchbearer code:
from implementations.torchbearer_implementation import FMix
fmix = FMix()
trial = Trial(model, optimiser, fmix.loss(), callbacks=[fmix])
See an example in test_torchbearer.py
.
PyTorch-Lightning
For PyTorch-Lightning, we provide a class, FMix
in lightning.py
that can be used in your LightningModule
:
from implementations.lightning import FMix
class CoolSystem(pl.LightningModule):
def __init__(self):
...
self.fmix = FMix()
def training_step(self, batch, batch_nb):
x, y = batch
x = self.fmix(x)
x = self.forward(x)
loss = self.fmix.loss(x, y)
return {'loss': loss}
See an example in test_lightning.py
.
Tensorflow
For Tensorflow, we provide a class, FMix
in tensorflow_implementation.py
that can be used in your tensorflow code:
from implementations.tensorflow_implementation import FMix
fmix = FMix()
def loss(model, x, y, training=True):
x = fmix(x)
y_ = model(x, training=training)
return tf.reduce_mean(fmix.loss(y_, y))
See an example in test_tensorflow.py
.
Pre-trained Models
COMING SOON