Code Monkey home page Code Monkey logo

mamba-minimal's Introduction

mamba-minimal

Simple, minimal implementation of Mamba in one file of PyTorch.

Featuring:

  • Equivalent numerical output as official implementation for both forward and backward pass
  • Simplified, readable, annotated code

Does NOT include:

  • Speed. The official implementation is heavily optimized, and these optimizations are core contributions of the Mamba paper. I kept most implementations simple for readability.
  • Proper parameter initialization (though this could be added without sacrificing readability)

Demo

See demo.ipynb for examples of prompt completions.

from model import Mamba
from transformers import AutoTokenizer

model = Mamba.from_pretrained('state-spaces/mamba-370m')
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

generate(model, tokenizer, 'Mamba is the')

Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)

150 meters... 🫢 scary!

References

The Mamba architecture was introduced in Mamba: Linear-Time Sequence Modeling with Selective State Spaces by Albert Gu and Tri Dao.

The official implementation is here: https://github.com/state-spaces/mamba/tree/main

mamba-minimal's People

Contributors

eltociear avatar johnma2006 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

mamba-minimal's Issues

About selective_scan

Hi, great work!
Could you please explain why in selective_scan the "x = torch.zeros((b, d_in, n), device=deltaA.device)"?
In addition, I am confusing on u and x.

Thanks.

about causal conv

I note that the convolution used in the original paper is causal convolution, but I don't seem to see an implementation of causal convolution in this project.
A common grouping convolution is used in model.py. I wonder if this is correct

Missing LICENSE

Nice work! I'm considering porting this to MLX, but I wonder if I can copy parts of your code. Could you please add a license to this repo?

Thanks

Memory issue due to A and B matrix computation

Hi,
Thanks for providing the Mamba implementation. I would like to know if there is any workaround in the efficient computation of deltaA and deltaB_u that can avoid the GPU memory running out issue. The following are the parameters I used to create the Mamba instance:

d_model: 1024
n_layer: 4   
d_state: int = 1024
expand: int = 2

The other parameters are set to their default values.

It results in a model of ~60M parameters. However, I run out of memory (max GPU memory= 24 GB) when I train with a batch size of 256 or even as low as 64 and this probably happens due to large matrix computations for deltaA and deltaB_u.

Discretization of `B`

Thanks for the clear implementation!

Can you explain the discretization of $B$ in selective scan?

Equation 4 in section 2 of the paper states $$\overline{B} = (\Delta A)^{-1} (exp((\Delta A) - I) \cdot \Delta B$$

In your implementation, the input is mapped into the hidden state by the following:

 deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b d_in l n')

which if I understand correctly, implies that $\overline{B} = \Delta B$?

Netwok Breakdone

Sorry, i can`t connect to huggingface directly.

Maybe i could load the pre-trainned weights loaclly?(If someone could TEACH me!!!Please!!!)

General Question: Why is self.in_proj have an expansion again?

I understand that d_inner is d_model * expansion (E=2) . But why is self.in_proj = nn.Linear(args.d_model, args.d_inner * 2 ...).

Why is the in projection expanded a second time by 2 ?

I can't seem to find the answer in the appointed paper section 3.4.

Any clarification would be appreciated.

Setting MPS backend causes mamba-minimal to generate gibberish and crash on M1 Mac

Trying to run it on MPS backend on Mac M1 Max 64Gb modifying demo.ipynb like this:


Changes to notebook:

device = 'mps'

model = Mamba.from_pretrained(pretrained_model_name).to(device)

input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)


Getting gibberish and dead python kernel in dozen of tokens:

for t in generate(model, tokenizer, 'Mamba is the'): print(t, end='')

ocardial goal fibrobl ( triglycer ા judgement extraordinacebook CURIAM

and then it crashes.

Parameter Initialization

Hey! Thanks so much for putting this together. You mention that the parameter initialization isn't correct, can you point me towards what it should be? I've tried looking at the official repo and didn't see any out-of-the-ordinary stuff.

version

I have tried several versions of Transformers, but none of them work. May I ask what version of Torch and Transformers the blogger has installed

Is it recurrent?

Nice implementation!

I thought that Mamba was somewhat recurrent, like keeping an internal state and then outputting one token at a time. But your code shows that for each new output token, the entire sentence must be passed to the model, including the last outputted token.

Is this the only mode of operation?

It also has the l parameter (sequence length). Does it mean it has a maximum sequence length? The paper shows up to 1 million, so I was expecting it to be recurrent and without limit

From the paper:

Table 2: (Induction Heads.) Models are trained on sequence length 2^8 = 256, and tested on increasing sequence lengths of 2^6 = 64 up to 2^20 = 1048576

It is not clear how they do that

Why is the implementation of Mamba so slow?

I haven't run the official version of Mamba, but I've run your implementation, and it seems that the training speed of this model is much slower than that of the Transformer.

Using cumsum instead of a for loop

There is a way to perform the selective scan with two cumulative sums or torch.cumsum, which is effectively like a parallel scan but supported by pytorch natively.

I made a minimal commit in my fork here PeaBrane@2908f50. The correctness and functionality are tested, and I could observe an inference speed up of ~14x on an A30. But not sure how close it is to the original impl with parallel scan still. More details are here.

If intersted, it would be nice if someone could review this change, and discuss whether this could be merged here, albiet the explicitness of the code may suffer (as I understand the repo is meant to be pedagogical).

Is use of LogA intentional?

Nice implementation! I notice a small detour below and wonder if it's necessary:
In this line you define self.A_log = nn.Parameter(torch.log(A)) which is only used here A = -torch.exp(self.A_log.float()) to exp it back. What's the reason for not defining A directly in the class parameter?

Does mamba minimal has gpu op optimization

Hi, thanks for the great work! I would like to know does mamba-minimal has GPU op optimization? It seems to me that it doesn't have. I want to train a large-scale mamba and am currently considering the original mamba and mamba-minimal.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.