Code Monkey home page Code Monkey logo

Comments (13)

tmp-iclr avatar tmp-iclr commented on September 7, 2024 2

@dmezh We included the learning rate schedule in this repo, though you kind of need to hunt through code to find it. The most important line is this one, which I'll paste below:

sched = lambda t, lr_max: np.interp([t], [0, self.t_initial*2//5, self.t_initial*4//5, self.t_initial], 
                                      [0, lr_max, lr_max/20.0, 0])[0]

Where t_initial should be the total number of epochs you're going to train for, lr_max is your learning rate (we used 0.01 everywhere), and t should be the_current_epoch + (batch_idx + 1) / batches_per_epoch with indices starting at 0.

That said, I think you'll get approximately the same results using something more standard like cosine decay with one cycle.

Let me know if you have any other questions about your replication! Given the interest in our CIFAR-10 results, we'll try to release a more compact training script and model weights for it sometime soon (but in PyTorch).

from convmixer.

tmp-iclr avatar tmp-iclr commented on September 7, 2024 1

Glad to help. I used https://github.com/knjcode/cifar2png to construct the dataset. I'll see if there's any difference with the one you linked.

By the way, the model I ran with your settings and -b 64 ended up getting 97% accuracy.

from convmixer.

tmp-iclr avatar tmp-iclr commented on September 7, 2024 1

No problem! Glad we figured it out, and thanks.

I'm going to go ahead and close this issue, but feel free to reopen it or open another if you have more questions (likewise, @dmezh).

from convmixer.

tmp-iclr avatar tmp-iclr commented on September 7, 2024

Thanks for pointing this out!

I think the key parameter we didn't clearly specify for CIFAR-10 is the allowable "scale" for random cropping. The default parameter setting in timm allows the crop to be as low as 8% the original area of the image (before being resized to the original shape). We thought this didn't make sense for 32x32 CIFAR-10 images, so we changed this to 75%. It also would probably be a good idea to specify the CIFAR-10 mean and standard deviation, though I don't think this will change much.

In particular, try adding the following flags: --scale 0.75 1.0 --mean 0.4914 0.4822 0.4465 --std 0.2471 0.2435 0.2616.

I've also updated the README to mention this.

from convmixer.

K-H-Ismail avatar K-H-Ismail commented on September 7, 2024

Hello, thanks for your help. Indeed the crop parameter has an effect on the accuracy, with --scale 0.75 1.0 we could reach up to 93.94% accuracy with convmixer256/16. This is still below the announced 96.74% of the paper, in order to reach this last accuracy, I've tried different batch sizes, adding more epochs ... But still I couldn't manage to get it. Could you please give me a hint ?
This is the timm command I used, and the model is implemented as in the paper:

sh distributed_train.sh 2 
--dataset cifar10     
/path/CIFAR-10-images/     
--train-split /path/CIFAR-10-images/train      
--val-split /path/CIFAR-10-images/test    
--model convmixer_256_16       
-b 128   
-j 2     
--opt adamw     
--epochs 200   
--amp     
--input-size 3 32 32   
--lr 0.01          
--num-classes 10     
--warmup-epochs 0  
--weight-decay 0.01 
--sched onecycle   
--opt-eps=1e-3     
--clip-grad 1.0 
--scale 0.75 1.0 
--mean 0.4914 0.4822 0.4465 
--std 0.2471 0.2435 0.261  
--aa rand-m9-mstd0.5-inc1     
--cutmix 0.5     
--mixup 0.5     
--reprob 0.25     
--remode pixel

from convmixer.

tmp-iclr avatar tmp-iclr commented on September 7, 2024

What patch size and kernel size are you using? I think this should easily achieve >96% if patch_size=1 and around 95% for patch_size=2 for large kernels. We used batch size 128 (whereas yours is 2*128), but I'm not sure if that would cause such a big difference. I have trained a ConvMixer-256/16 with patch_size=1 and kernel_size=9 on CIFAR-10 with almost the same settings (except for batch size) that achieved 96% by epoch 140/200.

I'll see if increasing the batch size would actually have such a significant effect and get back to you.

from convmixer.

K-H-Ismail avatar K-H-Ismail commented on September 7, 2024

Hello, I use a patch size of 1 and a kernel size of 9. Tried smaller batch size (64) and it changed nothing.

def ConvMixer(dim, depth, kernel_size=9, dilation=1, patch_size=7, n_classes=1000):
    return nn.Sequential(
        nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
        nn.GELU(),
        nn.BatchNorm2d(dim),
        *[nn.Sequential(
                Residual(nn.Sequential(
                    nn.Conv2d(dim, dim, kernel_size, dilation=dilation, groups=dim, padding=dilation * (kernel_size - 1) // 2),
                    nn.GELU(),
                    nn.BatchNorm2d(dim)
                )),
                nn.Conv2d(dim, dim, kernel_size=1),
                nn.GELU(),
                nn.BatchNorm2d(dim)
        ) for i in range(depth)],
        nn.AdaptiveAvgPool2d((1,1)),
        nn.Flatten(),
        nn.Linear(dim, n_classes)
    )

and I use this model

@register_model
def convmixer_256_16(pretrained=False, **kwargs):
    model = ConvMixer(256, 16, kernel_size=9, patch_size=1, n_classes=10)
    model.default_cfg = _cfg
    return model

from convmixer.

dmezh avatar dmezh commented on September 7, 2024

Hi, the paper mentions using a "simple triangular learning rate schedule" - we're trying to replicate your work on CIFAR-10 (in TensorFlow) - we're wondering which LR schedule and parameters you used for the results in Table B. Thank you!

from convmixer.

tmp-iclr avatar tmp-iclr commented on September 7, 2024

@K-H-Ismail I'm currently training the same model as you using a freshly-cloned version of this repo with the same parameters, other than the batch size which I have set to -b 64. It's on epoch 114/200 and has already reached 94.5% accuracy. The reason for the difference isn't clear to me...

from convmixer.

K-H-Ismail avatar K-H-Ismail commented on September 7, 2024

Hello @tmp-iclr,
Thanks for your time and support. The only thing we have not checked so far and that may differ is the dataset itself: as timm uses raw images for Imagenet and the official pytorch Cifar10 dataset is made directly into downloadable batches, I tried to download the raw Cifar10 images from this repository :
https://github.com/YoongiKim/CIFAR-10-images

Is it the same for you ?

from convmixer.

tmp-iclr avatar tmp-iclr commented on September 7, 2024

So I'm training the same model on the dataset you used, and it does indeed seem to be lagging behind by a few percent. It's too early to say for sure, but this might be the problem...

from convmixer.

tmp-iclr avatar tmp-iclr commented on September 7, 2024

Upon inspection, it looks like the CIFAR dataset you used has substantial JPEG artifacts -- the images actually look noticeably less sharp and colorful. I'm now pretty sure the dataset discrepancy is, in fact, the problem.

from convmixer.

K-H-Ismail avatar K-H-Ismail commented on September 7, 2024

Hello,
Indeed, the dataset was the problem, sorry for that! Usually I am not very luck when reproducing baselines 🤷‍♀️. I will make an issue on the repository https://github.com/YoongiKim/CIFAR-10-images. Thanks for your help and very good article by the way!

from convmixer.

Related Issues (16)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.