Code Monkey home page Code Monkey logo

mamba's People

Contributors

albertfgu avatar deroholic avatar eltociear avatar epicfilemcnulty avatar gaxler avatar haileyschoelkopf avatar harboryuan avatar iamshubhamgupto avatar jason19970210 avatar jmercat avatar luislechugaruiz avatar tridao avatar wongboo avatar yair-schiff avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

mamba's Issues

Finetuning example?

I see good amount of focus on understanding how to perform full training of Mamba, but what about PEFT? Adapters/LoRA finetuning.

The base models are in fact "Ready" for a fine tune, however due to the architecture I would prefer to minimize trial and error and ask: what can I use to fine tune a pretrained Mamba model?

Do you have a minimal code example?

Query Regarding Mamba Model Performance Tuning

Dear Mamba Contributors,

I hope this message finds you well. I am in the process of utilising the Mamba state space architecture for a language modelling task and have been highly impressed with the innovative approach adopted in this project.

However, whilst implementing the pretrained Mamba models provided, I have observed anomalous behaviour concerning model stability during training. Despite following the recommended guidelines and ensuring that my system aligns with the specified prerequisites (i.e. PyTorch 1.12+, CUDA 11.6+, Linux-based system with NVIDIA GPU), the SSMs appear to be sensitive to the recurrent dynamics, leading to unpredicted fluctuations.

It is indicated within the 'Troubleshooting' section that mixed precision via PyTorch's AMP maintains model parameters in float32 and casts to half precision when necessary. Nevertheless, might there be an alternative approach or further suggestions you could extend to enhance the model's stability? It would be most beneficial if there were additional insights into configurations that might mitigate the aforestated stability issues.

Moreover, I am curious to enquire if there are any plans for updated releases or patches that could possibly offer improved robustness or address the described concerns.

I would like to express my gratitude for making such a groundbreaking model accessible to the public and for your commitment to advancing the field of machine learning. I look forward to your guidance on the matter and any subsequent versions of Mamba that may further polish its performance.

Thank you for your time and assistance.

Warm regards,

yihong1120

Wikitext pipeline

Hi, can you please share pipeline for the wikitext dataset. I found results with 16.3 for mamba and 18 (vs. 18.6 everywhere else) perplexity for the transformer baseline and can not reproduce it. Maybe there is something different in preprocessing etc. Could you provide any details on the preprocessing steps or hyperparameters used that may be different from the default? Understanding those differences could help me reproduce the results.

A poorman mamba code

I have very bad training loss for a simple mamba code implementation , may I know why ?

100%|██████████| 10/10 [03:03<00:00, 18.32s/it]

Epoch: 10, Training Loss: -6515147.2516, Validation Loss: -7471518.3141, Validation Perplexity: 0.0000

using mamba as an headless "vanilla RNN"

Thank you for your amazing work,

I'm trying to use Mamba as an drop-in replacement of some RNN in a encoder/decoder architecture, for that I've turned off the logits head. How should I properly get the hidden state during inference ? I'm a bit lost in the decoding logic.

Compared to self-attention, does Mamba not need multiple heads per layer? Or positional embeddings?

Hi,

Looking through the MambaLMHeadModel I don't see any corresponding notion of the 'head' that exists with MultiHead Attention. I'm curious what is the status of this in the Mamba world?

In MHA, I remember the transformer circuits blog article explained that each head reads its keys and queries from a particular lower-dimensional subspace (mediated by W_k and W_q) of the previous output representation, and writes to a particular subspace (mediated by W_v).

Also, I didn't see any mention in the paper or the code of the use of positional embeddings - is this notion obsolete? At least with transformers, it seems that positional embeddings are essential for the lower head in the Induction Heads to be able to move information forward by one position in a content-agnostic way. How does Mamba achieve this if it doesn't use positional embeddings? Or, does it use a different mechanism to solve the in-context learning task?

Best,

Henry

Problems installing the packages

I am using an aws ec2 instance

Distributor ID: Ubuntu
Description: Ubuntu 20.04.6 LTS
Release: 20.04

When i tried to run the command
pip install causal-conv1d i got an error message saying a module called 'packaging' is missing.

the entire error message can be found below:
Screenshot 2023-12-07 at 1 51 49 PM

Then i tried to install from the repository, but have the same error.

Error running benchmark script: Expected B.stride(-1) == 1 to be true, but got false

I get this error running benchmark script:

File ~/personal/transformer-experiments/env/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py:37, in SelectiveScanFn.forward(ctx, u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
     35     C = rearrange(C, "b dstate l -> b 1 dstate l")
     36     ctx.squeeze_C = True
---> 37 out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
     38 ctx.delta_softplus = delta_softplus
     39 ctx.has_z = z is not None

RuntimeError: Expected B.stride(-1) == 1 to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

Env:

Linux, cuda 12.1

torch==2.1.0
triton==2.1.0
mamba-ssm==1.0.1
causal-conv1d==1.0.1

deriving embeddings

What would be the best way to derive embeddings from mamba models? Is there a straightforward approach or would we need a new architecture?

bfloat16 overflow during training session

  1. I tried vanilla pytorch training loop using bfloat16, the loss got overflow, https://github.com/mesolitica/malaya/blob/5.1/pretrained-model/mamba/causallm-130m-bf16.ipynb
  2. so I tried vanilla pytorch training loop using fp32, the loss is ok, https://github.com/mesolitica/malaya/blob/5.1/pretrained-model/mamba/causallm-130m-fp32.ipynb
  3. I thought maybe because no gradient clipping and etc, so I tried using HuggingFace trainer with Deepspeed, https://github.com/mesolitica/malaya/blob/5.1/pretrained-model/mamba/causallm-130m-trainer-deepspeed-bf16.ipynb, the loss got overflow.
  4. so I removed deepspeed and use fp32 in HuggingFace trainer, https://github.com/mesolitica/malaya/blob/5.1/pretrained-model/mamba/causallm-130m-trainer-fp32.ipynb, the loss is ok.

If bfloat16 is not working, deepspeed is not going to work.

selective_scan_ref deltaB_u doesn't follow Eqn 4 in paper

Hi,

It seems like the reference implementation doesn't correctly implement equation 4, specifically deltaB.
While the computation for deltaA seems to correspond with the paper eqn 4:

deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))

Eqn 4:

image

The section for deltaB_u in the code is just:

    else:
        if B.dim() == 3:
            deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
        else:
            B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
            deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
    if is_variable_C and C.dim() == 4:
        C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
    last_state = None
    for i in range(u.shape[2]):
        x = deltaA[:, :, i] * x + deltaB_u[:, :, i]

Maybe I missed something more subtle but I couldn't find it.

Thanks in advance for any help you could provide. Mostly, I'm just trying to figure out the indexing since the formula in Eqn 4 looks a lot like matmuls but it is not. So, I am not sure quite what the exponentiation means.

Thanks very much,

Henry

Apptainer definition

hi! thanks for your great models. Since getting it running can be a challenge, I wrote an Apptainer definition file like so, which makes it very to run on HPC sites (the ones I use frequently support Singularity/Apptainer)

BootStrap: docker
From: pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel

%post
    pip install --upgrade pip
    pip install packaging matplotlib tqdm
    pip install 'causal-conv1d<=1.0.2'
    pip install mamba-ssm

save as mamba.def and run with

$ apptainer build -F mamba.sif mamba.def && apptainer run --nv ./mamba.sif python3 foobar.py

the only downside is that the resulting file is large, but I think this is unavoidable given the requirement to have the full CUDA toolchain.

I think a similar approach for Docker is straightforward

FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel
RUN pip install --upgrade pip \
 && pip install packaging matplotlib tqdm \
 && pip install 'causal-conv1d<=1.0.2' \
 && pip install mamba-ssm

These might be useful to have in the repo for other users.

Minimal reference implementation

Thanks so much for providing this code; looks very useful and reproducible.

As I understand, the custom scan kernel can be quite important to performance considerations, so it is great to see it here as well.

However, as a suggestion, I think itd be super neat to also see a minimal Mamba reference implementation, with minimal dependencies, simply for clarity of exposition; something that could be unit tested to behave the same at least on small datasets, as the custom kernel. Would that be a lot of work? Does it already exist somewhere? If a torch version exists id be happy to port it to a JAX version as well.

Import error after installing packages

Hello, after installing packages with the following script

pip install torch transformers datasets triton einops
git clone https://github.com/state-spaces/mamba.git
cd mamba && pip install causal-conv1d && pip install mamba-ssm

When installing successfully finished I run benchmark from the example, to ensure everything is ok:

python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5

And get

ImportError: /usr/local/lib/python3.10/dist-packages/causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK3c1017SymbolicShapeMeta18init_is_contiguousEv

I am using nvcr.io/nvidia/pytorch:23.11-py3 docker image and a100-40gb gpu, more precisely:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.182.03   Driver Version: 470.182.03   CUDA Version: 12.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-PCI...  On   | 00000000:01:00.0 Off |                    0 |
| N/A   32C    P0    46W / 250W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

Can you guide me on how to solve this issue?

CUDA version

Thanks for the great work!

I notice that

Is it possible to provide CUDA 11.6 support for future mamba versions?

I understand I can build from source with pip install . from this repo, but I'm worried it would be unstable.

fine_tune to a text classification task

I am trying to get mamba working for a text classification task by adding a classification head after the model.

For transformer models, people usually the last_hidden_state as the input to the classification head, any suggestions for mamba?

Minimalist version of Mamba

I'm looking for a very simple version for Mamba. And a teaching video of implementation will be great. I know professors are very busy, but are there any enthusiastic netizens to do tutorials?

ImportError causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so undefined symbol

I am encountering a strange error, while using Mamba with torch 2.1.1+cu118 on Linux.
In another environment with torch 1.13.1+cu116 the same code works fine.

ImportError: /home/user1/.conda/envs/2prod/lib/python3.10/site-packages/causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda20CUDACachingAllocator12recordStreamERKNS_7DataPtrENS0_10CUDAStreamE

Package Version


absl-py 1.4.0
aiohttp 3.9.1
aiosignal 1.3.1
annotated-types 0.6.0
antlr4-python3-runtime 4.9.3
anyio 3.7.1
appdirs 1.4.4
arrow 1.2.3
async-timeout 4.0.3
attrs 23.1.0
backoff 2.2.1
beautifulsoup4 4.12.2
binaryornot 0.4.4
biopython 1.81
blessed 1.19.1
boto3 1.26.106
botocore 1.29.106
Brotli 1.1.0
brotlipy 0.7.0
build 1.0.3
CacheControl 0.13.1
cachetools 5.3.0
causal-conv1d 1.0.0
certifi 2023.11.17
cffi 1.16.0
chardet 5.1.0
charset-normalizer 3.3.2
cleo 2.1.0
click 8.1.3
cmake 3.25.0
colorama 0.4.6
contourpy 1.0.7
cookiecutter 2.1.1
crashtest 0.4.1
croniter 1.4.1
cryptography 41.0.7
cycler 0.11.0
dataclasses 0.8
datasets 2.8.0
dateutils 0.6.12
deepdiff 6.7.1
deepspeed 0.12.4
dill 0.3.6
distlib 0.3.7
docker-pycreds 0.4.0
dulwich 0.21.7
einops 0.6.0
exceptiongroup 1.2.0
fastapi 0.104.1
fastjsonschema 2.19.0
filelock 3.13.1
fonttools 4.38.0
frozenlist 1.4.0
fsspec 2023.12.1
gitdb 4.0.10
GitPython 3.1.30
gmpy2 2.1.2
google-auth 2.17.1
google-auth-oauthlib 1.0.0
GPUtil 1.4.0
grpcio 1.53.0
h11 0.14.0
hjson 3.1.0
huggingface-hub 0.16.4
hydra-core 1.3.1
idna 3.6
importlib-metadata 7.0.0
inquirer 3.1.4
installer 0.7.0
itsdangerous 2.1.2
jaraco.classes 3.3.0
jeepney 0.8.0
Jinja2 3.1.2
jinja2-time 0.2.0
jmespath 1.0.1
joblib 1.2.0
keopscore 2.1.1
keyring 24.3.0
kiwisolver 1.4.4
lightning 2.1.2
lightning-cloud 0.5.57
lightning-utilities 0.10.0
mamba-ssm 1.0.1
Markdown 3.4.3
markdown-it-py 2.1.0
MarkupSafe 2.1.3
matplotlib 3.6.3
mdurl 0.1.2
more-itertools 10.1.0
mpmath 1.3.0
msgpack 1.0.7
multidict 6.0.4
multiprocess 0.70.14
munkres 1.1.4
networkx 3.2.1
ninja 1.11.1.1
numpy 1.26.2
oauthlib 3.2.2
omegaconf 2.3.0
opt-einsum 3.3.0
ordered-set 4.1.0
orjson 3.9.10
packaging 23.2
pandas 1.5.3
pathtools 0.1.2
patsy 0.5.3
pexpect 4.8.0
Pillow 9.4.0
pip 23.3.1
pkginfo 1.9.6
platformdirs 3.11.0
poetry 1.7.1
poetry-core 1.8.1
poetry-plugin-export 1.6.0
prettytable 3.9.0
prettyTables 1.1.5
protobuf 3.20.3
psutil 5.9.4
ptyprocess 0.7.0
py-cpuinfo 9.0.0
pyahocorasick 2.0.0
pyarrow 10.0.1
pyasn1 0.4.8
pyasn1-modules 0.2.8
pybedtools 0.9.0
pybind11 2.10.3
pycparser 2.21
pydantic 2.1.1
pydantic_core 2.4.0
Pygments 2.14.0
PyJWT 2.8.0
pykeops 2.1.1
pynvml 11.5.0
pyOpenSSL 23.3.0
pyparsing 3.0.9
pyproject_hooks 1.0.0
pysam 0.20.0
PySocks 1.7.1
python-dateutil 2.8.2
python-editor 1.0.4
python-multipart 0.0.6
python-slugify 8.0.1
pytorch-lightning 2.1.1
pytz 2022.7.1
PyYAML 6.0.1
rapidfuzz 3.5.2
readchar 4.0.5.dev0
regex 2022.10.31
requests 2.31.0
requests-oauthlib 1.3.1
requests-toolbelt 1.0.0
responses 0.18.0
rich 13.2.0
rsa 4.9
s3transfer 0.6.0
sacremoses 0.0.53
safetensors 0.3.3
scikit-learn 1.2.2
scipy 1.10.0
seaborn 0.13.0
SecretStorage 3.3.3
sentencepiece 0.1.97
sentry-sdk 1.14.0
seqeval 1.2.2
setproctitle 1.3.2
setuptools 68.2.2
shellingham 1.5.4
six 1.16.0
sklearn 0.0.post2
smmap 5.0.0
sniffio 1.3.0
soupsieve 2.5
starlette 0.27.0
starsessions 1.3.0
statsmodels 0.13.5
sympy 1.12
tensorboard 2.12.1
tensorboard-data-server 0.7.0
tensorboard-plugin-wit 1.8.1
tensorboardX 2.6
text-unidecode 1.3
threadpoolctl 3.1.0
tokenizers 0.13.3
tomli 2.0.1
tomlkit 0.12.3
torch 2.1.1+cu118
torchaudio 2.1.1+cu118
torchmetrics 1.2.1
torchvision 0.16.1+cu118
tqdm 4.66.1
traitlets 5.14.0
transformers 4.33.3
triton 2.1.0
trove-classifiers 2023.11.29
types-python-dateutil 2.8.19.14
typing_extensions 4.8.0
tzdata 2023.3
unicodedata2 15.1.0
urllib3 1.26.0
uvicorn 0.24.0.post1
virtualenv 20.25.0
wandb 0.13.9
wcwidth 0.2.12
websocket-client 1.7.0
websockets 12.0
Werkzeug 2.2.3
wheel 0.42.0
xxhash 3.2.0
yarl 1.9.3
zipp 3.17.0

I tried setting up the environment several times, and it would be great if I could use torch > 2.0.

Thanks for having a look!

Positional Encoding

Does the mamba model need any kind of positional encodings? My understanding based on the code and paper is that no position encoding is needed due to the recurrent nature. However I tried adding positional encoding for a ViT like model using Mamba only blocks and accuracy improved by 3-4% - still exploring exactly what is causing this.

Error in Running Benchmark Command

Greetings! Thanks for your great work! When I tried the benchmark code, I met the error below. Could you please share some possible solutions?

python benchmarks/benchmark_generation_mamba_simple.py --model-name "/home/x/VisionProjects/mamba/ckpts/mamba-130m" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
Loading model /home/x/VisionProjects/mamba/ckpts/mamba-130m
Number of parameters: 129135360
Traceback (most recent call last):
  File "<string>", line 21, in _layer_norm_fwd_1pass_kernel
KeyError: ('2-.-1-.-0-+-2-c-3-2-f-4-3-9-9-9-83ca8b715a9dc5f32dc1110973485f64-45375ed7aa3bacaed5f41dca33dc8ee0-6590aa19b3e9909e5c8a7254fb3b9328-e6da1445790e1250a9b68f17efc2dd18-7f2d2fed060f2e0fa46ef4e19e20c865-e1f133f98d04093da2078dfc51c36b72-056bca445a91d3175375bc8481ed1689-0db1785b8dc43452c61ef6d926ec11bb-6aff3b6e239e435b817994e60abc8cef', (torch.float16, torch.float16, torch.float16, None, None, torch.float32, None, torch.float32, 'i32', 'i32', 'i32', 'i32', 'i32', 'fp32'), (True, 1024, False, True, False), (True, True, True, (False,), (False,), True, (False,), True, (True, False), (True, False), (True, False), (True, False), (True, False), (False,)), 1, 2, False)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "benchmarks/benchmark_generation_mamba_simple.py", line 77, in <module>
    out = fn()
  File "benchmarks/benchmark_generation_mamba_simple.py", line 54, in <lambda>
    fn = lambda: model.generate(
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/utils/generation.py", line 218, in generate
    output = decode(
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/utils/generation.py", line 127, in decode
    model._decoding_cache = update_graph_cache(
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/utils/generation.py", line 311, in update_graph_cache
    cache.callables[batch_size, decoding_seqlen] = capture_graph(
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/utils/generation.py", line 345, in capture_graph
    logits = model(
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 221, in forward
    hidden_states = self.backbone(input_ids, inference_params=inference_params)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 152, in forward
    hidden_states, residual = layer(
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/modules/mamba_simple.py", line 341, in forward
    hidden_states, residual = fused_add_norm_fn(
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn
    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/ops/triton/layernorm.py", line 411, in forward
    y, mean, rstd, residual_out = _layer_norm_fwd(
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd
    _layer_norm_fwd_1pass_kernel[(M,)](
  File "/opt/conda/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 77, in run
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 77, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 65, in _bench
    return do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8))
  File "/opt/conda/lib/python3.8/site-packages/triton/testing.py", line 146, in do_bench
    fn()
  File "/opt/conda/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 63, in kernel_call
    self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
  File "<string>", line 41, in _layer_norm_fwd_1pass_kernel
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 1687, in compile
    return CompiledKernel(fn, so_path, metadata, asm)
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 1700, in __init__
    mod = importlib.util.module_from_spec(spec)
  File "<frozen importlib._bootstrap>", line 556, in module_from_spec
  File "<frozen importlib._bootstrap_external>", line 1101, in create_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
ImportError: /root/.triton/cache/767259c163b96d4d22c0eea24dd36494/_layer_norm_fwd_1pass_kernel.so: undefined symbol: cuLaunchKernel

The dependencies and libraries are shown below:

Package                      Version
---------------------------- ------------
absl-py                      1.4.0
accelerate                   0.21.0
aiofiles                     23.1.0
aiohttp                      3.8.4
aiosignal                    1.3.1
alembic                      1.10.3
altair                       4.2.2
anthropic                    0.3.2
anyio                        3.6.2
appdirs                      1.4.4
argon2-cffi                  23.1.0
argon2-cffi-bindings         21.2.0
arrow                        1.2.3
asttokens                    2.4.0
astunparse                   1.6.3
async-lru                    2.0.4
async-timeout                4.0.2
attrs                        22.2.0
Babel                        2.12.1
backcall                     0.2.0
beautifulsoup4               4.12.2
bert-score                   0.3.13
bleach                       6.0.0
blessed                      1.20.0
BLEURT                       0.0.2
cachetools                   5.3.1
causal-conv1d                1.0.0
certifi                      2022.12.7
cffi                         1.15.1
charset-normalizer           3.1.0
click                        8.1.3
cloudpickle                  2.2.1
cmake                        3.26.1
colorama                     0.4.6
comm                         0.1.4
contourpy                    1.0.7
cycler                       0.11.0
databricks-cli               0.17.6
datasets                     2.11.0
debugpy                      1.8.0
decorator                    5.1.1
deepspeed                    0.9.5
defusedxml                   0.7.1
dill                         0.3.6
distlib                      0.3.6
distro                       1.8.0
docker                       6.0.1
docker-pycreds               0.4.0
einops                       0.6.0
entrypoints                  0.4
exceptiongroup               1.1.3
executing                    1.2.0
fastapi                      0.95.0
fastjsonschema               2.18.0
ffmpy                        0.3.0
filelock                     3.10.7
fire                         0.5.0
flash-attn                   2.0.4
Flask                        2.2.3
flatbuffers                  23.5.26
fonttools                    4.39.3
fqdn                         1.5.1
frozenlist                   1.3.3
fsspec                       2023.10.0
gast                         0.5.4
gitdb                        4.0.10
GitPython                    3.1.31
google-auth                  2.23.1
google-auth-oauthlib         1.0.0
google-pasta                 0.2.0
gpustat                      1.1.1
gradio                       3.50.2
gradio_client                0.6.1
greenlet                     2.0.2
grpcio                       1.53.0
gunicorn                     20.1.0
h11                          0.14.0
h5py                         3.9.0
hjson                        3.1.0
httpcore                     0.17.0
httpx                        0.24.0
huggingface-hub              0.19.4
idna                         3.4
importlib-metadata           6.4.1
importlib-resources          6.1.0
ipdb                         0.13.13
ipykernel                    6.25.2
ipython                      8.15.0
ipython-genutils             0.2.0
ipywidgets                   8.1.1
isoduration                  20.11.0
itsdangerous                 2.1.2
jedi                         0.19.0
Jinja2                       3.1.2
joblib                       1.2.0
json5                        0.9.14
jsonpointer                  2.4
jsonschema                   4.19.0
jsonschema-specifications    2023.7.1
jupyter                      1.0.0
jupyter_client               8.3.1
jupyter-console              6.6.3
jupyter_core                 5.3.1
jupyter-events               0.7.0
jupyter-lsp                  2.2.0
jupyter_server               2.7.3
jupyter_server_terminals     0.4.4
jupyterlab                   4.0.6
jupyterlab-pygments          0.2.2
jupyterlab_server            2.25.0
jupyterlab-widgets           3.0.9
keras                        2.14.0
kiwisolver                   1.4.4
libclang                     16.0.6
libretranslatepy             2.1.1
linkify-it-py                2.0.0
lit                          16.0.0
llvmlite                     0.39.1
lxml                         4.9.2
Mako                         1.2.4
mamba-ssm                    1.0.1
Markdown                     3.4.3
markdown-it-py               2.2.0
markdown2                    2.4.8
MarkupSafe                   2.1.2
matplotlib                   3.7.1
matplotlib-inline            0.1.6
mdit-py-plugins              0.3.3
mdurl                        0.1.2
mistune                      3.0.1
ml-dtypes                    0.2.0
mlflow                       2.2.2
mpmath                       1.3.0
msgpack                      1.0.5
multidict                    6.0.4
multiprocess                 0.70.14
nbclient                     0.8.0
nbconvert                    7.8.0
nbformat                     5.9.2
nest-asyncio                 1.5.8
networkx                     3.1
ninja                        1.11.1
nltk                         3.8.1
notebook                     7.0.3
notebook_shim                0.2.3
numba                        0.56.4
numpy                        1.23.5
nvidia-cublas-cu11           11.10.3.66
nvidia-cuda-cupti-cu11       11.7.101
nvidia-cuda-nvrtc-cu11       11.7.99
nvidia-cuda-runtime-cu11     11.7.99
nvidia-cudnn-cu11            8.5.0.96
nvidia-cufft-cu11            10.9.0.58
nvidia-curand-cu11           10.2.10.91
nvidia-cusolver-cu11         11.4.0.1
nvidia-cusparse-cu11         11.7.4.91
nvidia-ml-py                 12.535.108
nvidia-nccl-cu11             2.14.3
nvidia-nvtx-cu11             11.7.91
oauthlib                     3.2.2
openai                       0.27.4
opt-einsum                   3.3.0
orjson                       3.8.10
overrides                    7.4.0
packaging                    23.0
pandas                       2.0.0
pandocfilters                1.5.0
parso                        0.8.3
pathtools                    0.1.2
pexpect                      4.8.0
pickleshare                  0.7.5
Pillow                       9.3.0
pip                          23.3.1
platformdirs                 3.2.0
portalocker                  2.8.2
prometheus-client            0.17.0
prompt-toolkit               3.0.39
protobuf                     4.24.3
psutil                       5.9.4
ptyprocess                   0.7.0
pure-eval                    0.2.2
py-cpuinfo                   9.0.0
pyarrow                      11.0.0
pyasn1                       0.5.0
pyasn1-modules               0.3.0
pycparser                    2.21
pydantic                     1.10.7
pydub                        0.25.1
Pygments                     2.15.0
PyJWT                        2.6.0
pyparsing                    3.0.9
pyrsistent                   0.19.3
python-dateutil              2.8.2
python-json-logger           2.0.7
python-multipart             0.0.6
pytz                         2022.7.1
PyYAML                       6.0
pyzmq                        25.1.1
qtconsole                    5.4.4
QtPy                         2.4.0
querystring-parser           1.2.4
ray                          2.3.1
referencing                  0.30.2
regex                        2023.3.23
requests                     2.31.0
requests-oauthlib            1.3.1
responses                    0.18.0
rfc3339-validator            0.1.4
rfc3986-validator            0.1.1
rouge-score                  0.1.2
rpds-py                      0.10.3
rsa                          4.9
sacrebleu                    2.3.1
safetensors                  0.3.1
scikit-learn                 1.2.2
scipy                        1.10.1
seaborn                      0.12.2
semantic-version             2.10.0
Send2Trash                   1.8.2
sentencepiece                0.1.97
sentry-sdk                   1.19.1
setproctitle                 1.3.2
setuptools                   65.6.3
shap                         0.41.0
shortuuid                    1.0.11
six                          1.16.0
slicer                       0.0.7
smmap                        5.0.0
sniffio                      1.3.0
soupsieve                    2.5
SQLAlchemy                   2.0.9
sqlparse                     0.4.3
stack-data                   0.6.2
starlette                    0.26.1
svgwrite                     1.4.3
sympy                        1.11.1
tabulate                     0.9.0
tensor-parallel              1.2.0
tensorboard                  2.14.0
tensorboard-data-server      0.7.1
tensorboardX                 2.6
tensorflow                   2.14.0
tensorflow-estimator         2.14.0
tensorflow-io-gcs-filesystem 0.34.0
termcolor                    2.2.0
terminado                    0.17.1
tf-slim                      1.1.0
threadpoolctl                3.1.0
tinycss2                     1.2.1
tokenizers                   0.15.0
tomli                        2.0.1
toolz                        0.12.0
torch                        2.0.1+cu118
torchaudio                   2.0.1+cu118
torchvision                  0.15.1+cu118
tornado                      6.3.3
tqdm                         4.65.0
traitlets                    5.10.0
transformers                 4.35.2
translate                    3.6.1
triton                       2.0.0
typing_extensions            4.5.0
tzdata                       2023.3
uc-micro-py                  1.0.1
uri-template                 1.3.0
urllib3                      2.0.5
uvicorn                      0.21.1
virtualenv                   20.21.0
wandb                        0.14.2
wavedrom                     2.0.3.post3
wcwidth                      0.2.6
webcolors                    1.13
webencodings                 0.5.1
websocket-client             1.5.1
websockets                   11.0.1
Werkzeug                     2.2.3
wheel                        0.38.4
widgetsnbextension           4.0.9
wrapt                        1.14.1
xxhash                       3.2.0
yarl                         1.8.2
zipp                         3.15.0

inference error

I found that when prompt length is shorter than 4, then model can't inference bucause of an error "conv_state.copy_(x[:,:,-self.d_conv:])" in line 169 in file "mamba_simple.py"
maybe a bug to fix?

Name conflicts for mamba

This looks like a cool project, but the name seems problematic because of the pre-existing package manager Mamba. The Mamba SSM installation instructions currently suggest installing with pip, but Mamba SSM could very soon be available as a conda package, and then you'll have instructions like mamba install mamba-ssm, which seems very, very confusing!

Decouple from Triton

Congratulations team! This is a seriously impressive model.

Thank you also for sharing the weights and supporting the open-source community!

FYI - I was able to successfully fine-tune the 2.8B model and completely demolish squad, showing strong reasoning abilities.

I'm wanting to see how far we can push the inference speed of this model, but the triton dependency is preventing loading it on CPU or doing an ONNX conversion.

Do you have any advice to decouple this dependency or alternative approaches?

Any help would be appreciated!

ModuleNotFoundError: No module named 'lm_eval.api'

Hi,

I tried pip install lm_eval, but that package has no api module?

Which lm_eval should I install then?

MODEL_NAME=mamba-370m
MODEL_PATH=/media/hangyu5/Home/Documents/Hugging-Face/$MODEL_NAME
python evals/lm_harness_eval.py \
    --model mamba \
    --model_args pretrained=$MODEL_PATH \
    --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande \
    --device cuda \
    --batch_size 64 \
    tee lm_eval.txt
Traceback (most recent call last):
  File "/home/hangyu5/Documents/Git-repoMy/AIResearchVault/repo/Architecture/mamba/evals/lm_harness_eval.py", line 8, in <module>
    from lm_eval.api.model import LM
ModuleNotFoundError: No module named 'lm_eval.api'

Thanks!

optimal alphabet size for tokenization?

PS - is there a discord or discuss channel where this would be more appropriate?

In your paper you seem to use the same tokenizer (and hence alphabet size I assume) to compare Mamba models to Transformers. But, I wondered if you'd thought about implications of smaller or larger alphabet sizes?

Intuitively it seems like, given Mamba's ability to handler longer context, there is the possibility for a smaller alphabet. Obviously this is a tradeoff and slows down inference but I wondered if it would provide a win in perplexity and generalization.

how to compare mamba with flashattention2

In your paper, you mentioned that mamba scan is faster than flashattention2.
Does it mean comparing

class SelectiveScanFn(torch.autograd.Function):
with https://github.com/Dao-AILab/flash-attention/blob/9356a1c0389660d7e231ff3163c1ac17d9e3824a/flash_attn/flash_attn_interface.py#L432 ?
The inputs of these two modules are different, is this comparation fair? Or the preprocessing(compute q, k, v in flashattention; compute A,B,C,D,delta in mamba scan) need to be be taken into account?

Training Script

It would be worth to provide a train script, in order to train larger models (for instance 7B, 13B).

Token masking support

I'm going to train some Mamba variants from scratch on different tasks with both discrete and continuous input, but I'm not sure how the masking mechanism is implemented here (compared to transformer attention masking)

e.g. we want the model not to take into account some tokens, because depending on the number of padding tokens inference would give different results as I understand.

Error during backprop when bias=True in Mamba block

Hi Albert & Tri,

Awesome work! Thank you for sharing!

I noticed a bug when using bias=True in a Mamba block. For example, using:

layer = Mamba(
        # This module uses roughly 3 * expand * d_model^2 parameters
        d_model=512, # Model dimension d_model
        d_state=16,  # SSM state expansion factor
        d_conv=4,    # Local convolution width
        expand=2,    # Block expansion factor
    ).to("cuda")

works perfectly fine, but if I do:

layer = Mamba(
        # This module uses roughly 3 * expand * d_model^2 parameters
        d_model=512, # Model dimension d_model
        d_state=16,  # SSM state expansion factor
        d_conv=4,    # Local convolution width
        expand=2,    # Block expansion factor
       bias=True,
    ).to("cuda")

I get the error:

Traceback (most recent call last):
  File "/home/dromerog/Projects/meshgen/train.py", line 79, in <module>
    main()
  File "/home/dromerog/Projects/meshgen/train.py", line 71, in main
    trainer.train_loop()
  File "/home/dromerog/Projects/meshgen/meshgen/trainers/autoregressive_mamba.py", line 94, in train_loop
    loss['total'].backward()
  File "/home/dromerog/anaconda3/envs/meshgen/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/home/dromerog/anaconda3/envs/meshgen/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function MambaInnerFnBackward returned an invalid gradient at index 6 - got [] but expected shape compatible with [512]

Hopes this helps! :)

Best,
David

Parameter count

When counting the number of parameters of a loaded model I get very different results than what is written. For example the 2.8b model shows only about 1.4b parameters and the 130M shows as 81M. Why is there a difference?

RuntimeError: CUDA error: no kernel image is available for execution on the device

is a p5200 enough for this?

]
Traceback (most recent call last):
File "/home/user/mamba/simplermambassm.py", line 259, in
losses = estimate_loss()
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/user/mamba/simplermambassm.py", line 94, in estimate_loss
logits, loss = model(X,Y)
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/mamba/simplermambassm.py", line 210, in forward
x = self.blocks(x) # (B,T,C_e)
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/container.py", line 215, in forward
input = module(input)
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/mamba/simplermambassm.py", line 188, in forward
x = x + self.sa_head(self.ln1(x))
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/mamba_ssm/modules/mamba_simple.py", line 149, in forward
out = mamba_inner_fn(
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 306, in mamba_inner_fn
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 113, in decorate_fwd
return fwd(*args, **kwargs)
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 181, in forward
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True)
RuntimeError: CUDA error: no kernel image is available for execution on the device
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Registering (forward/backward) hooks on intermediate Mamba layers

I'm trying to look at the internal activations of a MambaLMHeadModel (eg. the one you get when you load pretrained). If I look through the modules, I can find all the layers, and each of the modules of the Mamba block.

If I register a forward hook to "backbone.layers.0.mixer" (or any other layer) I can get an activation every forward pass. On the other hand, if I hook on "backbone.layers.0.mixer.in_proj" or "backbone.layers.0.mixer.conv1d", they don't get called during the forward pass of the model.

Is this the expected behaviour?

in-context learning over very long context - how is it possible?

Hi,

Looking at table 2 in the paper, I was astonished to see the result.

image

As I understand it, the synthetic data task produces:

...  [S] [A] [B...] [S]

where:

[S] is a special symbol
[A] is the token to remember
[B...] are up to 1M arbitrary tokens

and the task is to predict [A] after the second occurence of [S].

Given the finite size of hidden state, this implies Delta(x_i) should be ~ 0 for almost every token in [B...], otherwise the hidden state would "overflow" with information and drown out the original [A].

But the Delta function is just parameterized by a single N x 1 vector (plus possible bias). So it seems to me that because the Delta(*) function is content-specific, it can't really be used to truly facilitate unlimited in-context learning for any choice of [A] and intervening [B...]. Is that a correct intuition?

Thanks in advance,

Henry

Error when `fused_add_norm` is False.

llm-compute-14:   File "/home/azureuser/.conda/envs/megatron-deepspeed/lib/python3.8/site-packages/mamba_ssm/ops/triton/layernorm.py", line 494, in forward
llm-compute-14:     return rms_norm_fn(
llm-compute-14: TypeError: rms_norm_fn() got an unexpected keyword argument 'is_rms_norm'

The rms_norm_fn function doesn't have parameter named is_rms_norm.
Rather than, it's hard coded like the following.

https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/layernorm.py#L478

def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6):
    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)  <-- this

But, RMSNorm class in ops/triton/layernorm.py inputs is_rms_norm parameter to this function.

https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/layernorm.py#L502

def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
        return rms_norm_fn(
            x,
            self.weight,
            self.bias,
            residual=residual,
            eps=self.eps,
            prenorm=prenorm,
            residual_in_fp32=residual_in_fp32,
            is_rms_norm=True,  <-- this
        )

Google Colab: AssertionError: libcuda.so cannot found!

I am using Google Colab Pro+ with V100 GPU. I have followed your example but couldn't get the output because of the error:
AssertionError: libcuda.so cannot found!
It seems that triton backend is causing the problem:

/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py in _layer_norm_fwd(x, weight, bias, eps, residual, out_dtype, residual_dtype, is_rms_norm)
153 # heuristics for number of warps
154 with torch.cuda.device(x.device.index):
--> 155 _layer_norm_fwd_1pass_kernel[(M,)](
156 x,
157 y,

/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py in run(self, *args, **kwargs)
98 pruned_configs = self.prune_configs(kwargs)
99 bench_start = time.time()
--> 100 timings = {config: self._bench(*args, config=config, **kwargs)
101 for config in pruned_configs}
102 bench_end = time.time()

/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py in (.0)
98 pruned_configs = self.prune_configs(kwargs)
99 bench_start = time.time()
--> 100 timings = {config: self._bench(*args, config=config, **kwargs)
101 for config in pruned_configs}
102 bench_end = time.time()

/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py in _bench(self, config, *args, **meta)
81 self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
82 try:
---> 83 return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
84 except OutOfResources:
85 return [float('inf'), float('inf'), float('inf')]

/usr/local/lib/python3.10/dist-packages/triton/testing.py in do_bench(fn, warmup, rep, grad_to_none, quantiles, fast_flush, return_mode)
102 """
103
--> 104 fn()
105 torch.cuda.synchronize()
106

/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py in kernel_call()
79 config.pre_hook(full_nargs)
80 self.hook(args)
---> 81 self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
82 try:
83 return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))

in _layer_norm_fwd_1pass_kernel(X, Y, W, B, RESIDUAL, RESIDUAL_OUT, Mean, Rstd, stride_x_row, stride_y_row, stride_res_row, stride_res_out_row, N, eps, IS_RMS_NORM, BLOCK_N, HAS_RESIDUAL, STORE_RESIDUAL_OUT, HAS_BIAS, grid, num_warps, num_stages, extern_libs, stream, warmup, device, device_type)

/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py in compile(fn, **kwargs)
423 # cache manager
424 if is_cuda or is_hip:
--> 425 so_path = make_stub(name, signature, constants)
426 else:
427 so_path = _device_backend.make_launcher_stub(name, signature, constants)

/usr/local/lib/python3.10/dist-packages/triton/compiler/make_launcher.py in make_stub(name, signature, constants)
37 with open(src_path, "w") as f:
38 f.write(src)
---> 39 so = _build(name, src_path, tmpdir)
40 with open(so, "rb") as f:
41 return so_cache_manager.put(f.read(), so_name, binary=True)

/usr/local/lib/python3.10/dist-packages/triton/common/build.py in _build(name, src, srcdir)
59 hip_include_dir = os.path.join(rocm_path_dir(), "include")
60 else:
---> 61 cuda_lib_dirs = libcuda_dirs()
62 cu_include_dir = cuda_include_dir()
63 suffix = sysconfig.get_config_var('EXT_SUFFIX')

/usr/local/lib/python3.10/dist-packages/triton/common/build.py in libcuda_dirs()
28 msg += 'Possible files are located at %s.' % str(locs)
29 msg += 'Please create a symlink of libcuda.so to any of the file.'
---> 30 assert any(os.path.exists(os.path.join(path, 'libcuda.so')) for path in dirs), msg
31 return dirs
32

AssertionError: libcuda.so cannot found!

How can I solve this on Google Colab environment?

DNA checkpoints

In your paper you say Mamba was trained on DNA (Hg38) and compared to HyenaDNA. Do you release the weights anywhere ?

TypeError: expected string or bytes-like object

oracle linux 8.5
built from source
as well as installed from package
python 3.10
cuda 12.2
pytorch 2.01 for cuda 12.1

)
Building wheels for collected packages: mamba-ssm
Building wheel for mamba-ssm (setup.py) ... done
Created wheel for mamba-ssm: filename=mamba_ssm-1.0.1-cp310-cp310-linux_x86_64.whl size=137567739 sha256=1775b610f76d6bc71ffaa72375df8c0afde52c1c14f1c788ad6afee4290adff2
Stored in directory: /tmp/pip-ephem-wheel-cache-bdzo446u/wheels/e4/00/6b/3fc67b42d0194f9c0988cf52351ff031c56e3607a95badb51c
Successfully built mamba-ssm
WARNING: Ignoring invalid distribution -lash-attn (/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages)
Installing collected packages: argparse, mamba-ssm
Attempting uninstall: mamba-ssm
Found existing installation: mamba-ssm 1.0.1
Uninstalling mamba-ssm-1.0.1:
Successfully uninstalled mamba-ssm-1.0.1
Successfully installed argparse-1.4.0 mamba-ssm-1.0.1
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
(textgen) [root@pve-m7330 mamba]# python3 simplermambassm.py
python3: can't open file '/home/user/mamba/mamba/simplermambassm.py': [Errno 2] No such file or directory
(textgen) [root@pve-m7330 mamba]# cd ..
(textgen) [root@pve-m7330 mamba]# python3 simplermambassm.py
Traceback (most recent call last):
File "/home/user/mamba/simplermambassm.py", line 21, in
from mamba_ssm import Mamba
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/mamba_ssm/init.py", line 5, in
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 12, in
from mamba_ssm.utils.generation import GenerationMixin
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/mamba_ssm/utils/generation.py", line 14, in
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/transformers/init.py", line 26, in
from . import dependency_versions_check
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/transformers/dependency_versions_check.py", line 16, in
from .utils.versions import require_version, require_version_core
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/transformers/utils/init.py", line 31, in
from .generic import (
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/transformers/utils/generic.py", line 29, in
from .import_utils import is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/transformers/utils/import_utils.py", line 74, in
_flash_attn_2_available = _is_package_available("flash_attn") and version.parse(
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/packaging/version.py", line 54, in parse
return Version(version)
File "/home/user/miniconda3/envs/textgen/lib/python3.10/site-packages/packaging/version.py", line 198, in init
match = self._regex.search(version)
TypeError: expected string or bytes-like object
(textgen) [root@pve-m7330 mamba]#

About max token length

What is the max token length that this model can support? Can it support more than 10k?

causal-conv1d installation error

Hi . I am trying to install and use mamba but i cant install causal-conv1d with pip then I tried to build it from source but I get same error .please help me .

Building wheel for causal-conv1d (setup.py) ... error
error: subprocess-exited-with-error

× python setup.py bdist_wheel did not run successfully.
│ exit code: 1
╰─> [113 lines of output]

  torch.__version__  = 1.13.1+cu117
  
  
  running bdist_wheel
  Guessing wheel URL:  https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.0.2/causal_conv1d-1.0.2+cu118torch1.13cxx11abiFALSE-cp37-cp37-linux_x86_64.whl
  Precompiled wheel not found. Building from source...
  running build
  running build_py
  creating build
  creating build/lib.linux-x86_64-cpython-37
  creating build/lib.linux-x86_64-cpython-37/causal_conv1d
  copying causal_conv1d/__init__.py -> build/lib.linux-x86_64-cpython-37/causal_conv1d
  copying causal_conv1d/causal_conv1d_interface.py -> build/lib.linux-x86_64-cpython-37/causal_conv1d
  running build_ext
  /root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/torch/utils/cpp_extension.py:387: UserWarning: The detected CUDA version (11.8) has a minor version mismatch with the version that was used to compile PyTorch (11.7). Most likely this shouldn't be a problem.
    warnings.warn(CUDA_MISMATCH_WARN.format(cuda_str_version, torch.version.cuda))
  /root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/torch/utils/cpp_extension.py:397: UserWarning: There are no g++ version bounds defined for CUDA version 11.8
    warnings.warn(f'There are no {compiler_name} version bounds defined for CUDA version {cuda_str_version}')
  building 'causal_conv1d_cuda' extension
  creating /tmp/pip-install-ite1402y/causal-conv1d_6898db0f5bcd4aa9b82b3dd0dca603f7/build/temp.linux-x86_64-cpython-37
  creating /tmp/pip-install-ite1402y/causal-conv1d_6898db0f5bcd4aa9b82b3dd0dca603f7/build/temp.linux-x86_64-cpython-37/csrc
  Emitting ninja build file /tmp/pip-install-ite1402y/causal-conv1d_6898db0f5bcd4aa9b82b3dd0dca603f7/build/temp.linux-x86_64-cpython-37/build.ninja...
  Compiling objects...
  Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
  ninja: error: '/tmp/pip-install-ite1402y/causal-conv1d_6898db0f5bcd4aa9b82b3dd0dca603f7/csrc/causal_conv1d.cpp', needed by '/tmp/pip-install-ite1402y/causal-conv1d_6898db0f5bcd4aa9b82b3dd0dca603f7/build/temp.linux-x86_64-cpython-37/csrc/causal_conv1d.o', missing and no known rule to make it
  Traceback (most recent call last):
    File "/tmp/pip-install-ite1402y/causal-conv1d_6898db0f5bcd4aa9b82b3dd0dca603f7/setup.py", line 207, in run
      urllib.request.urlretrieve(wheel_url, wheel_filename)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/urllib/request.py", line 247, in urlretrieve
      with contextlib.closing(urlopen(url, data)) as fp:
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/urllib/request.py", line 222, in urlopen
      return opener.open(url, data, timeout)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/urllib/request.py", line 531, in open
      response = meth(req, response)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/urllib/request.py", line 641, in http_response
      'http', request, response, code, msg, hdrs)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/urllib/request.py", line 569, in error
      return self._call_chain(*args)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/urllib/request.py", line 503, in _call_chain
      result = func(*args)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/urllib/request.py", line 649, in http_error_default
      raise HTTPError(req.full_url, code, msg, hdrs, fp)
  urllib.error.HTTPError: HTTP Error 404: Not Found
  
  During handling of the above exception, another exception occurred:
  
  Traceback (most recent call last):
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 1906, in _run_ninja_build
      env=env)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/subprocess.py", line 512, in run
      output=stdout, stderr=stderr)
  subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
  
  The above exception was the direct cause of the following exception:
  
  Traceback (most recent call last):
    File "<string>", line 36, in <module>
    File "<pip-setuptools-caller>", line 34, in <module>
    File "/tmp/pip-install-ite1402y/causal-conv1d_6898db0f5bcd4aa9b82b3dd0dca603f7/setup.py", line 264, in <module>
      "ninja",
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/__init__.py", line 87, in setup
      return distutils.core.setup(**attrs)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/core.py", line 185, in setup
      return run_commands(dist)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/core.py", line 201, in run_commands
      dist.run_commands()
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/dist.py", line 969, in run_commands
      self.run_command(cmd)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/dist.py", line 1208, in run_command
      super().run_command(command)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
      cmd_obj.run()
    File "/tmp/pip-install-ite1402y/causal-conv1d_6898db0f5bcd4aa9b82b3dd0dca603f7/setup.py", line 224, in run
      super().run()
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/wheel/bdist_wheel.py", line 325, in run
      self.run_command("build")
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/cmd.py", line 318, in run_command
      self.distribution.run_command(command)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/dist.py", line 1208, in run_command
      super().run_command(command)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
      cmd_obj.run()
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/command/build.py", line 132, in run
      self.run_command(cmd_name)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/cmd.py", line 318, in run_command
      self.distribution.run_command(command)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/dist.py", line 1208, in run_command
      super().run_command(command)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
      cmd_obj.run()
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/command/build_ext.py", line 84, in run
      _build_ext.run(self)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/command/build_ext.py", line 346, in run
      self.build_extensions()
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 843, in build_extensions
      build_ext.build_extensions(self)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/command/build_ext.py", line 468, in build_extensions
      self._build_extensions_serial()
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/command/build_ext.py", line 494, in _build_extensions_serial
      self.build_extension(ext)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/command/build_ext.py", line 246, in build_extension
      _build_ext.build_extension(self, ext)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/command/build_ext.py", line 556, in build_extension
      depends=ext.depends,
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 668, in unix_wrap_ninja_compile
      with_cuda=with_cuda)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 1578, in _write_ninja_file_and_compile_objects
      error_prefix='Error compiling objects for extension')
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 1916, in _run_ninja_build
      raise RuntimeError(message) from e
  RuntimeError: Error compiling objects for extension
  [end of output]

note: This error originates from a subprocess, and is likely not a problem with pip.
ERROR: Failed building wheel for causal-conv1d
Running setup.py clean for causal-conv1d
Failed to build causal-conv1d
Installing collected packages: argparse, causal-conv1d
Running setup.py install for causal-conv1d ... error
error: subprocess-exited-with-error

× Running setup.py install for causal-conv1d did not run successfully.
│ exit code: 1
╰─> [92 lines of output]

  torch.__version__  = 1.13.1+cu117
  
  
  running install
  /root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/command/install.py:37: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
    setuptools.SetuptoolsDeprecationWarning,
  running build
  running build_py
  creating build
  creating build/lib.linux-x86_64-cpython-37
  creating build/lib.linux-x86_64-cpython-37/causal_conv1d
  copying causal_conv1d/__init__.py -> build/lib.linux-x86_64-cpython-37/causal_conv1d
  copying causal_conv1d/causal_conv1d_interface.py -> build/lib.linux-x86_64-cpython-37/causal_conv1d
  running build_ext
  /root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/torch/utils/cpp_extension.py:387: UserWarning: The detected CUDA version (11.8) has a minor version mismatch with the version that was used to compile PyTorch (11.7). Most likely this shouldn't be a problem.
    warnings.warn(CUDA_MISMATCH_WARN.format(cuda_str_version, torch.version.cuda))
  /root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/torch/utils/cpp_extension.py:397: UserWarning: There are no g++ version bounds defined for CUDA version 11.8
    warnings.warn(f'There are no {compiler_name} version bounds defined for CUDA version {cuda_str_version}')
  building 'causal_conv1d_cuda' extension
  creating /tmp/pip-install-ite1402y/causal-conv1d_6898db0f5bcd4aa9b82b3dd0dca603f7/build/temp.linux-x86_64-cpython-37
  creating /tmp/pip-install-ite1402y/causal-conv1d_6898db0f5bcd4aa9b82b3dd0dca603f7/build/temp.linux-x86_64-cpython-37/csrc
  Emitting ninja build file /tmp/pip-install-ite1402y/causal-conv1d_6898db0f5bcd4aa9b82b3dd0dca603f7/build/temp.linux-x86_64-cpython-37/build.ninja...
  Compiling objects...
  Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
  ninja: error: '/tmp/pip-install-ite1402y/causal-conv1d_6898db0f5bcd4aa9b82b3dd0dca603f7/csrc/causal_conv1d.cpp', needed by '/tmp/pip-install-ite1402y/causal-conv1d_6898db0f5bcd4aa9b82b3dd0dca603f7/build/temp.linux-x86_64-cpython-37/csrc/causal_conv1d.o', missing and no known rule to make it
  Traceback (most recent call last):
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 1906, in _run_ninja_build
      env=env)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/subprocess.py", line 512, in run
      output=stdout, stderr=stderr)
  subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
  
  The above exception was the direct cause of the following exception:
  
  Traceback (most recent call last):
    File "<string>", line 36, in <module>
    File "<pip-setuptools-caller>", line 34, in <module>
    File "/tmp/pip-install-ite1402y/causal-conv1d_6898db0f5bcd4aa9b82b3dd0dca603f7/setup.py", line 264, in <module>
      "ninja",
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/__init__.py", line 87, in setup
      return distutils.core.setup(**attrs)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/core.py", line 185, in setup
      return run_commands(dist)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/core.py", line 201, in run_commands
      dist.run_commands()
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/dist.py", line 969, in run_commands
      self.run_command(cmd)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/dist.py", line 1208, in run_command
      super().run_command(command)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
      cmd_obj.run()
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/command/install.py", line 68, in run
      return orig.install.run(self)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/command/install.py", line 698, in run
      self.run_command('build')
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/cmd.py", line 318, in run_command
      self.distribution.run_command(command)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/dist.py", line 1208, in run_command
      super().run_command(command)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
      cmd_obj.run()
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/command/build.py", line 132, in run
      self.run_command(cmd_name)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/cmd.py", line 318, in run_command
      self.distribution.run_command(command)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/dist.py", line 1208, in run_command
      super().run_command(command)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
      cmd_obj.run()
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/command/build_ext.py", line 84, in run
      _build_ext.run(self)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/command/build_ext.py", line 346, in run
      self.build_extensions()
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 843, in build_extensions
      build_ext.build_extensions(self)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/command/build_ext.py", line 468, in build_extensions
      self._build_extensions_serial()
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/command/build_ext.py", line 494, in _build_extensions_serial
      self.build_extension(ext)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/command/build_ext.py", line 246, in build_extension
      _build_ext.build_extension(self, ext)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/setuptools/_distutils/command/build_ext.py", line 556, in build_extension
      depends=ext.depends,
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 668, in unix_wrap_ninja_compile
      with_cuda=with_cuda)
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 1578, in _write_ninja_file_and_compile_objects
      error_prefix='Error compiling objects for extension')
    File "/root/anaconda3/envs/DiffGesture/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 1916, in _run_ninja_build
      raise RuntimeError(message) from e
  RuntimeError: Error compiling objects for extension
  [end of output]

note: This error originates from a subprocess, and is likely not a problem with pip.
error: legacy-install-failure

× Encountered error while trying to install package.
╰─> causal-conv1d

note: This is an issue with the package mentioned above, not pip.
hint: See above for output from the failure.

Why isn't selective_scan implemented in Triton?

Reading Section D of the paper:

1. We read in 𝑂(𝐵𝐿𝐷 + 𝐷𝑁) bytes of memory (∆, A, B, C) from slow HBM to fast SRAM.
2. We discretize to produce A, B of size (𝐵, 𝐿, 𝐷, 𝑁) in SRAM.
3. We perform a parallel associative scan, yielding intermediate states of size (𝐵, 𝐿, 𝐷, 𝑁) in SRAM.
4. We multiply and sum with C, producing outputs of size (𝐵, 𝐿, 𝐷) and write it to HBM

I am trying to understand why the selective_scan wasn't implemented in Triton? or where the performance benefit comes from implementing this in CUDA C++ for kernel fusion, parallel scan, & recomputation.

I ask because I'm working on my own implementation of the Mamba Mixer in JAX, and wondering if I can get away with implementing the kernels in Pallas

causal_conv1d import error

After installation from wheel (causal_conv1d-1.0.2+cu122torch2.1cxx11abiTRUE-cp310-cp310-linux_x86_64.whl), errors happened in importing causal_conv1d.

Is there any mismatch in versions of cuda and torch? Thanks for your advice.

image

torch.version: '2.1.0a0+4136153'
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Mon_Apr__3_17:16:06_PDT_2023
Cuda compilation tools, release 12.1, V12.1.105
Build cuda_12.1.r12.1/compiler.32688072_0

Windows Support

I'm able to compile causal-conv1d by adding

                        "-DWIN32_LEAN_AND_MEAN",

To the nvcc flags.

When compiling mamba, after adding -DWIN32_LEAN_AND_MEAN to nvcc flags, I find I need to add

#ifndef M_LOG2E
#define M_LOG2E 1.4426950408889634074
#endif

To selective_scan_common.h

Then it can get a little further, however it raises the following errors:

Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(493): error C2975: 'kIsEvenLen_': invalid template argument for 'Selective_Scan_bwd_kernel_traits', expected compile-time constant expression
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(26): note: see declaration of 'kIsEvenLen_'
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(521): note: see reference to function template instantiation 'void selective_scan_bwd_launch<32,4,input_t,weight_t>(SSMParamsBwd &,cudaStream_t)' being compiled
        with
        [
            input_t=c10::BFloat16,
            weight_t=c10::complex<float>
        ]
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_bf16_complex.cu(9): note: see reference to function template instantiation 'void selective_scan_bwd_cuda<c10::BFloat16,c10::complex<float>>(SSMParamsBwd &,cudaStream_t)' being compiled
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(493): error C2975: 'kIsVariableB_': invalid template argument for 'Selective_Scan_bwd_kernel_traits', expected compile-time constant expression
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(26): note: see declaration of 'kIsVariableB_'
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(493): error C2975: 'kIsVariableC_': invalid template argument for 'Selective_Scan_bwd_kernel_traits', expected compile-time constant expression
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(26): note: see declaration of 'kIsVariableC_'
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(493): error C2975: 'kDeltaSoftplus_': invalid template argument for 'Selective_Scan_bwd_kernel_traits', expected compile-time constant expression
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(27): note: see declaration of 'kDeltaSoftplus_'
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(493): error C2975: 'kHasZ_': invalid template argument for 'Selective_Scan_bwd_kernel_traits', expected compile-time constant expression
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(27): note: see declaration of 'kHasZ_'

This might be related to this issue, something about the windows compiler being more strict. However the intervention is probably gonna be a little more involved and I haven't had much luck yet

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.