Code Monkey home page Code Monkey logo

cutmix's Introduction

cutmix

a Ready-to-use PyTorch Extension of Unofficial CutMix Implementations.

This re-implementation is improved in some parts,

  • Fixing issue #1 in the original repository
  • issue #3 : Random crop regions are randomly chosen, even within the same batch.
  • issue #4 : Different lambda values(sizes of crop regions) are randomly chosen, even within the same batch.
  • Images to be cropped are randomly chosen in the whole dataset. Original implementation selects images only inside the same batch(shuffling).
  • Easy to install and use on your existing project.
  • With additional augmentations(fast-autoaugment), the performances are improved further.

Hence, there may be slightly-improved training results also.

Requirements

  • python3
  • torch >= 1.1.0

Install

This repository is pip-installable,

$ pip install git+https://github.com/ildoonet/cutmix

or you can copy 'cutmix' folder to your project to use it.

Usage

Our CutMix is inhereted from the PyTorch Dataset class so you can wrap your own dataset(eg. cifar10, imagenet, ...). Also we provide CutMixCrossEntropyLoss, soft version of cross-entropy loss, which accept soft-labels required by cutmix.

from cutmix.cutmix import CutMix
from cutmix.utils import CutMixCrossEntropyLoss
...

dataset = datasets.CIFAR100(args.cifarpath, train=True, download=True, transform=transform_train)
dataset = CutMix(dataset, num_class=100, beta=1.0, prob=0.5, num_mix=2)    # this is paper's original setting for cifar.
...

criterion = CutMixCrossEntropyLoss(True)
for _ in range(num_epoch):
    for input, target in loader:    # input is cutmixed image's normalized tensor and target is soft-label which made by mixing 2 or more labels.
        output = model(input)
        loss = criterion(output, target)
    
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

Result

PyramidNet-200 + ShakeDrop + CutMix \w CIFAR-100

Top-1 Error(@300epoch) Top-1 Error(Best) Model File
Paper's Reported Result N/A 13.81 N/A
Our Re-implementation 13.68 13.15 Download(12.88)
+ Fast AutoAugment 13.3 12.95

We ran 6 indenpendent experiments with our re-implemented codes and got top-1 errors of 13.09, 13.29, 13.27, 13.24, 13.15 and 12.88, using below command. (Converged at 300epoch with the top-1 errors of 13.55, 13.66, 13.95, 13.9, 13.8 and 13.32.)

$ python train.py -c conf/cifar100_pyramid200.yaml

ResNet + CutMix \w ImageNet

Top-1 Error
(@300epoch)
Top-1 Error
(Best)
Model File
ResNet18 Reported Result \wo CutMix N/A 30.43
Ours 29.674 29.56
ResNet34 Reported Result \wo CutMix N/A 26.456
Ours 24.7 24.57 Download
ResNet50 Paper's Reported Result N/A 21.4 N/A
Author's Code(Our Re-run) 21.768 21.586 N/A
Our Re-implementation 21.524 21.340 Download(21.25)
ResNet200 Our Re-implementation
+ Fast AutoAugment 19.058 18.858
$ python train.py -c conf/imagenet_resnet50.yaml

We ran 5 independent experiments on ResNet50.

  • Author's codes

    • 300epoch : 21.762, 21.614, 21.762, 21.644, 21.810
    • best : 21.56, 21.556, 21.666, 21.498, 21.648
  • Our Re-implementation

    • 300epoch : 21.53, 21.408, 21.55, 21.4, 21.73
    • best : 21.392, 21.328, 21.386, 21.256, 21.34

Reference

cutmix's People

Contributors

ildoonet avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

cutmix's Issues

Reproducibility on Resnet-110 on Cifar-100

Hello,
Thank you for the nice highly modular code.

I ran your code on Cifar-100 using Resnet-110.
I used the default parameters given there: clovaai/CutMix-PyTorch#10.

dataset: cifar100
net_type: resnet
depth: 110

epochs: 300
batch_size: 64
lr: 0.1
momentum: 0.9

weight_decay: 0.0001

cutmix:
beta: 1.0
prob: 1.0
num: 1

Here is what I got:
try1: 20.36 4.66
try2: 20.21 4.67
try3: 20.30 4.65

However, 20.11, 4.43 are reported in the original paper (in Table 6).
Did you manage to reproduce those scores ? maybe it is due to the slight differences between you and the official implementation (diverse lambdas, diverse crops and boundary effect debugged) ?

'CutMix' object is not callable

I have installed and tried to use cutmix for training CIFAR100
`
def load_dataset(dataset):
train_transform = T.Compose([
T.RandomHorizontalFlip(),
T.RandomCrop(size=32, padding = 4), #ImageNet 224
T.ToTensor(),
T.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), # T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) # CIFAR-100

    CutMixCollator(alpha = 0.5) 
])
test_transform = T.Compose([
    T.ToTensor(),
    T.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])

if dataset == 'cifar100':
data_train = CIFAR100('../cifar100', train=True, download=True, transform=train_transform)
data_unlabeled = MyDataset(dataset, True, test_transform)
data_test = CIFAR100('../cifar100', train=False, download=True, transform=test_transform)
NO_CLASSES = 100
adden = ADDENDUM
no_train = NUM_TRAIN

`
in default_collate
return torch.stack(batch, 0, out=out)
TypeError: stack(): argument 'tensors' (position 1) must be tuple of Tensors, not Tensor

Dimension issue in correct prediction

I am trying to use CutMix in my network. But when run the code i got error on the line
correct += (lam * predicted.eq(targets.data).cpu().sum()
about the dimension of predicted value and targeted value.
predicted.size() = [64]
targets.size() = [64,7] because target is one hot. So how i can match the dimension of both ?

Getting empty dataset on using Cutmix Dataloader

I have used this class for creting dataset class for my flower data

defining dataset

from PIL import Image
import cv2
import albumentations
import torch
import numpy as np
import io
from torch.utils.data import Dataset

class FlowerDataset(Dataset):
    def __init__(self, id , classes , image , img_height , img_width, mean , std , is_valid):
        self.id = id
        self.classes = classes
        self.image = image
        if is_valid == 1:
            self.aug = albumentations.Compose([
               albumentations.Resize(img_height , img_width, always_apply = True) ,
               albumentations.Normalize(mean , std , always_apply = True) 
            ])
        else:
            self.aug = albumentations.Compose([
                albumentations.Resize(img_height , img_width, always_apply = True) ,
                albumentations.Normalize(mean , std , always_apply = True),
                albumentations.ShiftScaleRotate(shift_limit = 0.0625,
                                                scale_limit = 0.1 ,
                                                rotate_limit = 5,
                                                p = 0.9)
            ]) 
        
    def __len__(self):
        return len(self.id)
    
    def __getitem__(self, index):
        id = self.id[index]
        img = np.array(Image.open(io.BytesIO(self.image[index]))) 
        img = cv2.resize(img, dsize=(128, 128), interpolation=cv2.INTER_CUBIC)
        img = self.aug(image = img)['image']
        img = np.transpose(img , (2,0,1)).astype(np.float32)
       
        
        return {
            'image' : torch.tensor(img, dtype = torch.float),
            'class' : torch.tensor(self.classes[index], dtype = torch.long) 
        } 

then did the sanity check to ensure its good to go

# sanity check for FlowerDataset class created

train_dataset = FlowerDataset(id = train_ids, classes = train_class, image = train_images, 
                        img_height = 128 , img_width = 128, 
                        mean = (0.485, 0.456, 0.406),
                        std = (0.229, 0.224, 0.225) , is_valid = 0)

import matplotlib.pyplot as plt
%matplotlib inline

idx = 0
img = train_dataset[idx]['image']

print(train_dataset[idx]['class'])

npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1,2,0)))

image

& then I did

# setting up the dataloader with cutmix data agumentation
!pip install git+https://github.com/ildoonet/cutmix

# setting up the train data loader

from cutmix.cutmix import CutMix

train_dataloader = CutMix(train_dataset, 
                          num_class=104, 
                          beta=1.0, 
                          prob=0.5, 
                          num_mix=2)

It worked successfully.
but when I did the sanity check as:-->

batch = next(iter(train_dataloader))
len(batch)

it returned
image

and thereby I am unable to train the model

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

For the loss function I am gtting the error

setting up the training function

loss_fn = CutMixCrossEntropyLoss(True)
if __name__ == "__main__":
    
    set_parameters_requires_grad(model , True)

    epochs = 25

    for epoch in range(epochs):
        print('Epoch ', epoch,'/',epochs-1)
        print('-'*15)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0.0
            
            # Iterate over data.
            for inputs,labels in dataloaders[phase]:
                if torch.cuda.is_available():
                    inputs = inputs.cuda()
                    labels = labels.cuda()

                # zero the parameter gradients
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = loss_fn(outputs, labels)

                    # we backpropagate to set our learning parameters only in training mode
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == torch.argmax(labels)) # (preds == labels.data) as the usage of .data is not recommended, as it might have unwanted side effect.

            # scheduler for weight decay
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / float(dataset_sizes[phase])
            epoch_acc = running_corrects / float(dataset_sizes[phase])

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
    optimizer.swap_swa_sgd()       

image

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.