Code Monkey home page Code Monkey logo

barlow-twins-hsic's People

Contributors

yaohungt avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

barlow-twins-hsic's Issues

Result about tiny-imagenet

hi, have you run the model with the tiny-imagenet? can you tell the result on the tiny-imagenet as I try to run to the model on tiny-imagenet but it seems too low?

The feature normalization is necessary?

return F.normalize(feature, dim=-1), F.normalize(out, dim=-1)

Hi, your code is very helpful and I want to firstly appreciate the code share.

I have a question on whether this feature normalization is necessary (to make the cifar10 performance to about 92% accuracy).

The original Barlow Twins does not contain this step. On the other hand, they rather define all linear layers in the projector with no bias.

Question re: reproducing Fig 2 from the paper

Hello --

I'm interested in trying to reproduce the Barlow Twins curve from Fig 2 in the paper.

I'm running:

python main.py --lmbda 0.0078125 --corr_zero --batch_size 128 --feature_dim 128 --dataset cifar10

and getting:

Test Epoch: [5/1000] Acc@1:47.33% Acc@5:92.49%
Test Epoch: [10/1000] Acc@1:53.80% Acc@5:94.87%
Test Epoch: [15/1000] Acc@1:58.44% Acc@5:96.68%
Test Epoch: [20/1000] Acc@1:63.11% Acc@5:96.86%
Test Epoch: [25/1000] Acc@1:65.55% Acc@5:97.33%
Test Epoch: [30/1000] Acc@1:66.59% Acc@5:97.61%
Test Epoch: [35/1000] Acc@1:68.85% Acc@5:97.87%
Test Epoch: [40/1000] Acc@1:69.17% Acc@5:97.75%
Test Epoch: [45/1000] Acc@1:71.24% Acc@5:98.16%
Test Epoch: [50/1000] Acc@1:72.38% Acc@5:98.26%

In Fig 2, it looks like accuracy after 50 epochs should be ~ 79%, but I'm only getting to ~72%.

Any ideas why there might be a gap? Perhaps the accuracies reported in Fig 2 are from training a linear classifier (eg, in linear.py) rather than using the weighted KNN in main.py:train?

Thanks!

Speed Up Model by Using cudnn benchmark

Not really an issue, but adding the following made training significantly faster (+29% on Titan X Pascal)

if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True

I thought I would mention it as others might benefit from this as well.

Thanks for making your code available on GitHub!

Question Regarding Tranform

The implementation is simple and easy to use. Thank you for that. I have one doubt,

Given a mini batch with input x of size BxCxHxW

we apply transformations to get
y1 = self.transform(x)
y2 = self.transform(x)

So is this a batch transformation or image wise transformation

Because as per the paper "More specifically, it
produces two distorted views for all images of a batch X
sampled from a dataset" there are two distorted views only i interpret it as for one distorted view we apply the same transformation for the images in a batch

Inaccuracy in the cross correlation for small batch sizes (<32)

Hey,

first, thanks for sharing your research and code!

TL;DR

Your code uses torch.std which uses Bessel's correction by default, therefore inhibiting that the values on the diagonal reach 1.

While working with it, I noticed some small inaccuracy for in the calculation of the cross-correlation matrix.

Opposed to original implementation, which uses BatchNorm1d you implemented the normalization with:

# normalize the representations along the batch dimension
out_1_norm = (out_1 - out_1.mean(dim=0)) / out_1.std(dim=0)
out_2_norm = (out_2 - out_2.mean(dim=0)) / out_2.std(dim=0)

I implemented a small test with two identical vectors coming from the projection head and was therefore expecting straight ones on the diagonal. But as you can see from my attached code, for a batch size of < 32 (here 8), the values on the diagonal can't get bigger than 0.75. I found that torch.std uses Bessel's correction by default. When this flag is set to false, the numbers match with the original implementation.

I think there is no practical difference for batch sizes > 32, which is also the smallest batch size you presented in your paper, I think.

import torch
from torch import nn

batch_size = 4
size_z = 128

torch.manual_seed(1234)
z1 = torch.randn(batch_size, size_z)
z2 = z1.clone()

# your implementation
z1_norm = (z1 - torch.mean(z1, dim=0)) / torch.std(z1, dim=0)
z2_norm = (z2 - torch.mean(z2, dim=0)) / torch.std(z2, dim=0)
cross_corr = torch.matmul(z1_norm.T, z2_norm) / batch_size

print( cross_corr[0:5,0:5] )

# tensor([[ 0.7500, -0.1065, -0.0837,  0.0630,  0.3664],
#        [-0.1065,  0.7500, -0.2283, -0.3708, -0.5607],
#        [-0.0837, -0.2283,  0.7500, -0.5013, -0.2554],
#        [ 0.0630, -0.3708, -0.5013,  0.7500,  0.6334],
#        [ 0.3664, -0.5607, -0.2554,  0.6334,  0.7500]])

# original implementation
bn = nn.BatchNorm1d(size_z, affine=False)
z1_norm = bn(z1)
z2_norm = bn(z2)
cross_corr = z1_norm.T @ z2_norm / batch_size

print( cross_corr[0:5,0:5] )

# tensor([[ 1.0000, -0.1420, -0.1116,  0.0840,  0.4885],
#        [-0.1420,  1.0000, -0.3043, -0.4944, -0.7476],
#        [-0.1116, -0.3043,  1.0000, -0.6683, -0.3405],
#        [ 0.0840, -0.4944, -0.6683,  1.0000,  0.8445],
#        [ 0.4885, -0.7476, -0.3405,  0.8445,  1.0000]])

# corrected code (without Bessel’s correction for the calculation of the standard deviation)
z1_norm = (z1 - torch.mean(z1, dim=0)) / torch.std(z1, dim=0, unbiased=False)
z2_norm = (z2 - torch.mean(z2, dim=0)) / torch.std(z2, dim=0, unbiased=False)
cross_corr = torch.matmul(z1_norm.T, z2_norm) / batch_size

print( cross_corr[0:5,0:5] )

# tensor([[ 1.0000, -0.1421, -0.1116,  0.0840,  0.4885],
#         [-0.1421,  1.0000, -0.3043, -0.4944, -0.7476],
#         [-0.1116, -0.3043,  1.0000, -0.6683, -0.3405],
#         [ 0.0840, -0.4944, -0.6683,  1.0000,  0.8446],
#         [ 0.4885, -0.7476, -0.3405,  0.8446,  1.0000]])

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.