xxxnell / how-do-vits-work Goto Github PK
View Code? Open in Web Editor NEW(ICLR 2022 Spotlight) Official PyTorch implementation of "How Do Vision Transformers Work?"
Home Page: https://arxiv.org/abs/2202.06709
License: Apache License 2.0
(ICLR 2022 Spotlight) Official PyTorch implementation of "How Do Vision Transformers Work?"
Home Page: https://arxiv.org/abs/2202.06709
License: Apache License 2.0
Thank you for the amazing work! The summary organizes the main points of the paper really well and helps facilitate future research.
I would like to visualize the loss landscape for my model and I am trying the Python notebook in this repository on Colab. Apparently it is a long process and I would like to have a rough estimation on the time it will take for my model.
Could you kindly advise how long it took for you to run the notebook on Colab or on your machine?
안녕하세요. 논문 흥미롭게 잘 읽었습니다. 좋은 논문 감사드립니다.
∆ Log amplitude 관련 질문을 드리려고 합니다.
Figure 2 (a) 그래프에서, 특정 frequency에서의 ∆ Log amplitude가 정확히 무엇을 뜻하는지 조금 헷갈립니다.
예를 들어, ResNet에서 0.5π 부분의 ∆ Log amplitude는 약 -6이며, 이 -6이라는 것은 0.0π의 amplitude와 0.5π의 amplitude의 상대적인 크기 차이라고 저는 이해했습니다.
하지만 Figure 2의 "∆ Log amplitude is the difference between the log amplitude at normalized frequency 0.0π (center) and at 1.0π (boundary)."라는 문장을 봤을 때는, 단순히 ∆ Log amplitude가 그래프의 모든 부분에서 0.0π의 amplitude와 1.0π의 amplitude의 상대적인 크기 차이를 뜻하는 것이라고 생각했습니다.
따라서 ∆ Log amplitude에 대한 제 이해가 논문에 쓰여져 있는 것과 다른 것 같은데, 혹시 제가 잘못 이해하고 있는 것인지 여쭙고 싶습니다.
미리 감사드립니다.
Hi, awesome work and really good points about MSAs! I'm very much interested in the AlterNet mentioned in the paper(based on ResNet-50 and SwinTBlock), but I cant find the implementation of it in this repo. Did I miss? If not, can you release the code maybe?
Thanks a lot!
https://user-images.githubusercontent.com/930317/158025258-e9a5a454-99de-4d22-bc93-b217cdf06abb.jpeg
Where can I find other pictures?
Looking at this figure, I'm seeing that the early layers of ResNet has many low-freq components, and the deeper ResNet goes, it contains more high-freq components. Am I interpreting this figure right?
If I'm right, isn't this a little contradict to popular belief and visualization? That early layers in a ConvNet tend to learn high-freq components?
Hi, thank you for the great paper. Could you please release the code or give implementation example of plotting "Relative log amplitudes of Fourier transformed feature maps". Thanks!
In figure 1 of the paper, authors stated that MSA flattens the loss landscape, however, in When Vision Transformer outperform ResNets without pre-training or strong data augmentation, they stated that ViT converge at sharp local minima, which is contrast to your findings?
Furthermore, authors claim that "The magnitude of the Hessian eigenvalues of ViT is smaller than that of ResNet during training
phase" (Fig 1 still). However, in above paper, the "Hessian dominate eigenvalue" of ViT are "orders of magnitude larger
than that of ResNet" (Table 1).
hello! How is the relative log magnitude calculated? Is the first layer subtracted from the feature map of each layer?
hello,i have aquestion about why you use vit-s and vit-tiny,and counterpart is resnet-50,these size is not equal.i know you have explained on openview,i want to know whether vit-base's matrix eigenvalue spectrum is like vit-tiny in your paper,just stretch to the right.
I understood that in the loss landscape visualization the z-axis is NLL. I'm curious what the x-axis and y-axis mean. Of course, we can see in loss_landscapes.py
how the x and y values participate in the calculation, but I don't have an intuitive understanding of it.
xs = np.linspace(x_min, x_max, n_x)
ys = np.linspace(y_min, y_max, n_y)
ratio_grid = np.stack(np.meshgrid(xs, ys), axis=0).transpose((1, 2, 0))
print(ratio_grid)
metrics_grid = {}
for ratio in ratio_grid.reshape([-1, 2]):
print(ratio)
ws = copy.deepcopy(ws0)
gs = [{k: r * bs[k] for k in bs} for r, bs in zip(ratio, bases)]
gs = {k: torch.sum(torch.stack([g[k] for g in gs]), dim=0) + ws[k] for k in gs[0]}
print(gs)
model.load_state_dict(gs)
print("Grid: ", ratio, end=", ")
*metrics, cal_diag = tests.test(model, n_ff, dataset, transform=transform,
cutoffs=cutoffs, bins=bins, verbose=verbose, period=period, gpu=gpu)
l1, l2 = norm.l1(model, gpu).item(), norm.l2(model, gpu).item()
metrics_grid[tuple(ratio)] = (l1, l2, *metrics)
return metrics_grid
Thank you sincerely.
hello,
Thank you for your great work!
I wonder how you get the feature map variances. According to my understanding, you first need to extract representations of all the samples, which should give us a vector with a length of D (let's just fatten the 2d tensor or concatenate all tokens). Then you calculate the variance of each element in this vector over all the samples, which should give us D variances. Finally, you take the mean value of all D variances and get the variance ready to report.
Did I get you correctly? Sorry if I didn't catch up with your existing documentation or description.
Thank you and I'm looking forward to your reply.
Best,
Hi, while trying to setup an alternet_18 to train on CIFAR10 I used the default config in models/alternet.py, which would be the following.
AlterNet(preresnet_dnn.BasicBlock, AttentionBasicBlockB, stem=partial(StemB, pool=stem),
num_blocks=(2, 2, 2, 2), num_blocks2=(0, 1, 1, 1), heads=(3, 6, 12, 24),
num_classes=num_classes, name=name, **block_kwargs)``
Upon doing so I get the following error
Input tensor shape: torch.Size([128, 128, 4, 4]). Additional info: {'p1': 7, 'p2': 7}.
Shape mismatch, can't divide axis of length 4 in chunks of 7
which is thrown by
x = rearrange(x, "b c (n1 p1) (n2 p2) -> (b n1 n2) c p1 p2", p1=p, p2=p)
in the Class LocalAttention.
This is happening because the default window size is 7, which doesn't work for 3 x 32 x 32 input images of CIFAR10. Could you point me to a setup used to train AlterNet for CIFAR10/100 images?
Thank you
안녕하세요, 해당 논문 매우 재밌게 잘 읽었습니다.
관련해서, Hassian max eegenvalues spectra를 논문처럼 구현해보고 싶은데 혹시 이와 관련된 코드는 어디에서 확인 할 수 있을까요?
미리 답변 감사드립니다!
Can you provide the pretrained model of Alternet for ImageNet1k-C? Thanks !
thanks for your great work!
notice that you have set drop some value during the training with sd=0.1。
did you do some exps to analysis the influence between the drop ratio ?
Does there exist a TF implementation of AlterNet?
Would be a great contribution to the field as there are so many that uses TF, me included.
안녕하세요,
논문을 통해서 Neural Network Visualization에 관심을 갖게 되었습니다. 좋은 논문 써 주셔서 감사합니다.
Github에 다른 Visualization 관련 Tutorial이 잘 제공되어 확인할 수 있었는데 Hessian Max eigenvalue 관련 코드의 경우 따로 메일을 남기면 관련 자료를 보내주시는 것 같아 이슈를 남기게 되었습니다. 혹시 아래 메일로 관련 내용을 보내주실 수 있을까요?
제 메일 주소는 [email protected] 입니다. 감사합니다.
The original ViT and many ViT variants have feed-forward in their architectures. I noticed that feed-forward is neither mentioned in the paper nor implemented in the code of AlterNet. It would be interesting to learn about the intuitions behind such a design choice.
When i run the forward function of LocalAttention
class, some errors occurred.
x.shape = [1,128,84,64] and self.window_size=8.
The rearrange
function can not run in the right way as n1=84//8 can not be divisible.
If i change the window_size=7/6/5, there may be other img's height or width can not be divisible.
I also try dynamic set window_size but it didn't succeed.
The image come from coco datasets.
Do you have any good suggestions ?
The code is
b, c, h, w = x.shape
p = self.window_size
n1 = h // p
n2 = w // p
mask = torch.zeros(p ** 2, p ** 2, device=x.device) if mask is None else mask
mask = mask + self.pos_embedding[self.rel_index[:, :, 0], self.rel_index[:, :, 1]]
x = rearrange(x, "b c (n1 p1) (n2 p2) -> (b n1 n2) c p1 p2", p1=p, p2=p)
x, attn = self.attn(x, mask)
x = rearrange(x, "(b n1 n2) c p1 p2 -> b c (n1 p1) (n2 p2)", n1=n1, n2=n2, p1=p, p2=p)
In the paper authors stated that: "MSAs are low-pass filters, but Convs are high-pass filters". And authors proposed how to harmonize Convs with MSAs: by replacing Convs at (preferable) the end of a stage. And authors also have the idea that: "uses Convs in early stages and MSAs in late stages".
Sorry in advance if these following questions of mine are dumb.
In the late stage, adding Convs after MSAs should decreases the performance of a model right? Since the late stages produces low-frequency features, and adding Convs there suppress those features? I did an experiments: I trained a hierarchical ViT, Segformer, then replace the last stage 1x1 Conv in the decoder with a 3x3 Conv (pic below)
I trained the model on a Polyp Segmentation dataset, reported results below:
Model | Dice Score |
---|---|
Segformer | 84.95 |
Modified Segformer | 84.61 |
I haven't test if replacing the 1x1 Conv in stage 1-2 with 3x3 Conv will increases the performance, but is the conclusion I made above correct?
Hello.
Thanks for contributing to this project.
Could u share the hessian max eigenvalue spectrum code?
Thank you
Authors stated in the paper that: "Contrary to popular belief, the long-range dependency hinders NN optimization.". However, recent models that adopts long-range dependency achieves really great results like: VAN, ConvNeXt or RepLKNet
Therefore, the statement I mentioned above seems a little bit wrong? I know there's an issue that discuss about large kernel Conv, however, the issue did not mention the statement above.
Moreover, the Experiments in Fig 7, you use Convolutional SANs. This model has 2 variants: 1D-CSANs and 2D-CSANs. The one you are doing experiments on is 2D-CSANs right? It not only consider the interaction among tokens in a single, but also consider the interaction among different heads. The "long-range dependency" is still very beneficial in the 1D-CSANs (Fig below), which typically, is what I consider the true "long-range dependency" in Self-attention.
When using 2D-CSANs, it considers both aspects: interaction among heads, and tokens, which brings negative performance when scaling up window sizes. The results is align with Convolutional SANs paper.
However, I don't consider 2D-CSANs negative performance when scaling up window sizes is: "long-range dependency hinders NN optimization" since it consider 2 aspects in the model. Sorry for writing this long, if you don't understand any parts in my question, I can clarify it for you
how-do-vits-work/models/convit.py
Lines 1 to 6 in 8752f4e
You said it's not the same with ConVit by d'Ascoli, Stéphane, et al. Then where does this ConVit come from? I ask because if I reuse this code, I want to know whom I should cite.
In the paper, authors state that "A key feature of MSAs is data specificity (not long-range dependency)".
Can you explain about the "data specificity" part? What is it, and how it behaves?
Further more, can you elaborate how MSAs (through visualization, formulas, etc) achieves data specificity
seems that the files below are not available
url = "https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/resnet_50_cifar100_691cc9a9e4.pth.tar"
path = "checkpoints/resnet_50_cifar100_691cc9a9e4.pth.tar"
url = "https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/vit_ti_cifar100_9857b21357.pth.tar"
path = "checkpoints/vit_ti_cifar100_9857b21357.pth.tar"
Hi @xxxnell ,
I find it hard to understand the conclusions about the lesion study. For example, VIT has not satisfied your conclusion (i.e., the latter MSA is more important)
What are the total number of parameters and total FLOPS of AlterNet for CIFAR-100 and ImageNet dataset?
Hi, thanks for your great work. I'd like to discuss the L2 Loss problem in loss landscape visualization. I found that your calculated L2 loss is significantly larger (10x) than the classification loss so the landscape visualization is basically a visualization of L2 Loss.
In fact, "weight decay" is slightly different from "L2 Loss" in Pytorch in implementation. Simply calculating the sum of norms as L2 loss is different from applying weight decays in Adam-like fancy optimizers in Pytorch. See blogs in https://bbabenko.github.io/weight-decay/.
Although one might find L2 Loss is significantly larger than the classification loss. In fact, in the practice of ViT, the weight decay loss does not dominate the classification loss, this is due to the implementation of weight decay in Pytorch.
Hello author, your work has brought me a lot, but when I read the paper, I am very interested in the drawing process of Figure 1(b)[Trajectories in polar coordinates]. Can you open source it?
Best regards to you
- Alternately replace Conv blocks with MSA blocks from the end of a baseline CNN model.
- If the added MSA block does not improve predictive performance, replace a Conv block located at the end of an earlier stage with an MSA.
- Use more heads and higher hidden dimensions for MSA blocks in late stages.
I suppose the above rules apply to high-level computer vision tasks such as classifications that involve only downsampling. I wonder how these rules differ for tasks involving upsampling stages such as image generation from latent or segmentation with U-Net. In particular, I am interested in (1) the ordering of Conv and MSA blocks and (2) the number of heads and hidden dimensions in upsampling stages.
Based on your findings that Convs are high-pass and MSAs are low-pass filters, I suppose the ordering of Conv-MSA
blocks should hold for both downsampling and upsampling stages instead of MSA-Conv
blocks in upsampling.
Since downsampling stages usually reduce the spatial resolution and increase the channel dimension, the third rule makes sense.
However, upsampling stages usually increase the spatial resolution and reduce the channel dimension, does the third rule still hold for upsampling? Or should it be flipped to fewer heads and lower hidden dimensions for late stages?
I will appreciate your valuable insights on the application of these build-up rules for upsampling stages.
Hi,
I am ross.
Excellent work !!!
The experiments are basically classification problems.
Does the analysis result change much if the task switch to the object detection?
안녕하세요. 저자님.
우선 많은 인사이트를 주는 좋은 논문 감사드립니다.
저자님의 논문을 읽고 코드를 활용하여 여러 분석을 진행해 보고 있습니다.
그중에 저자님의 Fig. 2b의 robustness for noise frequency에 대한 분석을 진행해 보고자 합니다.
그러나 코드에는 이 부분은 없는 것으로 보여 질문드리게 되었습니다.
아마도 FreqAttack 클래스를 활용하는 것으로 보이는데,
혹시 이 실험을 재현해보기 위한 각 frequency별 random noise를 적용하는 실험 코드 공유를 해주실 수 있을까요?
감사드립니다.
Could you let accessible for the already trained models in this work ? thank you very much in advance
Hi,
thank you for this wonderful work on vision transformers and how to understand them. I have some simple questions which I must apologize for.
I tried to reproduce figure 12 independently of your code base. I struggle a bit to understand the code. Is is correct that you define robustness as robustness = mean(accuracy(y_val_true, y_val_pred))
?
Related to this, do I understand correctly that you compute this accuracy on batches of the validation dataset? These batches are of size 256
, right?
Thanks.
Great analysis! I wonder the attributes of large-kernel CNN. In your paper, the basic 3x3 resnet is fully explored. 3x3 conv extracts detailed local patterns, thus may contribute to the high pass filtering. However, recent works investigate the effect of larger kernel. The attribute of 3x3 resnet might change, and similar to ViT?
Hi, I am Ph.D/M.S. integrated student at Yonsei University. I am very interested in your research and I am looking into your code. However, I couldn't find the code about hessian eigenvalue and recognized that you don't share it right now.
It would be very pleasure if you give me the code or guideline to write about hessian eigenvalue visualization.
my email: [email protected]
Thank you.
Your paper reports that generally MSAs behave like low-pass filters (shape-biased) and Convs behave like high-pass filters (texture-biased). Recently I came across papers that report shape bias in their findings and I wonder about your thoughts on them.
Low-pass filters (shape-biased)
High-pass filters (texture-biased)
These findings suggest that factors affecting the behavior can be spatial aggregation, kernel size, training data, or training procedures. It seems that only 3x3 Convs behave like high-pass filters or I may be missing something. In your another thread you mentioned that group size also makes a difference. I wonder how ResNet and ResNeXt differ and I suppose ResNeXt is also texture-biased.
I will appreciate your insights on what factors determine if a model or a layer behaves like a low- or high-pass filter.
I read your paper and studied a lot.
I would also like to see the code for plotting Hessian max eigenvalue spectra.
May I know if you have any plans to update?
Best,
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.