Code Monkey home page Code Monkey logo

mamba-chat's Introduction

Mamba-Chat ๐Ÿ

Mamba-Chat is the first chat language model based on a state-space model architecture, not a transformer.

The model is based on Albert Gu's and Tri Dao's work Mamba: Linear-Time Sequence Modeling with Selective State Spaces (paper) as well as their model implementation. This repository provides training / fine-tuning code for the model based on some modifications of the Huggingface Trainer class.

Mamba-Chat is based on Mamba-2.8B and was fine-tuned on 16,000 samples of the HuggingFaceH4/ultrachat_200k dataset. To learn more, you can:


Run Mamba-Chat

We provide code for testing and fine-tuning our model. Here's how to get started and what you can do with it:


Clone repository and install dependencies:

git clone https://github.com/havenhq/mamba-chat.git
cd mamba-chat
pip install -r requirements.txt

Talk to Mamba-Chat (CLI chatbot):

python chat.py

Talk to Mamba-Chat (gradio app):

pip install gradio==4.8.0
python app.py --share

Fine-Tune Mamba (the base model) on a subset of the Ultrachat dataset:

python train_mamba.py --model state-spaces/mamba-2.8b --tokenizer EleutherAI/gpt-neox-20b --learning_rate 5e-5 --batch_size 4 --data_path ./data/ultrachat_small.jsonl --num_epochs 3

If you have a 24GB card (3090, 4090, etc.) you can use these settings:

python train_mamba.py --model state-spaces/mamba-2.8b --tokenizer EleutherAI/gpt-neox-20b --learning_rate 5e-5 --batch_size 1 --gradient_accumulation_steps 4 --optim paged_adamw_8bit --data_path ./data/ultrachat_small.jsonl --num_epochs 3

Citation

bibtex
@misc{haven2023mambachat,
  title        = {Mamba-Chat},
  author       = {Justus Mattern and Konstantin Hohr},
  year         = {2023},
  howpublished = {GitHub},
  url          = {https://github.com/havenhq/mamba-chat}
}

mamba-chat's People

Contributors

blenderwang9487 avatar justusmattern27 avatar rwl4 avatar tohrnii avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

mamba-chat's Issues

Bug in train_mamba.py line 53

Default code generates:

~/site-packages/transformers/trainer.py", line 1597, in _inner_training_loop
    max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: must be real number, not str

Changing the line 53 to int corrects this error: parser.add_argument("--num_epochs", type=int, default=1

Cant use trained model

Traceback (most recent call last):
File "/tdx/llm-download/lhylab/mamba-chat-main/chat.py", line 11, in
model = MambaLMHeadModel.from_pretrained("mamba-chat/checkpoint-50000", device="cuda", dtype=torch.float16)
File "/home/asus/miniconda3/envs/labv2/lib/python3.10/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 230, in from_pretrained
config = load_config_hf(pretrained_model_name)
File "/home/asus/miniconda3/envs/labv2/lib/python3.10/site-packages/mamba_ssm/utils/hf.py", line 11, in load_config_hf
return json.load(open(resolved_archive_file))
TypeError: expected str, bytes or os.PathLike object, not NoneType

help

Demo

Hi,
Very interesting! Are you planning to release a demo on HF spaces?

TypeError: MixerModel.__init__() got an unexpected keyword argument 'bos_token_id'

I train the model via axolot .

Heres the chat.py error:

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Traceback (most recent call last):
File "/workspace/mamba-chat/xdan-chat.py", line 12, in
model = MambaLMHeadModel.from_pretrained(model_path, device="cuda", dtype=torch.float16)
File "/root/miniconda3/envs/axo/lib/python3.10/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 231, in from_pretrained
model = cls(**config, device=device, dtype=dtype, **kwargs)
File "/root/miniconda3/envs/axo/lib/python3.10/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 190, in init
self.backbone = MixerModel(
TypeError: MixerModel.init() got an unexpected keyword argument 'bos_token_id'

Feature 'cvt with .bf16' requires .target sm_80 or higher Error

An error message stating that "Feature 'cvt with.bf16' requires.target sm_80 or higher Error" appears when I try to fine-tune Mamba. I looked up this repo . I'm fine-tuning in the Tesla T4 card it uses Turning Architecture and only supports SM_75, although SM_80 or above is needed from error. Thus, support for SM_80 in Ampere Architecture, such as A100, the error is being caused by ineffective hardware or if there is another cause. I would appreciate clarification. Is there another option to fine-tune in T4?
Thanks in advance

MoE https://arxiv.org/abs/2401.04081

I'm reading through this paper and it looks like it would improve the efficiency of the model. I'm still trying to wrap my head around MoE as a concept, but from results it seems like it would be a natural fit to implement with training and could be a good fit for this repository.
https://github.com/llm-random/llm-random
The concept if Mixture of Experts

Question about how mamba chat training is done

Hi,

First off, thanks for providing training code for mamba use cases.

I was looking at how the training for mamba-chat is done, something I'm unclear on is the "preprocess" function used in the class "ChatDataset" (in "/trainer/data.py"). Why does it return a dictionary with only the input ids and not labels

dict(input_ids = all_input_ids, labels=all_input_ids)

I'm a little confused, wouldn't we need the data of both the user and assistant to train a chatbot? I notice this same pattern a few other times in the training code so I wanted to ask

Thanks!

sentencepiece version

Hi,
When I load the tokenizer of EleutherAI--gpt-neox-20b and zephyr-7b-base.
The following error happened:
ValueError: Couldn't instantiate the backend tokenizer from one of:
(1) a tokenizers library serialization file,
(2) a slow tokenizer instance to convert or
(3) an equivalent slow tokenizer class to instantiate and convert.
You need to have sentencepiece installed to convert a slow tokenizer to a fast one.

Even if I install sentencepiece, the error still exits.

My environments is py38 and sentencepiece 0.20.0. What is your environment?

Memory requirements for training

I was able to run 2.8b model for inference and it uses about 6G of VRAM. In your readme there is 24G requirements for training. Is model uses much more memory during training (32-bit?) or is it because of the space required for input batches?

Error When Inferencing

When I run !python chat.py, I got AssertionError: libcuda.so cannot found!.
I'm using colab.

Issue while installing requirements.txt

Requirement already satisfied: packaging in ./venv/lib/python3.11/site-packages (from -r requirements.txt (line 1)) (21.3)
Collecting torch==2.1.0 (from -r requirements.txt (line 2))
Using cached torch-2.1.0-cp311-cp311-manylinux1_x86_64.whl.metadata (25 kB)
Collecting transformers==4.35.0 (from -r requirements.txt (line 3))
Using cached transformers-4.35.0-py3-none-any.whl.metadata (123 kB)
Collecting causal-conv1d==1.0.0 (from -r requirements.txt (line 4))
Using cached causal_conv1d-1.0.0.tar.gz (6.4 kB)
Installing build dependencies ... done
Getting requirements to build wheel ... error
error: subprocess-exited-with-error

ร— Getting requirements to build wheel did not run successfully.
โ”‚ exit code: 1
โ•ฐโ”€> [20 lines of output]
Traceback (most recent call last):
File "/home/rexommendation/Programs/mamba-chat/venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 353, in
main()
File "/home/rexommendation/Programs/mamba-chat/venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 335, in main
json_out['return_val'] = hook(**hook_input['kwargs'])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rexommendation/Programs/mamba-chat/venv/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 118, in get_requires_for_build_wheel
return hook(config_settings)
^^^^^^^^^^^^^^^^^^^^^
File "/tmp/pip-build-env-wiab8kl5/overlay/lib/python3.11/site-packages/setuptools/build_meta.py", line 325, in get_requires_for_build_wheel
return self._get_build_requires(config_settings, requirements=['wheel'])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/pip-build-env-wiab8kl5/overlay/lib/python3.11/site-packages/setuptools/build_meta.py", line 295, in _get_build_requires
self.run_setup()
File "/tmp/pip-build-env-wiab8kl5/overlay/lib/python3.11/site-packages/setuptools/build_meta.py", line 480, in run_setup
super(_BuildMetaLegacyBackend, self).run_setup(setup_script=setup_script)
File "/tmp/pip-build-env-wiab8kl5/overlay/lib/python3.11/site-packages/setuptools/build_meta.py", line 311, in run_setup
exec(code, locals())
File "", line 8, in
ModuleNotFoundError: No module named 'packaging'
[end of output]

note: This error originates from a subprocess, and is likely not a problem with pip.
error: subprocess-exited-with-error

ร— Getting requirements to build wheel did not run successfully.
โ”‚ exit code: 1
โ•ฐโ”€> See above for output.

note: This error originates from a subprocess, and is likely not a problem with pip.

Error in importing from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel in GoogleCollab

ERROR: Could not find a version that satisfies the requirement mamba_ssm.models.mixer_seq_simple (from versions: none)
ERROR: No matching distribution found for mamba_ssm.models.mixer_seq_simple
ERROR: Could not find a version that satisfies the requirement causal_conv1d_cuda (from versions: none)
ERROR: No matching distribution found for causal_conv1d_cuda

ImportError Traceback (most recent call last)
in <cell line: 5>()
3 get_ipython().system('pip install mamba_ssm.models.mixer_seq_simple')
4 get_ipython().system('pip install causal_conv1d_cuda')
----> 5 from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
6
7 device = "cuda"

3 frames
/usr/local/lib/python3.10/dist-packages/causal_conv1d/causal_conv1d_interface.py in
5
6
----> 7 import causal_conv1d_cuda
8
9

ImportError: /usr/local/lib/python3.10/dist-packages/causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops10zeros_like4callERKNS_6TensorEN3c108optionalINS5_10ScalarTypeEEENS6_INS5_6LayoutEEENS6_INS5_6DeviceEEENS6_IbEENS6_INS5_12MemoryFormatEEE

Error during training

I can get the model to perform inference just fine, but in my colab env this is what I run into

`2023-12-21 14:14:22.093031: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-21 14:14:22.093134: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-21 14:14:22.162482: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-21 14:14:24.060295: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Tokenizing dataset...
100% 1000/1000 [00:02<00:00, 383.23it/s]
0% 0/750 [00:00<?, ?it/s]Traceback (most recent call last):
File "/content/mamba-chat/train_mamba.py", line 60, in
run(args)
File "/content/mamba-chat/train_mamba.py", line 45, in run
trainer.train()
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1555, in train
return inner_training_loop(
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1860, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2725, in training_step
loss = self.compute_loss(model, inputs)
File "/content/mamba-chat/trainer/mamba_trainer.py", line 9, in compute_loss
lm_logits = model(input_ids).logits
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 221, in forward
hidden_states = self.backbone(input_ids, inference_params=inference_params)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 152, in forward
hidden_states, residual = layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/mamba_simple.py", line 341, in forward
hidden_states, residual = fused_add_norm_fn(
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 411, in forward
y, mean, rstd, residual_out = _layer_norm_fwd(
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd
_layer_norm_fwd_1pass_kernel[(M,)](
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 100, in run
timings = {config: self._bench(*args, config=config, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 100, in
timings = {config: self._bench(*args, config=config, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 83, in _bench
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 104, in do_bench
fn()
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 81, in kernel_call
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
File "", line 63, in _layer_norm_fwd_1pass_kernel
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 476, in compile
next_module = compile_kernel(module)
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 351, in
lambda src: ptx_to_cubin(src, arch))
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 150, in ptx_to_cubin
return compile_ptx_to_cubin(ptx, ptxas, arch)
RuntimeError: Internal Triton PTX codegen error:
ptxas /tmp/compile-ptx-src-eed3b0, line 984; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 984; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 986; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 986; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 988; error : Feature '.bf16' requires .target sm_80 or higher
....
ptxas /tmp/compile-ptx-src-eed3b0, line 2801; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2803; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2803; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2805; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2805; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2807; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2807; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas fatal : Ptx assembly aborted due to errors

0% 0/750 [00:02<?, ?it/s]`

How could I run this on windows 10?

Hello,

When I tried to install packages with requirements.txt, I got the following error.

Error
ERROR: Could not find a version that satisfies the requirement triton (from mamba-ssm) (from versions: none)
ERROR: No matching distribution found for triton

As I know, triton package supports something like linux only.
What should I do?

Interesting chat example

This is not an issue but an interesting example of the learning via chat:

Your message: what is the loudest animal in the world
Model: The loudest animal in the world is the blue whale, which can produce sounds up to 150 decibels.<|endoftext|>

Your message: is the sperm whale louder than the blue whale
Model: Yes, the sperm whale is the loudest animal in the world, with a sound pressure level of up to 170 decibels.<|endoftext|>

Your message: what is the loudest animal in the world
Model: The loudest animal in the world is the sperm whale, with a sound pressure level of up to 170 decibels.<|endoftext|>

Finetune on 3090 but loss equal to zero

Thanks for the great work! I tried to finetune the Mamba model using four 3090 GPUs following your code. But there's a problem: the loss drops to zero after just two steps. Can you help me figure out what's going wrong and fix it?
image

ImportError: /usr/local/lib/python3.10/dist-packages/causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEi

ImportError Traceback (most recent call last)
in <cell line: 4>()
2 import argparse
3
----> 4 from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
5 from transformers import AutoTokenizer, TrainingArguments
6 from trainer.data import ChatDataModule

3 frames
/usr/local/lib/python3.10/dist-packages/causal_conv1d/causal_conv1d_interface.py in
5
6
----> 7 import causal_conv1d_cuda
8
9

ImportError: /usr/local/lib/python3.10/dist-packages/causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEi


NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
"Open Examples" button below.

I am facing above error when starting the fine-tuning in Google Colab.

Colab notebook has error, numpy array used instead of torch

while True:
    user_message = input("\nYour message: ")
    messages.append(dict(
        role="user",
        content=user_message
    ))
 

    out = model.generate(input_ids=input_ids, max_length=2000, temperature=0.9, top_p=0.7, eos_token_id=tokenizer.eos_token_id)

    decoded = tokenizer.batch_decode(out)
    messages.append(dict(
        role="assistant",
        content=decoded[0].split("<|assistant|>\n")[-1])
    )

    print("Model:", decoded[0].split("<|assistant|>\n")[-1])

Error:

AttributeError                            Traceback (most recent call last)
[<ipython-input-6-8618df48d4c0>](https://localhost:8080/#) in <cell line: 2>()
      9     input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to("cuda")
     10 
---> 11     out = model.generate(input_ids=input_ids, max_length=2000, temperature=0.9, top_p=0.7, eos_token_id=tokenizer.eos_token_id)
     12 
     13     decoded = tokenizer.batch_decode(out)

18 frames
[/usr/local/lib/python3.10/dist-packages/causal_conv1d/causal_conv1d_interface.py](https://localhost:8080/#) in forward(ctx, x, weight, bias, seq_idx, activation)
     17             x = x.contiguous()
     18         bias = bias.contiguous() if bias is not None else None
---> 19         seq_idx = seq_idx.contiguous() if seq_idx is not None else None
     20         ctx.save_for_backward(x, weight, bias, seq_idx)
     21         ctx.activation = activation in ["silu", "swish"]

AttributeError: 'str' object has no attribute 'contiguous'

Line most likely causing error, torch conversion did not happen, numpy arrays do not have any contigous method. Fix this to explicitly cast as torch tensor:

input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to("cuda")

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.