Code Monkey home page Code Monkey logo

Comments (6)

fabianlim avatar fabianlim commented on August 29, 2024

Maybe this can be addressed by distributing the adapters using DDP, just as it was done with the AutoGPTQ version

from fms-acceleration.

achew010 avatar achew010 commented on August 29, 2024

Device Mapping Error

Turns out this error is thrown because base_layer W is on cpu device when passed inside W = fast_dequantize(W.t(), W_quant) of BNB's fast_lora.py. However, dequantization of W needs to happen on cuda.

Also, adapters A and B were also on cpu and will subsequently also throw a device mismatch error when matmul with X (which is on cuda) inside the the matmul_lora function.

fused_ops/unsloth_lora/utils.py

def matmul_lora(X, W, W_quant, A, B, s, out = None):
    dtype = X.dtype
    W = fast_dequantize(W.t(), W_quant)

    if X.dim() == 3:
        batch, seq_len, d = X.shape
        X = X.view(-1, X.shape[-1])
        reshape = True
    else:
        reshape = False
    pass

    out = torch.matmul(X, W, out = out)
    if W_quant is not None: del W

    if A is not None:
        # LoRA is enabled
        A, B = A.t(), B.t()
        out += (X @ A.to(dtype)) @ (s * B.to(dtype))
    pass
    
    return out.view(batch, seq_len, -1) if reshape else out
pass

Checking Device Preparation

Just before the foak patching, the model itself has been casted to cuda but self_attn base layer weights and adapters are still on cpu due to low memory mode

  (Pdb) model.device
  device(type='cuda', index=0)
  (Pdb) model.get_base_model().model.layers[0].self_attn.q_proj.base_layer.weight.device
  device(type='cpu')
  (Pdb) model.get_base_model().model.layers[0].self_attn.q_proj.lora_A.default.weight.device
  device(type='cpu')

However, removing the FOAK patching and seems to reverse the problem and FSDP-QLoRA with low memory mode trains perfectly fine

My guess is since the FOAK patch happens before the trainer prepares the model, the patching is performed on weights still residing on cpu and will subsequently cause problems when self references to module weights not placed on the gpu.

I made a temporary workaround is to cast the attention module to X.device when X is passed in below. This is not the correct solution but it avoids the error.

Temp Workaround

fused_ops/unsloth_lora/bnb/fast_lora.py

def apply_lora_qkv(self, X):
    self = self.to(X.device) # TEMPFIX: adding this will cast the module to device
    QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
    KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
    VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
    Q, K, V = LoRA_QKV.apply(X,
        QW, QW_quant, QA, QB, QS,
        KW, KW_quant, KA, KB, KS,
        VW, VW_quant, VA, VB, VS,
    )
    return Q, K, V
pass

Testing Command

CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file scripts/benchmarks/accelerate.yaml --num_processes=2 --main_process_port=29500 -m tuning.sft_trainer --model_name_or_path mistralai/Mistral-7B-v0.1 --acceleration_framework_config_file sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml --packing True --max_seq_len 4096 --fp16 True --learning_rate 2e-4 --torch_dtype float16 --peft_method lora --r 16 --lora_alpha 16 --lora_dropout 0.0 --target_modules q_proj k_proj v_proj o_proj --use_flash_attn True --response_template '\n### Response:' --dataset_text_field 'output' --include_tokens_per_second True --num_train_epochs 1 --gradient_accumulation_steps 1 --gradient_checkpointing True --evaluation_strategy no --save_strategy no --weight_decay 0.01 --warmup_steps 10 --adam_epsilon 1e-4 --lr_scheduler_type linear --logging_strategy steps --logging_steps 10 --max_steps 30 --training_data_path benchmark_outputs/data/cache.json --skip_memory_metrics True --per_device_train_batch_size 4 --output_dir benchmark_outputs/exp_1/hf

from fms-acceleration.

fabianlim avatar fabianlim commented on August 29, 2024

@achew010 are you sure that this fixes the BNB case, because I realized i was getting the exact same error with the GPTQ case. The reason is because of #26 , where now in low_mem mode, we do not move the whole model directly to GPU, and we also ignore the adapters from FSDPing, so this is the reason why the adapters stayed on CPU.

So I fixed it in #29

from fms-acceleration.

fabianlim avatar fabianlim commented on August 29, 2024

@achew010 Update: The root cause is not because of the lora weights staying on cpu, you can try the following:

  • try with my fix in #29. In this case the lora weights will be on GPU, but the problem persists
  • I tried with your workaround, it works, but this is because the it moves the base layer weights to gpu

We can see this after my fix in #29, the base layer weights are on the GPU.

[('q_proj.base_layer.weight', device(type='cpu')), ('q_proj.lora_A.default.weight', device(type='cuda', index=0)), ('q_proj.lora_B.default.weight', device(type='cuda', index=0)), ('k_proj.base_layer.weight', device(type='cpu')), ('k_proj.lora_A.default.weight', device(type='cuda', index=0)), ('k_proj.lora_B.default.weight', device(type='cuda', index=0)), ('v_proj.base_layer.weight', device(type='cpu')), ('v_proj.lora_A.default.weight', device(type='cuda', index=0)), ('v_proj.lora_B.default.weight', device(type='cuda', index=0)), ('o_proj.base_layer.weight', device(type='cpu')), ('o_proj.lora_A.default.weight', device(type='cuda', index=0)), ('o_proj.lora_B.default.weight', device(type='cuda', index=0))]

I think the real issue is because BNB QLoRA does not work with FSDP low memory mode. And I think we need to fix it from the root cause, I feel the workaround is dangerous because in FSDP the parameters are being sharded and deshareded, so putting a .to in a foward function is not very safe.

from fms-acceleration.

achew010 avatar achew010 commented on August 29, 2024

Update

@fabianlim you are right, QLoRA doesn't work with FSDP and low memory mode, the weights stay in cpu until the FSDP wrapping here. Similar to the issue with #29, the base layer wasnt casted because of the FSDP ignored_modules in FSDP-FOAK workaround. I made a fix here, this is done before Trainer.train is called and will not interfere with FSDP sharding and unsharding at forward time. Im encountering the same device errors with set_module_tensor_to_device so im using the cuda method for now.

For GPTQ, the fix in #29 resolved the casting of adapters in ignored modules to cuda. The base layer is mapped to a meta device from model initialization (see in code) and the weights don't materialize until training.

I was thinking to load the QLoRA weights in meta device (reference) and maybe adopt the same way GPTQ reinitializes their meta tensors but i haven't found out how GPTQ exactly does the re-initialization in their code yet.

from fms-acceleration.

fabianlim avatar fabianlim commented on August 29, 2024

This has been addressed by #31.

from fms-acceleration.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.