Code Monkey home page Code Monkey logo

pytorch-vit's Introduction

Vision Transformers

Implementation of Vision Transformer in PyTorch, a new model to achieve SOTA in vision classification with using transformer style encoders. Associated blog article.

Credits to Phil Wang for the gif ViT

Features

  • ViT
  • ViT with convolutional patches
  • ViT with convolutional stems
    • Early Convolutional Stem
    • Scaled ReLU Stem
  • GAP Pooling

Citations

@article{dosovitskiy2020image,
  title={An image is worth 16x16 words: Transformers for image recognition at scale},
  author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and others},
  journal={arXiv preprint arXiv:2010.11929},
  year={2020}
}
@article{xiao2021early,
  title={Early convolutions help transformers see better},
  author={Xiao, Tete and Singh, Mannat and Mintun, Eric and Darrell, Trevor and Doll{\'a}r, Piotr and Girshick, Ross},
  journal={arXiv preprint arXiv:2106.14881},
  year={2021}
}
@article{wang2021scaled,
  title={Scaled ReLU Matters for Training Vision Transformers},
  author={Wang, Pichao and Wang, Xue and Luo, Hao and Zhou, Jingkai and Zhou, Zhipeng and Wang, Fan and Li, Hao and Jin, Rong},
  journal={arXiv preprint arXiv:2109.03810},
  year={2021}
}
@article{zhai2021scaling,
  title={Scaling vision transformers},
  author={Zhai, Xiaohua and Kolesnikov, Alexander and Houlsby, Neil and Beyer, Lucas},
  journal={arXiv preprint arXiv:2106.04560},
  year={2021}
}

pytorch-vit's People

Contributors

gupta-abhay avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-vit's Issues

Using ViT

Hello Gupta!

Being new to vision tasks, can you share just a small snippet that can show how we can use pytorch-vit in downwards vision tasks like image retrieval etc. Thanx

Why only use the first patch? Thanks

I don't understand the line 74 of ViT.py:
x = self.to_cls_token(x[:, 0])
If the first dimension of x is batch, then the 2nd dimension 0 should be patch, as the dimension of x should be [batch, patch, feature]. Does it mean only the first patch is used? Could anybody help me on this? Thanks a lot.

layout for vit: NCHW or NHWC? bugpropagation

hello,
Thanks a lot for this very interesting work!

when you unroll the tensor you use unfold and flatten like this:

x = (x.unfold(2, self.patch_dim, self.patch_dim).
unfold(3, self.patch_dim, self.patch_dim).contiguous())
x = x.view(x.size(0), -1, self.flatten_dim)

but if x is in shape N,C,H,W, unrolling ends up with N,C,H//P,W//P,P,P and therefore flattening ends up mixing data from different channels. It means your "words" come from different blocks in space. It does not really matter for training your model with one specific size, but i think it will have hard time to transfer to a different size...

instead you could do like this:

self.flatten_dim_in = (patch_dim**2) * in_channels
...
x = (x.unfold(2, self.patch_dim, self.patch_dim)
         .unfold(3, self.patch_dim, self.patch_dim) .contiguous())
x = x.view(b,c,-1,self.patch_dim**2)
x = x.permute(0,2,3,1).contiguous()
x = x.view(x.size(0), -1, self.flatten_dim_in)

Just to make sure the data at the end is really what you expect: all the rgb pixels of one patch together, and not a mix of patches together.

Now i haven't tried your code yet so perhaps you have a different layout than N,C,H,W for images?

adjust_learning_rate can't import

In the train.py there the code include "from vit.utils import ( adjust_learning_rate)" but no adjust_learning_rate in the util.py

Patch the image

How do you patch the image? any clues for the preprocessing and training step?

Working of FixedPositionalEncoding and LearnedPositionalEncoding?

I get the part where Image is split into P say 16x16 smaller image patches and then you have to Flatten the 3D patch to pass it into a Linear layer to get what they call Liner Projection. Can you please explain how the two types of Embeddings are working. Looked at your code too and looked like a maze to me. If you could just explain in Laymen's terms, I'll look at the code again and understand.
Thanks

The size of embedding dim

This work is very interested and fascinating. I have a question : the size of embedding size and how you decide it?
Look forward to the release of pretrained model.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.