yaohungt / barlow-twins-hsic Goto Github PK
View Code? Open in Web Editor NEWLicense: MIT License
License: MIT License
hi, have you run the model with the tiny-imagenet? can you tell the result on the tiny-imagenet as I try to run to the model on tiny-imagenet but it seems too low?
Line 31 in a30baba
Hi, your code is very helpful and I want to firstly appreciate the code share.
I have a question on whether this feature normalization is necessary (to make the cifar10 performance to about 92% accuracy).
The original Barlow Twins does not contain this step. On the other hand, they rather define all linear layers in the projector with no bias.
Hello --
I'm interested in trying to reproduce the Barlow Twins curve from Fig 2 in the paper.
I'm running:
python main.py --lmbda 0.0078125 --corr_zero --batch_size 128 --feature_dim 128 --dataset cifar10
and getting:
Test Epoch: [5/1000] Acc@1:47.33% Acc@5:92.49%
Test Epoch: [10/1000] Acc@1:53.80% Acc@5:94.87%
Test Epoch: [15/1000] Acc@1:58.44% Acc@5:96.68%
Test Epoch: [20/1000] Acc@1:63.11% Acc@5:96.86%
Test Epoch: [25/1000] Acc@1:65.55% Acc@5:97.33%
Test Epoch: [30/1000] Acc@1:66.59% Acc@5:97.61%
Test Epoch: [35/1000] Acc@1:68.85% Acc@5:97.87%
Test Epoch: [40/1000] Acc@1:69.17% Acc@5:97.75%
Test Epoch: [45/1000] Acc@1:71.24% Acc@5:98.16%
Test Epoch: [50/1000] Acc@1:72.38% Acc@5:98.26%
In Fig 2, it looks like accuracy after 50 epochs should be ~ 79%, but I'm only getting to ~72%.
Any ideas why there might be a gap? Perhaps the accuracies reported in Fig 2 are from training a linear classifier (eg, in linear.py
) rather than using the weighted KNN in main.py:train
?
Thanks!
Not really an issue, but adding the following made training significantly faster (+29% on Titan X Pascal)
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
I thought I would mention it as others might benefit from this as well.
Thanks for making your code available on GitHub!
The implementation is simple and easy to use. Thank you for that. I have one doubt,
Given a mini batch with input x of size BxCxHxW
we apply transformations to get
y1 = self.transform(x)
y2 = self.transform(x)
So is this a batch transformation or image wise transformation
Because as per the paper "More specifically, it
produces two distorted views for all images of a batch X
sampled from a dataset" there are two distorted views only i interpret it as for one distorted view we apply the same transformation for the images in a batch
Hey,
first, thanks for sharing your research and code!
Your code uses torch.std which uses Bessel's correction by default, therefore inhibiting that the values on the diagonal reach 1.
While working with it, I noticed some small inaccuracy for in the calculation of the cross-correlation matrix.
Opposed to original implementation, which uses BatchNorm1d you implemented the normalization with:
Lines 36 to 38 in a30baba
I implemented a small test with two identical vectors coming from the projection head and was therefore expecting straight ones on the diagonal. But as you can see from my attached code, for a batch size of < 32 (here 8), the values on the diagonal can't get bigger than 0.75. I found that torch.std uses Bessel's correction by default. When this flag is set to false, the numbers match with the original implementation.
I think there is no practical difference for batch sizes > 32, which is also the smallest batch size you presented in your paper, I think.
import torch
from torch import nn
batch_size = 4
size_z = 128
torch.manual_seed(1234)
z1 = torch.randn(batch_size, size_z)
z2 = z1.clone()
# your implementation
z1_norm = (z1 - torch.mean(z1, dim=0)) / torch.std(z1, dim=0)
z2_norm = (z2 - torch.mean(z2, dim=0)) / torch.std(z2, dim=0)
cross_corr = torch.matmul(z1_norm.T, z2_norm) / batch_size
print( cross_corr[0:5,0:5] )
# tensor([[ 0.7500, -0.1065, -0.0837, 0.0630, 0.3664],
# [-0.1065, 0.7500, -0.2283, -0.3708, -0.5607],
# [-0.0837, -0.2283, 0.7500, -0.5013, -0.2554],
# [ 0.0630, -0.3708, -0.5013, 0.7500, 0.6334],
# [ 0.3664, -0.5607, -0.2554, 0.6334, 0.7500]])
# original implementation
bn = nn.BatchNorm1d(size_z, affine=False)
z1_norm = bn(z1)
z2_norm = bn(z2)
cross_corr = z1_norm.T @ z2_norm / batch_size
print( cross_corr[0:5,0:5] )
# tensor([[ 1.0000, -0.1420, -0.1116, 0.0840, 0.4885],
# [-0.1420, 1.0000, -0.3043, -0.4944, -0.7476],
# [-0.1116, -0.3043, 1.0000, -0.6683, -0.3405],
# [ 0.0840, -0.4944, -0.6683, 1.0000, 0.8445],
# [ 0.4885, -0.7476, -0.3405, 0.8445, 1.0000]])
# corrected code (without Bessel’s correction for the calculation of the standard deviation)
z1_norm = (z1 - torch.mean(z1, dim=0)) / torch.std(z1, dim=0, unbiased=False)
z2_norm = (z2 - torch.mean(z2, dim=0)) / torch.std(z2, dim=0, unbiased=False)
cross_corr = torch.matmul(z1_norm.T, z2_norm) / batch_size
print( cross_corr[0:5,0:5] )
# tensor([[ 1.0000, -0.1421, -0.1116, 0.0840, 0.4885],
# [-0.1421, 1.0000, -0.3043, -0.4944, -0.7476],
# [-0.1116, -0.3043, 1.0000, -0.6683, -0.3405],
# [ 0.0840, -0.4944, -0.6683, 1.0000, 0.8446],
# [ 0.4885, -0.7476, -0.3405, 0.8446, 1.0000]])
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.