Implementation of Vision Mamba from the paper: "Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model" It's 2.8x faster than DeiT and saves 86.8% GPU memory when performing batch inference to extract features on high-res images
Implementation of Vision Mamba from the paper: "Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model" It's 2.8x faster than DeiT and saves 86.8% GPU memory when performing batch inference to extract features on high-res images.
importtorchfromvision_mambaimportVim# Forward passx=torch.randn(1, 3, 224, 224) # Input tensor with shape (batch_size, channels, height, width)# Modelmodel=Vim(
dim=256, # Dimension of the transformer modelheads=8, # Number of attention headsdt_rank=32, # Rank of the dynamic routing matrixdim_inner=256, # Inner dimension of the transformer modeld_state=256, # Dimension of the state vectornum_classes=1000, # Number of output classesimage_size=224, # Size of the input imagepatch_size=16, # Size of each image patchchannels=3, # Number of input channelsdropout=0.1, # Dropout ratedepth=12, # Depth of the transformer model
)
# Forward passout=model(x) # Output tensor from the modelprint(out.shape) # Print the shape of the output tensorprint(out) # Print the output tensor
Citation
@misc{zhu2024vision,
title={Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model},
author={Lianghui Zhu and Bencheng Liao and Qian Zhang and Xinlong Wang and Wenyu Liu and Xinggang Wang},
year={2024},
eprint={2401.09417},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
It appears, per algorithm 1 of the vision mamba paper, the state-space model runs bi-directionally along the sequence. But in this implementation I see that both forward and backward convolutions are standard 1D convolutions, rather than going backward and forward in space. Can you explain your rationale behind this?
Because otherwise, there seems to be no difference between the backward and forward operations, which would otherwise be providing the mamba block a bi-directional way of applying the state-space selection operation.
Upvote & Fund
We're using Polar.sh so you can upvote and help fund this issue.
We receive the funding once the issue is completed & confirmed by you.
Thank you in advance for helping prioritize & fund our backlog.
Describe the bug
Another package that isn't used in this package is in requirements.
To Reproduce
Install the package from pip
Expected behavior
Install necessary packages
Additional context
I see the author is the author of swarms package too. So if possible, can you clarify why such requirement is needed. I've checked that package too, but it's unrelated.
Upvote & Fund
We're using Polar.sh so you can upvote and help fund this issue.
We receive the funding once the issue is completed & confirmed by you.
Thank you in advance for helping prioritize & fund our backlog.
when running example.py, show a TypeError like this:
/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/torch/onnx/_internal/_beartype.py:35: UserWarning: unhashable type: 'list'
warnings.warn(f"{e}")
Traceback (most recent call last):
File "/home/disk2_12t/lhl/VisionMamba-main/example.py", line 2, in
from vision_mamba.model import Vim
File "/home/disk2_12t/lhl/VisionMamba-main/vision_mamba/init.py", line 1, in
from vision_mamba.model import Vim
File "/home/disk2_12t/lhl/VisionMamba-main/vision_mamba/model.py", line 7, in
from zeta.nn import SSM
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/zeta/init.py", line 28, in
from zeta.nn import *
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/zeta/nn/init.py", line 1, in
from zeta.nn.attention import *
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/zeta/nn/attention/init.py", line 14, in
from zeta.nn.attention.mixture_attention import (
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/zeta/nn/attention/mixture_attention.py", line 8, in
from zeta.models.vit import exists
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/zeta/models/init.py", line 3, in
from zeta.models.andromeda import Andromeda
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/zeta/models/andromeda.py", line 4, in
from zeta.structs.auto_regressive_wrapper import AutoregressiveWrapper
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/zeta/structs/init.py", line 4, in
from zeta.structs.local_transformer import LocalTransformer
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/zeta/structs/local_transformer.py", line 8, in
from zeta.nn.modules import feedforward_network
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/zeta/nn/modules/init.py", line 27, in
from zeta.nn.modules.adaptive_conv import AdaptiveConv3DMod
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/zeta/nn/modules/adaptive_conv.py", line 3, in
from beartype import beartype
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/beartype/init.py", line 58, in
from beartype._decor.decormain import (
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/beartype/_decor/decormain.py", line 23, in
from beartype._conf.confcls import (
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/beartype/_conf/confcls.py", line 46, in
from beartype._conf.confoverrides import (
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/beartype/_conf/confoverrides.py", line 15, in
from beartype._data.hint.datahinttyping import (
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/site-packages/beartype/_data/hint/datahinttyping.py", line 212, in
BeartypeReturn = Union[BeartypeableT, BeartypeConfedDecorator]
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/typing.py", line 243, in inner
return func(*args, **kwds)
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/typing.py", line 316, in getitem
return self._getitem(self, parameters)
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/typing.py", line 421, in Union
parameters = _remove_dups_flatten(parameters)
File "/home/disk2_12t/lhl/anaconda3/envs/vmamba/lib/python3.9/typing.py", line 215, in _remove_dups_flatten
all_params = set(params)
TypeError: unhashable type: 'list'
Upvote & Fund
We're using Polar.sh so you can upvote and help fund this issue.
We receive the funding once the issue is completed & confirmed by you.
Thank you in advance for helping prioritize & fund our backlog.
Describe the bug
The zeta module which is a required import does not have the attribute ssm
To Reproduce
Steps to reproduce the behavior:
Go to 'vision_mamba/model.py'
Click on 'model.py'
If you run example.py in the root directory, you will get an import Error:
ImportError: cannot import name 'SSM' from 'zeta.nn' (C:\Users\mseb\miniconda3\envs\py38\lib\site-packages\zeta\nn_init_.py)
Expected behavior
We would expect the import to work
Additional context
I also tried 'from zeta.nn.modules import SSM'. This also does not work although the documentation here [https://zeta.apac.ai/en/latest/zeta/nn/modules/ssm/] suggests that it should work.
Upvote & Fund
We're using Polar.sh so you can upvote and help fund this issue.
We receive the funding once the issue is completed & confirmed by you.
Thank you in advance for helping prioritize & fund our backlog.
HI! thanks for your work! 'I have some questions about the Matrix opearation.
When the input size in example.py is set to (1,512,1024), it seems impossible to compute "x2 = x2@z" in model.py (line 91).
Upvote & Fund
We're using Polar.sh so you can upvote and help fund this issue.
We receive the funding once the issue is completed & confirmed by you.
Thank you in advance for helping prioritize & fund our backlog.
Is your feature request related to a problem? Please describe.
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
Describe the solution you'd like
A clear and concise description of what you want to happen.
Describe alternatives you've considered
A clear and concise description of any alternative solutions or features you've considered.
Additional context
Add any other context or screenshots about the feature request here.
Upvote & Fund
We're using Polar.sh so you can upvote and help fund this issue.
We receive the funding once the issue is completed & confirmed by you.
Thank you in advance for helping prioritize & fund our backlog.
Describe the bug
After successfully installing the vision-mamba package in my environment, attempting to import it using from vision_mamba.model import Vim results in an ImportError. The error message indicates a problem with shape mismatch during a matrix multiplication operation in one of the package's dependencies.
To Reproduce
Steps to reproduce the behavior:
Attempt to import the Vim class from the package using from vision_mamba.model import Vim.
Additional context
Traceback (most recent call last):
File "", line 1, in
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/vision_mamba/model.py", line 4, in
from zeta.nn.modules.ssm import SSM
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/zeta/init.py", line 28, in
from zeta.nn import *
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/zeta/nn/init.py", line 3, in
from zeta.nn.modules import *
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/zeta/nn/modules/init.py", line 47, in
from zeta.nn.modules.mlp_mixer import MLPMixer
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/zeta/nn/modules/mlp_mixer.py", line 145, in
output = mlp_mixer(example_input)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/zeta/nn/modules/mlp_mixer.py", line 125, in forward
x = mixer_block(x)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/zeta/nn/modules/mlp_mixer.py", line 63, in forward
y = self.tokens_mlp(y)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/zeta/nn/modules/mlp_mixer.py", line 30, in forward
y = self.dense1(x)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x4 and 512x512)
Upvote & Fund
We're using Polar.sh so you can upvote and help fund this issue.
We receive the funding once the issue is completed & confirmed by you.
Thank you in advance for helping prioritize & fund our backlog.
Hello, I am truly amazed at the job you have done! And I am trying to build my ideas upon your code.
While I was following well through your codes, I had to come up with a minor question regarding your codes in vision_mamba/model.py Line 96.
I can see from the page 4 of the original paper Algorithm 1 Line 3 that there are two different linear layers that process x and z. While your code suggests those two vectors be forwarded through the same layer.
Though I find it a very trivial one to tell you, can you explain more upon that? Thanks
Upvote & Fund
We're using Polar.sh so you can upvote and help fund this issue.
We receive the funding once the issue is completed & confirmed by you.
Thank you in advance for helping prioritize & fund our backlog.
As you can see, I am currently facing a problem. I modified the Transformer part of my code using the code you provided, but it shows that the dimensions of my matrix multiplication do not match. I hope you can solve my doubts.There is a dimension mismatch in the multiplication of x1_ssm and x2_ssm with z. At the same time, I am still confused. Doesn't vim say that mamba blocks are of two types, front and back? Why didn't I see this in the code you provided?
defforward_temporal(self, x,F):
B, J, C=x.shape# Skip connectionskip=x# Normalizationx=self.norm(x)
# Split x into x1 and x2 with linearsz1=self.proj_x(x)
x1=self.proj_z(x)
# forward x1=x1.reshape(B,C,J)
x1_rearranged=self.softplus(x1)
forward_conv_output=self.forward_conv1d(x1_rearranged)
forward_conv_output=forward_conv_output.reshape(B,J,C)
x1_ssm=self.forward_ssm(forward_conv_output)
# backward x2=x1.reshape(B,C,J)
x2_rearranged=self.softplus(x2)
backward_conv_output=self.backward_conv1d(x2_rearranged)
backward_conv_output=backward_conv_output.reshape(B,J,C)
x2_ssm=self.backward_ssm(backward_conv_output)
# Activationz=self.activation(z1)
# matmul with z + backward ssmx2=x2_ssm*z# Matmul with z and x1x1=x1_ssm*z# Add both matmulsx=x1+x2# Add skip connectionreturnx+skip
Upvote & Fund
We're using Polar.sh so you can upvote and help fund this issue.
We receive the funding once the issue is completed & confirmed by you.
Thank you in advance for helping prioritize & fund our backlog.