Code Monkey home page Code Monkey logo

kyegomez / visionmamba Goto Github PK

View Code? Open in Web Editor NEW
310.0 6.0 15.0 2.23 MB

Implementation of Vision Mamba from the paper: "Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model" It's 2.8x faster than DeiT and saves 86.8% GPU memory when performing batch inference to extract features on high-res images

Home Page: https://discord.gg/GYbXvDGevY

License: MIT License

Shell 17.33% Python 82.67%
ai machine-learning mamba pytorch recurrent-neural-network ssm

visionmamba's Introduction

Multi-Modality

Vision Mamba

Implementation of Vision Mamba from the paper: "Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model" It's 2.8x faster than DeiT and saves 86.8% GPU memory when performing batch inference to extract features on high-res images.

PAPER LINK

Installation

pip install vision-mamba

Usage

import torch
from vision_mamba import Vim

# Forward pass
x = torch.randn(1, 3, 224, 224)  # Input tensor with shape (batch_size, channels, height, width)

# Model
model = Vim(
    dim=256,  # Dimension of the transformer model
    heads=8,  # Number of attention heads
    dt_rank=32,  # Rank of the dynamic routing matrix
    dim_inner=256,  # Inner dimension of the transformer model
    d_state=256,  # Dimension of the state vector
    num_classes=1000,  # Number of output classes
    image_size=224,  # Size of the input image
    patch_size=16,  # Size of each image patch
    channels=3,  # Number of input channels
    dropout=0.1,  # Dropout rate
    depth=12,  # Depth of the transformer model
)

# Forward pass
out = model(x)  # Output tensor from the model
print(out.shape)  # Print the shape of the output tensor
print(out)  # Print the output tensor

Citation

@misc{zhu2024vision,
    title={Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model}, 
    author={Lianghui Zhu and Bencheng Liao and Qian Zhang and Xinlong Wang and Wenyu Liu and Xinggang Wang},
    year={2024},
    eprint={2401.09417},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

License

MIT

Todo

  • Create training script for imagenet
  • Create a visual mamba for facial recognition

visionmamba's People

Contributors

dependabot[bot] avatar evelynmitchell avatar jungwon-choi avatar kyegomez avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

visionmamba's Issues

[BUG] Matrix size is not right

Describe the bug
I just run the example.py and get error: RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x4 and 512x512)

To Reproduce
Steps to reproduce the behavior:

  1. Run example.py

Expected behavior
output tensor with 1512512

Screenshots
image

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

[BUG] `heads` parameter is defined but not used in `VisionEncoderMambaBlock`

heads parameter is defined but not used in VisionEncoderMambaBlock.

This line uses it to initialize VisionEncoderMambaBlock:

heads=heads,

But it has no use in VisionEncoderMambaBlock as can be seen from here:

https://github.com/kyegomez/VisionMamba/blob/68e6447d9e3a7e7bb369c227d667bbb88275ce2e/vision_mamba/model.py#L32C1-L135C17

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

readme

Bro, can you give the appropriate configuration environment

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

forward-backward ssm

It appears, per algorithm 1 of the vision mamba paper, the state-space model runs bi-directionally along the sequence. But in this implementation I see that both forward and backward convolutions are standard 1D convolutions, rather than going backward and forward in space. Can you explain your rationale behind this?

Because otherwise, there seems to be no difference between the backward and forward operations, which would otherwise be providing the mamba block a bi-directional way of applying the state-space selection operation.

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

[BUG] Swarms is required to install but not used

Describe the bug
Another package that isn't used in this package is in requirements.

To Reproduce
Install the package from pip

Expected behavior
Install necessary packages

Additional context
I see the author is the author of swarms package too. So if possible, can you clarify why such requirement is needed. I've checked that package too, but it's unrelated.

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

[BUG] TypeError: unhashable type: 'list'

when running example.py, show a TypeError like this:

/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/torch/onnx/_internal/_beartype.py:35: UserWarning: unhashable type: 'list'
warnings.warn(f"{e}")
Traceback (most recent call last):
File "/home/disk2_12t/lhl/VisionMamba-main/example.py", line 2, in
from vision_mamba.model import Vim
File "/home/disk2_12t/lhl/VisionMamba-main/vision_mamba/init.py", line 1, in
from vision_mamba.model import Vim
File "/home/disk2_12t/lhl/VisionMamba-main/vision_mamba/model.py", line 7, in
from zeta.nn import SSM
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/zeta/init.py", line 28, in
from zeta.nn import *
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/zeta/nn/init.py", line 1, in
from zeta.nn.attention import *
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/zeta/nn/attention/init.py", line 14, in
from zeta.nn.attention.mixture_attention import (
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/zeta/nn/attention/mixture_attention.py", line 8, in
from zeta.models.vit import exists
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/zeta/models/init.py", line 3, in
from zeta.models.andromeda import Andromeda
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/zeta/models/andromeda.py", line 4, in
from zeta.structs.auto_regressive_wrapper import AutoregressiveWrapper
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/zeta/structs/init.py", line 4, in
from zeta.structs.local_transformer import LocalTransformer
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/zeta/structs/local_transformer.py", line 8, in
from zeta.nn.modules import feedforward_network
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/zeta/nn/modules/init.py", line 27, in
from zeta.nn.modules.adaptive_conv import AdaptiveConv3DMod
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/zeta/nn/modules/adaptive_conv.py", line 3, in
from beartype import beartype
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/beartype/init.py", line 58, in
from beartype._decor.decormain import (
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/beartype/_decor/decormain.py", line 23, in
from beartype._conf.confcls import (
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/beartype/_conf/confcls.py", line 46, in
from beartype._conf.confoverrides import (
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/beartype/_conf/confoverrides.py", line 15, in
from beartype._data.hint.datahinttyping import (
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/beartype/_data/hint/datahinttyping.py", line 212, in
BeartypeReturn = Union[BeartypeableT, BeartypeConfedDecorator]
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/typing.py", line 243, in inner
return func(*args, **kwds)
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/typing.py", line 316, in getitem
return self._getitem(self, parameters)
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/typing.py", line 421, in Union
parameters = _remove_dups_flatten(parameters)
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/typing.py", line 215, in _remove_dups_flatten
all_params = set(params)
TypeError: unhashable type: 'list'

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

[BUG] Cannot import name 'SSM' from 'zeta.nn'

Describe the bug
The zeta module which is a required import does not have the attribute ssm

To Reproduce
Steps to reproduce the behavior:

  1. Go to 'vision_mamba/model.py'
  2. Click on 'model.py'
  3. If you run example.py in the root directory, you will get an import Error:
    ImportError: cannot import name 'SSM' from 'zeta.nn' (C:\Users\mseb\miniconda3\envs\py38\lib\site-packages\zeta\nn_init_.py)

Expected behavior
We would expect the import to work

Additional context
I also tried 'from zeta.nn.modules import SSM'. This also does not work although the documentation here [https://zeta.apac.ai/en/latest/zeta/nn/modules/ssm/] suggests that it should work.

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

Matrix operation

HI! thanks for your work! 'I have some questions about the Matrix opearation.
When the input size in example.py is set to (1,512,1024), it seems impossible to compute "x2 = x2@z" in model.py (line 91).

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

[BUG]

image

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

[BUG] Inconsistent Tensor Dimensions

Describe the bug
Saw this from the discord, and am able to reproduce exactly. Copying the message:

I'm getting inconsistent tensor output shapes. Does anyone know anything about this? I'll provide some examples:

  1. Input tensor shape: torch.Size([1, 3, 224, 224])
    Output tensor shape: torch.Size([1, 196, 10])

224 -> 196

  1. Input tensor shape: torch.Size([1, 3, 16, 16])
    Output tensor shape: torch.Size([1, 1, 10])

16 -> 1

  1. Input tensor shape: torch.Size([1, 3, 256, 256])
    Output tensor shape: torch.Size([1, 256, 10])

256 -> 256

The model used is:

model = Vim(
    dim=256,
    heads=8,
    dt_rank=32,
    dim_inner=256,
    d_state=256,
    num_classes=10,
    image_size=IMAGE_SIZE,
    patch_size=16,
    channels=IMAGE_CHANNELS,
    dropout=0.1,
    depth=12,
)

If it's consistent, what's the math behind this? Or is this a BUG?

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

the pre-training parameter file of the model

Can you provide the pre-training parameter file of the model?

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

Thank you for your excellent work. I have a question regarding equation 6 in the paper: \( T_l = V_{im}(T_{l-1}) + T_{l-1} \). Why do you still add \( T_{l-1} \) when this addition process is already included in \( V_{im} \)?

Is your feature request related to a problem? Please describe.
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]

Describe the solution you'd like
A clear and concise description of what you want to happen.

Describe alternatives you've considered
A clear and concise description of any alternative solutions or features you've considered.

Additional context
Add any other context or screenshots about the feature request here.

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

[BUG] RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [1, 256] but got: [1, 196].

Describe the bug
just run example.py and you can see the error.

To Reproduce
Steps to reproduce the behavior:

  1. cd VisionMamba
  2. python example.py
  3. See error

Expected behavior
I've done nothing , just run the code which given by the author.

Screenshots

2024-04-03 21-49-47屏幕截图

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

[BUG] ImportError When Trying to Import Vision Mamba After Installation

Describe the bug
After successfully installing the vision-mamba package in my environment, attempting to import it using from vision_mamba.model import Vim results in an ImportError. The error message indicates a problem with shape mismatch during a matrix multiplication operation in one of the package's dependencies.

To Reproduce
Steps to reproduce the behavior:

  1. Attempt to import the Vim class from the package using from vision_mamba.model import Vim.

Additional context
Traceback (most recent call last):
File "", line 1, in
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/vision_mamba/model.py", line 4, in
from zeta.nn.modules.ssm import SSM
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/zeta/init.py", line 28, in
from zeta.nn import *
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/zeta/nn/init.py", line 3, in
from zeta.nn.modules import *
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/zeta/nn/modules/init.py", line 47, in
from zeta.nn.modules.mlp_mixer import MLPMixer
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/zeta/nn/modules/mlp_mixer.py", line 145, in
output = mlp_mixer(example_input)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/zeta/nn/modules/mlp_mixer.py", line 125, in forward
x = mixer_block(x)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/zeta/nn/modules/mlp_mixer.py", line 63, in forward
y = self.tokens_mlp(y)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/zeta/nn/modules/mlp_mixer.py", line 30, in forward
y = self.dense1(x)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x4 and 512x512)

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

Issues regarding the initial linear layer for x and z

Hello, I am truly amazed at the job you have done! And I am trying to build my ideas upon your code.

While I was following well through your codes, I had to come up with a minor question regarding your codes in vision_mamba/model.py Line 96.

I can see from the page 4 of the original paper Algorithm 1 Line 3 that there are two different linear layers that process x and z. While your code suggests those two vectors be forwarded through the same layer.

Though I find it a very trivial one to tell you, can you explain more upon that? Thanks

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

the pre-training parameter file of the model

Can you provide the pre-training parameter file of the model?

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

Matrix dimensions do not match

As you can see, I am currently facing a problem. I modified the Transformer part of my code using the code you provided, but it shows that the dimensions of my matrix multiplication do not match. I hope you can solve my doubts.There is a dimension mismatch in the multiplication of x1_ssm and x2_ssm with z. At the same time, I am still confused. Doesn't vim say that mamba blocks are of two types, front and back? Why didn't I see this in the code you provided?

    def forward_temporal(self, x,F):
        B, J, C = x.shape
       
        # Skip connection
        skip = x      
       
        # Normalization
        x = self.norm(x)

        # Split x into x1 and x2 with linears
        z1 = self.proj_x(x)   
        x1 = self.proj_z(x)

        # forward 
        x1 = x1.reshape(B,C,J)
        x1_rearranged = self.softplus(x1)
        forward_conv_output = self.forward_conv1d(x1_rearranged)
        forward_conv_output = forward_conv_output.reshape(B,J,C)
        x1_ssm = self.forward_ssm(forward_conv_output)

        # backward 
        x2 = x1.reshape(B,C,J)
        x2_rearranged = self.softplus(x2)
        backward_conv_output = self.backward_conv1d(x2_rearranged)
        backward_conv_output = backward_conv_output.reshape(B,J,C)
        x2_ssm = self.backward_ssm(backward_conv_output)
        
        # Activation
        z = self.activation(z1)

        # matmul with z + backward ssm
        x2 = x2_ssm * z

        # Matmul with z and x1
        x1 = x1_ssm * z

        # Add both matmuls
        x = x1 + x2

        # Add skip connection
        return x + skip

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.