I can get the model to perform inference just fine, but in my colab env this is what I run into
`2023-12-21 14:14:22.093031: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-21 14:14:22.093134: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-21 14:14:22.162482: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-21 14:14:24.060295: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Tokenizing dataset...
100% 1000/1000 [00:02<00:00, 383.23it/s]
0% 0/750 [00:00<?, ?it/s]Traceback (most recent call last):
File "/content/mamba-chat/train_mamba.py", line 60, in
run(args)
File "/content/mamba-chat/train_mamba.py", line 45, in run
trainer.train()
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1555, in train
return inner_training_loop(
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1860, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2725, in training_step
loss = self.compute_loss(model, inputs)
File "/content/mamba-chat/trainer/mamba_trainer.py", line 9, in compute_loss
lm_logits = model(input_ids).logits
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 221, in forward
hidden_states = self.backbone(input_ids, inference_params=inference_params)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 152, in forward
hidden_states, residual = layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/mamba_simple.py", line 341, in forward
hidden_states, residual = fused_add_norm_fn(
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 411, in forward
y, mean, rstd, residual_out = _layer_norm_fwd(
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd
_layer_norm_fwd_1pass_kernel[(M,)](
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 100, in run
timings = {config: self._bench(*args, config=config, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 100, in
timings = {config: self._bench(*args, config=config, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 83, in _bench
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 104, in do_bench
fn()
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 81, in kernel_call
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
File "", line 63, in _layer_norm_fwd_1pass_kernel
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 476, in compile
next_module = compile_kernel(module)
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 351, in
lambda src: ptx_to_cubin(src, arch))
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 150, in ptx_to_cubin
return compile_ptx_to_cubin(ptx, ptxas, arch)
RuntimeError: Internal Triton PTX codegen error:
ptxas /tmp/compile-ptx-src-eed3b0, line 984; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 984; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 986; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 986; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 988; error : Feature '.bf16' requires .target sm_80 or higher
....
ptxas /tmp/compile-ptx-src-eed3b0, line 2801; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2803; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2803; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2805; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2805; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2807; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2807; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas fatal : Ptx assembly aborted due to errors
0% 0/750 [00:02<?, ?it/s]`