nvlabs / gcvit Goto Github PK
View Code? Open in Web Editor NEW[ICML 2023] Official PyTorch implementation of Global Context Vision Transformers
Home Page: https://arxiv.org/abs/2206.09959
License: Other
[ICML 2023] Official PyTorch implementation of Global Context Vision Transformers
Home Page: https://arxiv.org/abs/2206.09959
License: Other
Hi thanks for sharing the great work. Some problems encountered during installation. README.md said "This repository is compatible with NVIDIA PyTorch docker nvcr>=21.06", my confusion is whether “NVIDIA PyTorch docker” is necessary and whether it is possible to simply install the correct Pytorch instead of installing “NVIDIA PyTorch docker”.
Can you provide the complete requirements includes the version of Pytorch and other dependencies required for the project?
Looking forward to your reply, thank you!
Could you please publish semantic segmentation code?
It is not clear which parameters depend on pretrained image size and which on actual one.
E.g. in Swin transformer for classification there is no shift in the last (4-th) stage (BasicLayer) because of https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py#L192
But there is no such condition in semantic segmentation backbone https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py#L182
So, it is unclear how correctly apply GCViT for segmentation on image size different from pretrained.
I am not using distributed training.
My python command is:
python3 train.py --model gcvit_xxtiny --pretrained --data_dir /gcvit/datasets --dataset May-25-2023 --num-classes 3 --batch-size 32 -c configs/gc_vit_xxtiny.yml --crop-pct 1.0
My dataset is an ImageNet style dataset with "train" and "validation" folders both containing 3 folders for each class (hence num-classes 3).
There are also a couple of other errors that I had to get around (but it worked) - I'll put in a pull request in the coming days to fix the train.py file
Lastly, curiously, changing the num-classes to 4, started the training without any errors.
PS. I'm using --crop-pct 1.0 because I don't want to lose any of my image information during training or validation. I'm curious as to why a 0.875 crop is used for validation - why would you want to lose information?
Thank you,
Rohin Nanavati
hello, i validate the GCViT using your public checkpoint and the accuracy for tiny, xtiny, xxtiny is very low(top1_error 99.9, top5_error 99.75)
I wonder either you public the wrong checkpoints or my validation bash has something wrong?
python validate.py \
--model gc_vit_xxtiny \
--checkpoint ckpts/gcvit_xxtiny_best_1k.pth.tar \
--data_dir data/imagenet/ \
--batch-size 32
Will you provide example to downstream tasks like object detection?
If click the model link, the connection is not possible due to the wrong address. can you please check it?
Hello,
I am the developer of flaim, a library of pre-trained vision models for the JAX/Flax ecosystem, which includes GC ViT as one of its available architectures. I would appreciate it if you would consider mentioning flaim in the third-pary implementations section of the README.
Best,
Borna
Hi,
Thank you.
Here https://github.com/NVlabs/GCVit/blob/main/models/gc_vit.py#L382 "x" is features in channel-last format [batch, height, width, channels]
You passing it thru some Conv2D layers (FeatExtract) which work with features in channel-first format ([batch, height, width, channels]). But instead of transpose operation you are making reshape with mixes channels and spatial axis.
I believe most architectural ideas in GCViT will boost performance, but i'd like to know reliable differences in metrices.
At the moment it is not possible due to all models trained with unusual window size.
More precisely Swin V1 with 224x222 input trained with window_size=7 for all stages.
GCViT at the same time uses [7, 7, 14, 7] window sizes. So at the third stage it has twice larger receptive field.
Due to this code branch https://github.com/NVlabs/GCVit/blob/main/models/gc_vit.py#L503 i believe you made some experiments.
Could you please publish their results, so we can understand what part of performance gained by architectural changes and what by receptive field expansion?
Dear all,
I have completed a new re-implementation of your work in TensorFlow2/Keras, including used timm
modules : GCViT-TensorFlow.
It would be a pleasure to be included in "Third-party Implementations and Resources" section of your repo.
Anyway, congratulations for your work.
Thank you.
Best regards
I've attempting to train the 'tiny' model from scratch for verification but running into some problems. I hit an accuracy wall of sorts approaching 80% and the model should well pass that based on your results.
I am (obvioulsy) familiar with the train scripts, hparams in general, have trained quite a few related models to expected accuracy.
A few questions
what is the input shape ,B H W C or B C H W?
dataloader seems using only single process to handle image processing, num_workers waiting 1 process. gpu is fast ,but the cpu is slow. why is that? how to make the training faster?
Dear all,
I am training a model for emotion recognition and figure out this error in line line 521.
RuntimeError: shape '[32, 1, 49, 3, 32]' is invalid for input of size 49152.
The batch size = 32 and the model is GCVit-S
Can someone help to solve this issue?
Best regards
Hi!
I'm probably doing something fundamentally wrong when I try to train the model in a slightly different input size.
For example
python train.py --config configs/gc_vit_large.yml --data_dir '/mnt/…/my/custom/dataset/' --num-classes 17 --no-aug --crop-pct 1 --input-size 3 384 384 --experiment dif_InputSize
gives me the following error:
Traceback (most recent call last):
File "train.py", line 864, in
main()
File "train.py", line 648, in main
train_metrics = train_one_epoch(
File "train.py", line 722, in train_one_epoch
output = model(input)
File "/home/tan/env-GCVit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/tan/GCVit/models/gc_vit.py", line 705, in forward
x = self.forward_features(x)
File "/home/tan/GCVit/models/gc_vit.py", line 696, in forward_features
x = level(x)
File "/home/tan/env-GCVit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/tan/GCVit/models/gc_vit.py", line 600, in forward
q_global = self.q_global_gen(_to_channel_first(x))
File "/home/tan/env-GCVit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/tan/GCVit/models/gc_vit.py", line 536, in forward
x = x.reshape(B, 1, self.N, self.num_heads, self.dim_head).permute(0, 1, 3, 2, 4)
RuntimeError: shape '[32, 1, 49, 6, 32]' is invalid for input of size 884736
If you could explain what I'm doing wrong here and how can I fix it, that would be great.
Is there a way to use this with arbitrary resolution?
model = timm.create_model('gcvit_tiny', pretrained=True)
sample = model(torch.randn(2, 3, 640, 640))
Error
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/tmp/ipykernel_9221/897640580.py in <module>
----> 1 sample = model(torch.randn(2, 3, 640, 640))
~/anaconda3/envs/yl/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
~/anaconda3/envs/yl/lib/python3.7/site-packages/timm/models/gcvit.py in forward(self, x)
524
525 def forward(self, x: torch.Tensor) -> torch.Tensor:
--> 526 x = self.forward_features(x)
527 x = self.forward_head(x)
528 return x
~/anaconda3/envs/yl/lib/python3.7/site-packages/timm/models/gcvit.py in forward_features(self, x)
517 def forward_features(self, x: torch.Tensor) -> torch.Tensor:
518 x = self.stem(x)
--> 519 x = self.stages(x)
520 return x
521
~/anaconda3/envs/yl/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
~/anaconda3/envs/yl/lib/python3.7/site-packages/torch/nn/modules/container.py in forward(self, input)
139 def forward(self, input):
140 for module in self:
--> 141 input = module(input)
142 return input
143
~/anaconda3/envs/yl/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
~/anaconda3/envs/yl/lib/python3.7/site-packages/timm/models/gcvit.py in forward(self, x)
388 x = checkpoint.checkpoint(blk, x)
389 else:
--> 390 x = blk(x, global_query)
391 x = self.norm(x)
392 x = x.permute(0, 3, 1, 2).contiguous() # back to NCHW
~/anaconda3/envs/yl/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
~/anaconda3/envs/yl/lib/python3.7/site-packages/timm/models/gcvit.py in forward(self, x, q_global)
310
311 def forward(self, x, q_global: Optional[torch.Tensor] = None):
--> 312 x = x + self.drop_path1(self.ls1(self._window_attn(self.norm1(x), q_global)))
313 x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
314 return x
~/anaconda3/envs/yl/lib/python3.7/site-packages/timm/models/gcvit.py in _window_attn(self, x, q_global)
303 def _window_attn(self, x, q_global: Optional[torch.Tensor] = None):
304 B, H, W, C = x.shape
--> 305 x_win = window_partition(x, self.window_size)
306 x_win = x_win.view(-1, self.window_size[0] * self.window_size[1], C)
307 attn_win = self.attn(x_win, q_global)
~/anaconda3/envs/yl/lib/python3.7/site-packages/timm/models/gcvit.py in window_partition(x, window_size)
235 def window_partition(x, window_size: Tuple[int, int]):
236 B, H, W, C = x.shape
--> 237 x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
238 windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
239 return windows
RuntimeError: shape '[2, 22, 7, 22, 7, 64]' is invalid for input of size 3276800
Hi,
When I try to train the model by myself, I have several questions about the ImageNet dataset. For ImageNet-1k, the validation data from official website is not a format of:
imagenet
├── val
├── class1
│ ├── img4.jpeg
│ ├── img5.jpeg
│ └── ...
├── class2
│ ├── img6.jpeg
│ └── ...
└── ...
but the validation tar is formatted as:
imagenet
├── val
├── img1.jpeg
├── img2.jpeg
├── img3.jpeg
└── ...
How to prepare the validation set into the desired format?
In addition, if the train set contains class0~class900, val set contains class900~1000. Should I create empty folders for train directory and test directory (so as to make each of them has 1000 folders) ?
Thanks a lot!
Model code has an issue in adapting global query for attention estimation: https://github.com/NVlabs/GCVit/blob/main/models/gc_vit.py#L223
Repeating along batch dimension results not in [1, 1, 1, 2, 2, 2, 3, 3, 3] bached element order, but in [1, 2, 3, 1, 2, 3, 1, 2, 3].
In the same time after window_partition key and value "windows" are odered as [1, 1, 1, 2, 2, 2, 3, 3, 3].
You can verify this by evaluating imagenet with batch size 1 and 16. Results would be different.
In my tests results for batch size 1 are always equals. And results for batch size 16 will be always different if we add stohasticy (shuffle batch items and labels simultaneously).
During training this bug leads to partly wrong attention matrix estimation (because around "B_//B - 1" portion of channels will try to attend to wrong batch elements).
I suppose that fixing this bug and finetuning existing checkpoints should get even better metrics.
Hi!
Thank you for your great work.
I am getting an error that the output shape for each stage in GC ViT is incorrect. I am not sure where the error is coming from, but I would appreciate your help in debugging it.
I wrote a print statement to print out the shape of x in the forward_features method of the GCViT class. The code is as follows:
def forward_features(self, x):
x = self.patch_embed(x)
x = self.pos_drop(x)
for level in self.levels:
x = level(x)
print(x.shape)
......
Then I run the following code:
model=gc_vit_xxtiny()
inp = torch.rand(1,3, 224, 224)
res=model(inp)
The results are as follows:
torch.Size([1, 28, 28, 128])
torch.Size([1, 14, 14, 256])
torch.Size([1, 7, 7, 512])
torch.Size([1, 7, 7, 512])
I am not sure if I made a mistake, but the results I am getting are different from what you mentioned in the paper.
Thanks
I was trying to implement GCVit as backbone for YOLO one stage detector. However, one thing I noticed is that YOLO architectures uses only three feature maps from its backbone before it passes it to the neck architecture. Howevr, ViT like GCVit and Swin-T have 4 feature maps. How do I properly use these feature maps without loss of vital informations?
Hi thanks for sharing the great work.
I am kindly wondering whether you are aware of our recent work that also proposes to use serial global-local attention? I am grateful if you can add a discussion if you happen to update it: MaxViT: Multi-Axis Vision Transformer
Hello @ahatamiz
Thank you very much for your work.
I find that your visualization of global attention is very attractive in understanding your work. I would like to ask about the details of your visualization of global attention.
Hi ,
Could you please share the pretrained model of GC_ViT-B_384 in IN1-K ?
Hi!
Thank you for your great work.
I am getting an error that the output shape for each stage in GC ViT is incorrect. I am not sure where the error is coming from, but I would appreciate your help in debugging it.
I wrote a print statement to print out the shape of x in the forward_features method of the GCViT class. The code is as follows:
def forward_features(self, x):
x = self.patch_embed(x)
x = self.pos_drop(x)
for level in self.levels:
x = level(x)
print(x.shape)
......
Then I run the following code:
model=gc_vit_xxtiny()
inp = torch.rand(1,3, 224, 224)
res=model(inp)
The results are as follows:
torch.Size([1, 28, 28, 128])
torch.Size([1, 14, 14, 256])
torch.Size([1, 7, 7, 512])
torch.Size([1, 7, 7, 512])
I am not sure if I made a mistake, but the results I am getting are different from what you mentioned in the paper.
Thanks
Model code has an issue in adapting global query for attention estimation: https://github.com/NVlabs/GCVit/blob/main/models/gc_vit.py#L223
Heads in keys and values and local query produced by splitting channels and head dimension transposition to "batch" dimension (in terms of matmul).
Heads in global query produced by shifting spatial size https://github.com/NVlabs/GCVit/blob/main/models/gc_vit.py#L224 because of missing transpose op that moves head dimension.
As far as i understand this results in not full query-to-key spatial interactions and adds some query-key channel shifting.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.