Code Monkey home page Code Monkey logo

Comments (18)

yangtiming avatar yangtiming commented on September 18, 2024 1

Oh, wow so did that simple solution fix the problem for you and enable the FSDP training for mamba or did you also have to do something else? Thanks for your quick reply as well :)

I'm trying a higher image resolution training so maybe I might try embedding dimension of 512, but that will depend on my GPU limit I guess...

Yes, only change the dim(I think the dim of mamba should be 256,512,etc), and anything bad about (mamba +dinov2) disappears.

from mamba.

yangtiming avatar yangtiming commented on September 18, 2024 1

Thank you, that helps a lot! If you don't mind, can you tell me if you used these configs for student & teacher?
backbone:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp32
reduce_dtype: bf16
buffer_dtype: fp16
Also if you don't mind, may I ask what modifications you have made to make mamba implementation for FSDP (auto_wrap_policy, chunking model block ,etc)? Thank you so much! 👍

param_dtype: bf16
reduce_dtype: fp32
buffer_dtype: fp32

chunking model block =0

no change in other part

from mamba.

tridao avatar tridao commented on September 18, 2024

Please send a script that we can reproduce, e.g. by saving tensors right before a function call that causes NaN:

### torch.save(a bunch of tensors)
tensors = torch.load(blah)
# calling func on tensors cause NaN.

from mamba.

yangtiming avatar yangtiming commented on September 18, 2024

Thank you for your reply. I save the tensor from here:

out = mamba_split_conv1d_scan_combined(

the format I save the format like this:

        torch.save({
            'zxbcdt': zxbcdt,
            'conv1d_weight': rearrange(self.conv1d.weight, "d 1 w -> d w"),
            'conv1d_bias': self.conv1d.bias,
            'dt_bias': self.dt_bias,
            'A': A,
            'D': self.D,
            'seq_idx': seq_idx,
            'activation': self.activation,
            'rmsnorm_weight': self.norm.weight,
            'rmsnorm_eps': self.norm.eps,
            'outproj_weight': self.out_proj.weight,
            'outproj_bias': self.out_proj.bias,
            'headdim': self.headdim,
            'ngroups': self.ngroups,
            'norm_before_gate': False,
            'initial_states': initial_states,
            'dt_limit_kwargs': dt_limit_kwargs
        }, '/cis/home/gwei10/dinov2_vim_v3/debug_tensors_before_mamba_split.pth')

        out = mamba_split_conv1d_scan_combined(
            zxbcdt,
            rearrange(self.conv1d.weight, "d 1 w -> d w"),
            self.conv1d.bias,
            self.dt_bias,
            A,
            D=self.D,
            chunk_size=self.chunk_size,
            seq_idx=seq_idx,
            activation=self.activation,
            rmsnorm_weight=self.norm.weight,
            rmsnorm_eps=self.norm.eps,
            outproj_weight=self.out_proj.weight,
            outproj_bias=self.out_proj.bias,
            headdim=self.headdim,
            ngroups=self.ngroups,
            norm_before_gate=False,
            initial_states=initial_states,
            **dt_limit_kwargs,
        )
        torch.save(out,'/cis/home/gwei10/dinov2_vim_v3/debug_output.pth')

You can download the tensor from https://drive.google.com/file/d/10BEHQIZb_cBMoCv38NH6SyWOHLbruiYb/view?usp=sharing

from mamba.

tridao avatar tridao commented on September 18, 2024

Does that cause the output to be NaN?
You mentioned gradient being NaN. How do we reproduce that from the saved tensors to get NaN gradient?

from mamba.

tridao avatar tridao commented on September 18, 2024

e.g. can you post a short script that loads the tensors, call the relevant function, then show that the function produces NaN?

from mamba.

yangtiming avatar yangtiming commented on September 18, 2024

Thank you for your reply. Finally, I found the problem caused by the mixed_precision.
Do you have any recommendations ?I use FullyShardedDataParallel in Dinov2. And, the config is

backbone:
      sharding_strategy: SHARD_GRAD_OP
      mixed_precision:
        param_dtype: fp32
        reduce_dtype: bf16
        buffer_dtype: fp16

Do you think this config is good for mamba?

from mamba.

tridao avatar tridao commented on September 18, 2024

We use torch.amp. I'm not familiar with FSDP configs.

from mamba.

chokevin8 avatar chokevin8 commented on September 18, 2024

@yangtiming Hi, I had a similar issue where I ran into nan issues with utilizing dinov2 with mamba, utilizing DDP (not even FSDP as you mentioned). However I do plan to expand it to FSDP in the future, did you get it figured out by any chance? Thank you!

from mamba.

yangtiming avatar yangtiming commented on September 18, 2024

@yangtiming Hi, I had a similar issue where I ran into nan issues with utilizing dinov2 with mamba, utilizing DDP (not even FSDP as you mentioned). However I do plan to expand it to FSDP in the future, did you get it figured out by any chance? Thank you!

The basic solution is to change the dim from 384 -> 256. (multiple of 8, maybe you can see some question about it)

from mamba.

chokevin8 avatar chokevin8 commented on September 18, 2024

Oh, wow so did that simple solution fix the problem for you and enable the FSDP training for mamba or did you also have to do something else? Thanks for your quick reply as well :)

I'm trying a higher image resolution training so maybe I might try embedding dimension of 512, but that will depend on my GPU limit I guess...

from mamba.

chokevin8 avatar chokevin8 commented on September 18, 2024

Thank you, that helps a lot! If you don't mind, can you tell me if you used these configs for student & teacher?
backbone:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp32
reduce_dtype: bf16
buffer_dtype: fp16
Also if you don't mind, may I ask what modifications you have made to make mamba implementation for FSDP (auto_wrap_policy, chunking model block ,etc)? Thank you so much! 👍

from mamba.

chokevin8 avatar chokevin8 commented on September 18, 2024

I see! So block_chunks = 0 will run fine, didn't know that, thank you! Bless you for all of the help, much appreciated :)

from mamba.

chokevin8 avatar chokevin8 commented on September 18, 2024

#352
#345

@yangtiming Seems like it is only for Mamba-2 that requires the embedding dimension or d_model to be multiples of 256 or 512, but did you find that Mamba would require the same? Because currently I am using Mamba, not Mamba-2 for DINOv2. Thank you :)

from mamba.

yangtiming avatar yangtiming commented on September 18, 2024

#352 #345

@yangtiming Seems like it is only for Mamba-2 that requires the embedding dimension or d_model to be multiples of 256 or 512, but did you find that Mamba would require the same? Because currently I am using Mamba, not Mamba-2 for DINOv2. Thank you :)

If you are using Vim, then try to set
rms_norm=False, fused_add_norm=False.

OR, you can set

param_dtype: fp32
reduce_dtype: fp32 ( I forget, maybe fp16)
buffer_dtype: fp32 ( I forget, maybe fp16)

It will work well.

from mamba.

chokevin8 avatar chokevin8 commented on September 18, 2024

#352 #345
@yangtiming Seems like it is only for Mamba-2 that requires the embedding dimension or d_model to be multiples of 256 or 512, but did you find that Mamba would require the same? Because currently I am using Mamba, not Mamba-2 for DINOv2. Thank you :)

If you are using Vim, then try to set rms_norm=False, fused_add_norm=False.

OR, you can set

param_dtype: fp32 reduce_dtype: fp32 ( I forget, maybe fp16) buffer_dtype: fp32 ( I forget, maybe fp16)

It will work well.

Hmm, yes I am utilizing the Vim architecture and I modified the code to make DINOv2 training work (I'm sure you did this as well). The default currently is rms_norm and fused_add_norm both equal to True for Vim, and setting these to False would make the training stable (no more nan's and convergence)?

Interesting how doing add and norm separately by setting it to False makes the training better...And yes I will try those dtypes, thank you so much!

from mamba.

chokevin8 avatar chokevin8 commented on September 18, 2024

@yangtiming @tridao Hello, I'm sorry to keep asking questions, but may I ask which environment (CUDA, PyTorch version, mamba-ssm, and causal_conv1d versions) you utilized to train Vision Mamba (Vim) with DINOv2 task utilizing bfloat16? When I utilize fp32 or fp16 it runs fine (fp16 will give unstable training though) and code will compile, but bfloat16 keeps giving me errors like this:

TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported:
    1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: Optional[torch.Tensor], arg5: Optional[torch.Tensor], arg6: bool) -> torch.Tensor

I'm currently utilizing CUDA 11.8, python 3.10.13, torch 2.1.1, causal_conv1d 1.1.1, mamba-ssm==1.2.0.post1. I've tried changing up the torch, causal_conv1d, mamba-ssm versions, but with no success, still cannot train with bfloat16. I am installing the causal_conv1d and mamba-ssm packages via wheel utilizing pip (pip install /path/to/.whl) to make sure the packages are compatibile with python, torch and CUDA versions) . Any help would be appreciated!

from mamba.

chokevin8 avatar chokevin8 commented on September 18, 2024

Also, I've looked at other issues, but the only solution that was mentioned in this issue doesn't apply to me since my GPU CCC is >=7. (Utilizing V100 or/and A100).

Could submit a new issue with more details if that could be better.

from mamba.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.