Comments (18)
Oh, wow so did that simple solution fix the problem for you and enable the FSDP training for mamba or did you also have to do something else? Thanks for your quick reply as well :)
I'm trying a higher image resolution training so maybe I might try embedding dimension of 512, but that will depend on my GPU limit I guess...
Yes, only change the dim(I think the dim of mamba should be 256,512,etc), and anything bad about (mamba +dinov2) disappears.
from mamba.
Thank you, that helps a lot! If you don't mind, can you tell me if you used these configs for student & teacher?
backbone:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp32
reduce_dtype: bf16
buffer_dtype: fp16
Also if you don't mind, may I ask what modifications you have made to make mamba implementation for FSDP (auto_wrap_policy, chunking model block ,etc)? Thank you so much! 👍
param_dtype: bf16
reduce_dtype: fp32
buffer_dtype: fp32
chunking model block =0
no change in other part
from mamba.
Please send a script that we can reproduce, e.g. by saving tensors right before a function call that causes NaN:
### torch.save(a bunch of tensors)
tensors = torch.load(blah)
# calling func on tensors cause NaN.
from mamba.
Thank you for your reply. I save the tensor from here:
mamba/mamba_ssm/modules/mamba2_simple.py
Line 138 in f9dbb4f
the format I save the format like this:
torch.save({
'zxbcdt': zxbcdt,
'conv1d_weight': rearrange(self.conv1d.weight, "d 1 w -> d w"),
'conv1d_bias': self.conv1d.bias,
'dt_bias': self.dt_bias,
'A': A,
'D': self.D,
'seq_idx': seq_idx,
'activation': self.activation,
'rmsnorm_weight': self.norm.weight,
'rmsnorm_eps': self.norm.eps,
'outproj_weight': self.out_proj.weight,
'outproj_bias': self.out_proj.bias,
'headdim': self.headdim,
'ngroups': self.ngroups,
'norm_before_gate': False,
'initial_states': initial_states,
'dt_limit_kwargs': dt_limit_kwargs
}, '/cis/home/gwei10/dinov2_vim_v3/debug_tensors_before_mamba_split.pth')
out = mamba_split_conv1d_scan_combined(
zxbcdt,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.dt_bias,
A,
D=self.D,
chunk_size=self.chunk_size,
seq_idx=seq_idx,
activation=self.activation,
rmsnorm_weight=self.norm.weight,
rmsnorm_eps=self.norm.eps,
outproj_weight=self.out_proj.weight,
outproj_bias=self.out_proj.bias,
headdim=self.headdim,
ngroups=self.ngroups,
norm_before_gate=False,
initial_states=initial_states,
**dt_limit_kwargs,
)
torch.save(out,'/cis/home/gwei10/dinov2_vim_v3/debug_output.pth')
You can download the tensor from https://drive.google.com/file/d/10BEHQIZb_cBMoCv38NH6SyWOHLbruiYb/view?usp=sharing
from mamba.
Does that cause the output to be NaN?
You mentioned gradient being NaN. How do we reproduce that from the saved tensors to get NaN gradient?
from mamba.
e.g. can you post a short script that loads the tensors, call the relevant function, then show that the function produces NaN?
from mamba.
Thank you for your reply. Finally, I found the problem caused by the mixed_precision.
Do you have any recommendations ?I use FullyShardedDataParallel in Dinov2. And, the config is
backbone:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp32
reduce_dtype: bf16
buffer_dtype: fp16
Do you think this config is good for mamba?
from mamba.
We use torch.amp. I'm not familiar with FSDP configs.
from mamba.
@yangtiming Hi, I had a similar issue where I ran into nan issues with utilizing dinov2 with mamba, utilizing DDP (not even FSDP as you mentioned). However I do plan to expand it to FSDP in the future, did you get it figured out by any chance? Thank you!
from mamba.
@yangtiming Hi, I had a similar issue where I ran into nan issues with utilizing dinov2 with mamba, utilizing DDP (not even FSDP as you mentioned). However I do plan to expand it to FSDP in the future, did you get it figured out by any chance? Thank you!
The basic solution is to change the dim from 384 -> 256. (multiple of 8, maybe you can see some question about it)
from mamba.
Oh, wow so did that simple solution fix the problem for you and enable the FSDP training for mamba or did you also have to do something else? Thanks for your quick reply as well :)
I'm trying a higher image resolution training so maybe I might try embedding dimension of 512, but that will depend on my GPU limit I guess...
from mamba.
Thank you, that helps a lot! If you don't mind, can you tell me if you used these configs for student & teacher?
backbone:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp32
reduce_dtype: bf16
buffer_dtype: fp16
Also if you don't mind, may I ask what modifications you have made to make mamba implementation for FSDP (auto_wrap_policy, chunking model block ,etc)? Thank you so much! 👍
from mamba.
I see! So block_chunks = 0 will run fine, didn't know that, thank you! Bless you for all of the help, much appreciated :)
from mamba.
@yangtiming Seems like it is only for Mamba-2 that requires the embedding dimension or d_model to be multiples of 256 or 512, but did you find that Mamba would require the same? Because currently I am using Mamba, not Mamba-2 for DINOv2. Thank you :)
from mamba.
@yangtiming Seems like it is only for Mamba-2 that requires the embedding dimension or d_model to be multiples of 256 or 512, but did you find that Mamba would require the same? Because currently I am using Mamba, not Mamba-2 for DINOv2. Thank you :)
If you are using Vim, then try to set
rms_norm=False, fused_add_norm=False.
OR, you can set
param_dtype: fp32
reduce_dtype: fp32 ( I forget, maybe fp16)
buffer_dtype: fp32 ( I forget, maybe fp16)
It will work well.
from mamba.
#352 #345
@yangtiming Seems like it is only for Mamba-2 that requires the embedding dimension or d_model to be multiples of 256 or 512, but did you find that Mamba would require the same? Because currently I am using Mamba, not Mamba-2 for DINOv2. Thank you :)If you are using Vim, then try to set rms_norm=False, fused_add_norm=False.
OR, you can set
param_dtype: fp32 reduce_dtype: fp32 ( I forget, maybe fp16) buffer_dtype: fp32 ( I forget, maybe fp16)
It will work well.
Hmm, yes I am utilizing the Vim architecture and I modified the code to make DINOv2 training work (I'm sure you did this as well). The default currently is rms_norm and fused_add_norm both equal to True for Vim, and setting these to False would make the training stable (no more nan's and convergence)?
Interesting how doing add and norm separately by setting it to False makes the training better...And yes I will try those dtypes, thank you so much!
from mamba.
@yangtiming @tridao Hello, I'm sorry to keep asking questions, but may I ask which environment (CUDA, PyTorch version, mamba-ssm, and causal_conv1d versions) you utilized to train Vision Mamba (Vim) with DINOv2 task utilizing bfloat16? When I utilize fp32 or fp16 it runs fine (fp16 will give unstable training though) and code will compile, but bfloat16 keeps giving me errors like this:
TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported:
1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: Optional[torch.Tensor], arg5: Optional[torch.Tensor], arg6: bool) -> torch.Tensor
I'm currently utilizing CUDA 11.8, python 3.10.13, torch 2.1.1, causal_conv1d 1.1.1, mamba-ssm==1.2.0.post1. I've tried changing up the torch, causal_conv1d, mamba-ssm versions, but with no success, still cannot train with bfloat16. I am installing the causal_conv1d and mamba-ssm packages via wheel utilizing pip (pip install /path/to/.whl) to make sure the packages are compatibile with python, torch and CUDA versions) . Any help would be appreciated!
from mamba.
Also, I've looked at other issues, but the only solution that was mentioned in this issue doesn't apply to me since my GPU CCC is >=7. (Utilizing V100 or/and A100).
Could submit a new issue with more details if that could be better.
from mamba.
Related Issues (20)
- mamba_split_conv1d_scan_combined error HOT 1
- torch.cuda.amp.custom_fwd(args...) Deprecated
- Is it possible to get ssm_states at specific seq len instead of just the last state? HOT 1
- How to modify mambaV2 to set causal=False? HOT 3
- Reducing _chunk_scan_bwd_kernel computation HOT 4
- Exploding gradients if ngroups is higher than 1. HOT 3
- clarification on how to interpret kernel size for conv1d HOT 1
- Some questions about the shape of A,B,C,D HOT 5
- How can I avoid using causal-conv1d? HOT 11
- Is it okay to put S6(MAMBA) and gated MLP blocks like a transformer? (Also please open the discussion tab)
- Vanishing gradient problem with more layer HOT 2
- Results vary greatly across experiments
- Gradient explosion in Mamba2 training, norm and loss divergence HOT 3
- Optimizing the bwd pass of Mamba 2 HOT 3
- Question about d_state. HOT 1
- Understanding about the selective scan HOT 2
- ModuleNotFoundError: No module named 'mamba_ssm.ops.triton.ssd_combined
- ERROR: Failed building wheel for mamba-ssm HOT 2
- Issue about the FLOPs of selective scan
- Chunked inference HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mamba.