alxndrtl / mamba.py Goto Github PK
View Code? Open in Web Editor NEWA simple and efficient Mamba implementation in pure PyTorch and MLX.
License: MIT License
A simple and efficient Mamba implementation in pure PyTorch and MLX.
License: MIT License
For running extremely large sequence lengths, I can break up the sequence among batches and use the inference step to save the hidden states between passes. However, step does not work during training as there seems to be a problem with overwriting the variables before loss.backward() can be run.
How might you modify the forward pass to allow running the same parallel scan as used during training, but connect previous hidden states from other batches?
Hi, I really liked this project.
I was hoping if you could also finish pscan documentation in notebook. It already gave me a lot of clarity. It will be helpful further.
Hi,
Firstly, thanks for making this repo. I found it very useful in understanding the scan algorithm. However, discretization in this repo seems to be different from the Eq 4 of the paper. Do you have any comments on this? Also, I wonder why the original paper needs the discretization step in the first place since it is possible to make the discrete versions of A and B conditioned on the input directly. I imagine that it must have something to do with the initialization, but I am not sure.
Line 248 in b9f315d
Hi. First of all thanks for great work! I was wondering if it's possible or planned to add VideoMamba support. Have a great day!
Dear Alex,
Thank you for the repo, and I realized that the pscan output between parallel and sequential versions are different. I compared the results of seletive_scan and selective_scan_seq from here: https://github.com/alxndrTL/mamba.py/blob/main/mamba.py#L258
I know parallel version of scan uses following:
https://github.com/alxndrTL/mamba.py/blob/main/pscan.py#L37
I am wondering if you have any suggestions how to debug it?
Hi, great work!
How to enable cuda because I found:
if self.config.use_cuda:
try:
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # I did not find mamba_ssm in this repo
self.selective_scan_cuda = selective_scan_fn
except ImportError:
print("Failed to import mamba_ssm. Falling back to mamba.py.")
self.config.use_cuda = False
Looking forward to your response.
Thanks !
I am going through the code line by line and adding additional comments with shape information. Inside the MambaBlock's forward function, duplicated below:
def forward(self, x):
# input x : (B, L, D)
# return : (B, L, D)
_, L, _ = x.shape
# in_proj: D --> 2*ED (two or three branches)
xz = self.in_proj(x) # (B, L, 2*ED)
x, z = xz.chunk(2, dim=-1) # (B, L, ED), (B, L, ED)
# x branch
# rearrange(x, "... ED L -> ... L ED")
x = x.transpose(1, 2) # (B, ED, L)
# What is the point of convolution?
x = self.conv1d(x)[:, :, :L] # depthwise convolution over time, with a short filter
# rearrange(x, "... L ED -> ... ED L")
x = x.transpose(1, 2) # (B, L, ED)
# x --> conv1d --> silu --> ssm --> y ---> output (y*z) --> (B, L, ED)
# ^ /
# z -------------> silu ---- |---------/
x = F.silu(x)
y = self.ssm(x, z) # (B, L, ED), (B, L, ED) --> (B, L, ED)
print(f"Return from self.ssm, {y.shape=}")
# GE: why the early exit if using CUDA?
if self.config.use_cuda: ######<<<<<<<<
output = self.out_proj(y) # (B, L, D)
return output
# z branch
z = F.silu(z) # (B, L, ED)
# Why multiply y * z?
output = y * z
# ED -> D
output = self.out_proj(output) # (B, L, D)
return output # (B, L, D)
there is the conditional:
if self.config.use_cuda:
Depending on whether cuda is used or not, the output of the forward() method is different. If cuda is not used, there an additional F.silu
applied to z
followed by output=y*z
.
If this is not an error, could you please explain the reason for this apparent discrepancy? Thanks.
Gordon
Segmentation fault: 11 while inferencing mamba with mlx
https://github.com/alxndrTL/mamba.py/tree/main/mlx
python3 generate.py --prompt="Mamba is a type of" --hf_model_name="state-spaces/mamba-130m" --n_tokens=100
on an Apple M1 Pro
I found that the line that cause the error is
mlx_weights = torch.zeros(channels, kernel_size, channels)
in functiontorch_to_mlx_depthwise_weights
but I don't know how to fix it
Hey! I ported mamba
to transformers
and think your approach to replace the naive scan would be great there!
Would you like to open a PR? 🤗 (to https://github.com/huggingface/transformers)
Hello. In the function selective_scan_seq
, there are two points that I am confused:
BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
h = deltaA[:, t] * h + BX[:, t]
However, in the paper, the equation is
Both terms in the right side of the euation is performed in matrix multiplication.
I am curious that do the two lines of code use some tricks to convert matrix multiplication into the elementwise one?
The 1.4B model takes 10-11GB of RAM at inference. (own test, M2 Pro 16GB)
The 2.8B model takes around 50GB at inference. (https://twitter.com/awnihannun/status/1749515431336112275)
This is not due to loading the model from HF (same memory footprint if model initialized with random weights).
This is neither due to the ssm_step
.
However, turning off the convolution at inference reduces the memory footprint (by 3GB for the 1.4B model : from 10GB to around 7GB). It also greatly speeds up the inference. (buf of course, the forward is not correct).
Files concerned :
mamba_mlx.py
(step
functions)misc.py
The depthwise conv implemented in misc.py
seems to be part of the problem.
As said the file, the PyTorch versions uses groups=channels (true depthwise), while the MLX depthwise conv in misc.py
uses groups=1 but with some weights set at 0. (only workaround found).
This result in a (d_model, 4, d_model) filter size, against (d_model, 4) for the "true" depthwise conv.
Either :
-wait for MLX to implement groups=channels for conv1d
-find another workaround (one possibility is to create d_model
conv object, each with 1 input and 1 output channel. but this result in a big for loop which is around 45x slower than the workaround found. but ofc, memory usage is greatly reduces (by d_model
=2560)
i find that the pscan method used in this Mamba implementation use huge amount of memory! any idea how to reduce memory consumption? or replace the pscan method with other implementation??
great thanks!
Would it be possible to put an explicit OSS license on the codebase just to remove approval burden when experimenting based on this code? I want to prototype a bit on laptop to get the plumbing right before moving to cloud GPUs to actually train and this looks like the best CPU friendly implementation I have found for that purpose
Hi @alxndrTL ,
I am trying to generate onnx file for the forward pass here: https://github.com/alxndrTL/mamba.py/blob/main/mamba.py#L69
The issue is that the onnx export can not handle in-place computation, and thus will skip or generate incorrect output if there is a in-place computation. I am wondering how to modify the pscan so that all the computations happen in out-of-place manner?
from commit hash 6a49341:
What I have done was executing generate.py for a mamba fine-tuned model - kuotient/mamba-ko-2.8b
, and below error was happened. How can I deal with this error?
(venv_mamba_py) ******@Mac-Studio-2022-01 scripts % python generate.py --prompt="Mamba is a type of" --hf_model_name="kuotient/mamba-ko-2.8b" --n_tokens=100
Traceback (most recent call last):
File "/Users/******/test/mamba.py/mlx/scripts/generate.py", line 31, in <module>
model = MambaLM.from_pretrained(args.hf_model_name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/******/test/mamba.py/mlx/mamba_lm_mlx.py", line 150, in from_pretrained
mlx_state_dict = map_mambassm_torch_to_mlx(state_dict)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/******/test/mamba.py/mlx/utils.py", line 53, in map_mambassm_torch_to_mlx
return map_mambapy_torch_to_mlx(new_state_dict)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/******/test/mamba.py/mlx/utils.py", line 37, in map_mambapy_torch_to_mlx
new_state_dict[key] = value.numpy()
^^^^^^^^^^^^^
TypeError: Got unsupported ScalarType BFloat16
My environments are:
Hello,
Thank you for the amazing work!
I had a question regadring the latest commit. If the following dot product was added to the ssm
function, is there a a specific reason why it was not in the ssm_step
function as well?
Moreover the softplus in the ssm
function uses the bias of dt_proj
, while in ssm_step
uses the previous implementation delta = F.softplus(self.dt_proj(delta))
Lastly, should D here have a _no_weight_decay
simlar to A_log
?
Edit: If the new modifications to delta was not added to ssm_step
intentionally, does that mean that during inference I have to use the step
and can not use forward
as well?
Edit2: If in the forward function, I do a birectional forward (as in Vision Mamba), such as:
output = self.mamba_block(self.norm(x))
x_flip = x.flip([1])
output_flip = self.mamba_block(self.norm(x_flip))
output += output_flip
return output + x
However, in the step, a each directional forward will return its own cache, which I am not sure how to handle exactly as unfortunately, I do not fully understand the cache mechanisim (h,input)
.
Apologies for the long question.
Thank you!
Support for batch size > 1, how is the progress of the work?
Hello, and thanks for sharing your code. I stumbled on your repo while looking for how to implement mup for mamba. It seems like you implemented mup without scaling any attn-like matrices. Does that mean that mup work withs mamba out of the box as long as the right initializations (from mup package for example) are implemented?
Thanks for your help.
Hi,
Thanks for the repo! This is really useful!
Hey! Awesome work on this project! I know it's not technically vanilla Mamba but I've been trying to convert the new SSM-Transformers Jamba into MLX for more efficient training and usability but am having a difficult time. My specialty is in the training/datasets world and not the strongest in the core math behind the model architectures beyond the basic implementations.
Would somebody know of an easier way to get Jamba converted into MLX? I truly think Jamba has A LOT to offer and could do some awesome stuff in the MLX format and for local model training with Mac
I've provided the modeling script released by AI21 for quick reference. Is this feasible or just way too complicated at the moment?
I'm sorry that I'm a fresh, so I can't get the model running smoothly.
When I start the 'example_llm.ipynb', I get the error like this
and then I tried to copy the same code to .py file, I get the error like
the same question had been met in 'mamba-minimal', I don't know how to solve it.Bless for help, thank you!
I am looking at the file pscan.py, and see:
class PScan(torch.autograd.Function):
@staticmethod
def pscan(A, X):
# A : (B, D, L, N)
# X : (B, D, L, N)
# modifies X in place by doing a parallel scan.
# more formally, X will be populated by these values :
# H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
# which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)
# only supports L that is a power of two (mainly for a clearer code)
B, D, L, _ = A.size()
num_steps = int(math.log2(L))
# up sweep (last 2 steps unfolded)
Aa = A
Xa = X
for _ in range(num_steps-2):
T = Xa.size(2)
Aa = Aa.view(B, D, T//2, 2, -1)
Xa = Xa.view(B, D, T//2, 2, -1)
Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])
Aa = Aa[:, :, :, 1]
Xa = Xa[:, :, :, 1]
...
This function is not using MLX. Where exactly is the parallelism? I must not understand because the number of operations is O(T) and is fully sequential. I am sure I am missing something simple.
Gordon
Thank you.
state-spaces/mamba is not working due to CUDA required.
But this one you created is working successfully. Thank you!!! 😄
Hi,
The value of A is very large after discretization.
deltaA = torch.exp(delta.unsqueeze(-1) * A)
The big value makes the loss NaN.
I also found the similar problem in the original mamba repo, but I can't find the solution.
I have try the ZOH discretization to avoid the exp function, but it still exits.
Do you know how to solve it ?
Thank you.
Here is a section of code in JambaLM
class Jamba(nn.Module):
def __init__(self, config: JambaLMConfig):
super().__init__()
self.config = config
# init each model layer, decide if it's mamba/attention and has experts or not
decoder_layers = []
for i in range(config.n_layers):
is_attn = (
True
if (i - self.config.attn_layer_offset) % self.config.attn_layer_period
== 0
else False
)
is_expert = (
True
if (i - self.config.expert_layer_offset)
% self.config.expert_layer_period
== 0
else False
)
You'll notice that the structure of is_attn
and is_expert
is identical. Furthermore, in the default configuration provided, is_attn=is_expert=False
, and they are both true at the same time. As a result, all the layers in this default Jamba architecture are all the same. Of course I can change that, but this is surely not intended given that this code is didactic. Thanks.
Hello, does this project have similar training functions to llama2.c?
Please rectify the path
from example_src.tinyhome import TinyHomeEngineV1, print_grid, print_act
from example_src.buffer import ReplayBuffer
to
from examples.tinyhome import TinyHomeEngineV1, print_grid, print_act
from examples.buffer import ReplayBuffer
Thank you for providing such excellent code, I have a question for you, how is the cache of step function used in the following code block? What does it do? Can you give an example of how it works? Thank you.
class Mamba2(nn.Module):
def __init__(self, config: Mamba2Config):
super().__init__()
self.config = config
self.layers = nn.ModuleList([ResidualBlock(config) for _ in range(config.n_layers)])
def forward(self, x):
# x : (B, L, D)
# y : (B, L, D)
for layer in self.layers:
x = layer(x)
return x
def step(self, x, caches):
# x : (B, L, D)
# caches : [cache(layer) for all layers], cache : (h, inputs)
# y : (B, L, D)
# caches : [cache(layer) for all layers], cache : (h, inputs)
for i, layer in enumerate(self.layers):
x, caches[i] = layer.step(x, caches[i])
return x, caches
Thank you for sharing your fantastic work.
We have noticed the image that with rising the dimension of d_state
, the mamba's time occupation doesn't rise.
However, we found in code that writes (selective_scan_fwd_kernel.cuh#163):
for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
...
if constexpr (kIsVariableB) {
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
}
}
which shows a for loop with related to state_idx that reads from HBM to shared memory.
Then I tested the speed again and finds that with the d_state rises, the time occupation of mamba rises linearly, which is aligned with the code.
device = torch.device("cuda")
dtype = torch.float32
B, L, G, D, N, R = 3, 4096, 4, 192, 16, 192 // 16
xi = torch.randn((B, G * D, L), device=device, dtype=dtype)
Ai = torch.randn((G * D, N), device=device, dtype=dtype)
Di = torch.randn((G * D), device=device, dtype=dtype)
dti = torch.randn((B, G * D, L), device=device, dtype=dtype)
Bi = torch.randn((B, G, N, L), device=device, dtype=dtype)
Ci = torch.randn((B, G, N, L), device=device, dtype=dtype)
tpb = torch.randn((G * D), device=device, dtype=dtype)
Ai2 = torch.randn((G * D, 4*N), device=device, dtype=dtype)
Bi2 = torch.randn((B, G, 4*N, L), device=device, dtype=dtype)
Ci2 = torch.randn((B, G, 4*N, L), device=device, dtype=dtype)
import time
tim0 = time.time()
for _ in range(1000):
y = selective_scan_fn(xi, dti, Ai, Bi, Ci, Di, tpb, True)
torch.cuda.synchronize()
torch.cuda.empty_cache()
tim1 = time.time()
for _ in range(1000):
y = selective_scan_fn(xi, dti, Ai2, Bi2, Ci2, Di, tpb, True)
torch.cuda.synchronize()
torch.cuda.empty_cache()
tim2 = time.time()
print(tim1-tim0, tim2-tim1, torch.cuda.max_memory_allocated()) # 0.7172577381134033 2.400775194168091 185063424
time.sleep(100000)
So what did I miss?
I found an unnecessary check in the In the get_data()
function of mamba_lm
, you clearly ensure that there are no partial batches.
In the training loop (below the get_batch()
function), you check whether there are partial batches, and skip in case there is. This check is unnecessary.
Here is the relevant code. First get_batch()
:
def get_batch(
data: Float[T, " B Examples"], seq_len: int, idx: int
) -> tuple[Float[T, "B SeqLen"], Float[T, "B SeqLen"]]:
"""Retrieve a single batch"""
src = data[:, idx : idx + seq_len] # noqa: E203
target = data[:, idx + 1 : idx + seq_len + 1] # noqa: E203
return src, target
where the batch size is always the same. There are no partial batches. Below is the extraneous check:
# If the batch is not complete - skip
### The batch is always complete
if logits.view(-1, logits.size(-1)).shape[0] != output.view(-1).shape[0]: # <<<< UNNECESSARY
print("skip")
else:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), output)
avg_loss += loss.item()
optim.zero_grad()
loss.backward()
Thanks a lot for your work and minimal implementation!
For work, I need to implement some models to benchmark, and I really want to include mamba-related models.
To do so, I created Jimmy (for Jax Image Model :) really not prod-ready at all yet) https://github.com/clementpoiret/jimmy
But porting the CUDA code into something that can be compiled by XLA is just beyond what I can do rn.
With credits of course, may I port your pscan to Jimmy?
Thanks!
Hi Alex,
I have tried to generate the onnx file for the inference,
as follows in generate function.
I called onnx export inside generate function here: https://github.com/alxndrTL/mamba.py/blob/main/mamba_lm.py#L134
as follows:
torch.onnx.export(model, inputids, "mamba.onnx", opset_version=12)
It is throwing me an error
TypeError: MambaLM.forward() missing 1 required positional argument: 'tokens'.
Any idea how can I generate onnx file? Is there a better way of generating onnx file for inference?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.