Comments (7)
Hi!
Full fine-tuning won't work as the model is quantized, but you could try fine-tuning the model using various PEFT techniques which work with quantized base models. Check out QLoRA for example.
Hope this is helpful.
from mixtral-offloading.
@dvmazur any link where this has been implemented or if you have done something similar please share that would be helpful !!
@asmith26 Did you found any method ?
from mixtral-offloading.
The structure of the loaded model is:
(model): MixtralModel(
(embed_tokens): Embedding(32000, 4096)
(layers): ModuleList(
(0-31): 32 x MixtralDecoderLayer(
(self_attn): MixtralAttention(
(q_proj): HQQLinearTritonSavable()
(k_proj): HQQLinearTritonSavable()
(v_proj): HQQLinearTritonSavable()
(o_proj): HQQLinearTritonSavable()
(rotary_emb): MixtralRotaryEmbedding()
)
(block_sparse_moe): SparseMoeWrapper(
(gate): Linear(in_features=4096, out_features=8, bias=False)
)
(input_layernorm): MixtralRMSNorm()
(post_attention_layernorm): MixtralRMSNorm()
)
)
(norm): MixtralRMSNorm()
)
(lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
When I try to train with
from peft import LoraConfig, get_peft_model
config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)
I get that peft is not compatible with HQQLinearTritonSavable, evidently:
ValueError Traceback (most recent call last)
in <cell line: 12>()
10 )
11
---> 12 model = get_peft_model(model, config)
13 print_trainable_parameters(model)
7 frames
/usr/local/lib/python3.10/dist-packages/peft/tuners/lora/model.py in _create_new_module(lora_config, adapter_name, target, **kwargs)
255 if new_module is None:
256 # no module could be matched
--> 257 raise ValueError(
258 f"Target module {target} is not supported. Currently, only the following modules are supported: "
259 "torch.nn.Linear
, torch.nn.Embedding
, torch.nn.Conv2d
, transformers.pytorch_utils.Conv1D
."
ValueError: Target module HQQLinearTritonSavable() is not supported. Currently, only the following modules are supported: torch.nn.Linear
, torch.nn.Embedding
, torch.nn.Conv2d
, transformers.pytorch_utils.Conv1D
.
from mixtral-offloading.
Hey, @nmarafo and @complete-dope!
It looks like using huggingface's peft for fine-tuning the offloaded model is a bit tricky (due to custom layers mostly), but I haven't looked into it myself.
A LoRA fine-tuning setup similar to the original paper can be hacked together quite simply:
# imports
class LoRALayer(nn.Module):
def __init__(self, module: nn.Linear, rank: int):
super().__init__()
self.module = module
self.adapter_A = nn.Parameter(torch.empty(module.in_features, rank, device=module.weight.device))
nn.init.kaiming_uniform_(self.adapter_A, a=5 ** 0.5)
self.adapter_B = nn.Parameter(torch.zeros(rank, module.out_features, device=module.weight.device))
def forward(self, input):
bottleneck = F.linear(input, self.adapter_A.T)
residual = F.linear(bottleneck, self.adapter_B.T)
return self.module(input) + residual
def custom_get_peft_model(model, rank):
for _, module in model.named_modules():
if not isinstance(module, MixtralAttention):
continue
module.q_proj = LoRALayer(module.q_proj, rank)
# TODO: {k, v, o}_proj
return model
Note that this example only applies LoRA to attention parameters. Doing the same for the expert layers is tricker as it might break the ExpertCache (haven't looked into that myself yet).
from mixtral-offloading.
Thank you very much for the answer.
Sorry for my inexperience, I'm trying to implement it like this:
import torch.nn as nn
from transformers.models.mixtral.modeling_mixtral import MixtralAttention
class LoRALayer(nn.Module):
def __init__(self, module: nn.Linear, rank: int):
super().__init__()
self.module = module
self.adapter_A = nn.Parameter(torch.empty(module.in_features, rank, device=module.weight.device))
nn.init.kaiming_uniform_(self.adapter_A, a=5 ** 0.5)
self.adapter_B = nn.Parameter(torch.zeros(rank, module.out_features, device=module.weight.device))
def forward(self, input):
bottleneck = F.linear(input, self.adapter_A.T)
residual = F.linear(bottleneck, self.adapter_B.T)
return self.module(input) + residual
def custom_get_peft_model(model, rank):
for _, module in model.named_modules():
if not isinstance(module, MixtralAttention):
continue
module.q_proj = LoRALayer(module.q_proj, rank)
# TODO: {k, v, o}_proj
return model
model = custom_get_peft_model(model, rank=8)
and I get this error:
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in __getattr__(self, name)
1693 if name in modules:
1694 return modules[name]
-> 1695 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
1696
1697 def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
AttributeError: 'HQQLinearTritonSavable' object has no attribute 'in_features'
from mixtral-offloading.
Perhaps is solved with this:
import torch.nn as nn
from transformers.models.mixtral.modeling_mixtral import MixtralAttention
from src.custom_layers import HQQLinearTritonSavable
class LoRALayer(nn.Module):
def __init__(self, module: HQQLinearTritonSavable, rank: int):
super().__init__()
self.module = module
in_features = module.meta['shape'][1]
out_features = module.meta['shape'][0]
self.adapter_A = nn.Parameter(torch.empty(in_features, rank, device=module.W_q.device))
nn.init.kaiming_uniform_(self.adapter_A, a=5 ** 0.5)
self.adapter_B = nn.Parameter(torch.zeros(rank, out_features, device=module.W_q.device))
def forward(self, input):
bottleneck = F.linear(input, self.adapter_A.T)
residual = F.linear(bottleneck, self.adapter_B.T)
return self.module(input) + residual
def custom_get_peft_model(model, rank):
for _, module in model.named_modules():
if not isinstance(module, MixtralAttention):
continue
module.q_proj = LoRALayer(module.q_proj, rank)
# TODO: {k, v, o}_proj
return model
model = custom_get_peft_model(model, rank=8)
´´´
from mixtral-offloading.
I'm not sure whether (module.meta['shape'][1], module.meta['shape'][0])
is the correct shape. Maybe you should try pulling the correct shape from the original model's config.
from transformers import AutoConfig
config = AutoConfig.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
head_dim = config.hidden_size // config.num_attention_heads
# (in_features, out_features)
q_proj_shape = (config.hidden_size, config.num_attention_heads * head_dim)
k_proj_shape = (config.hidden_size, config.num_key_value_heads * head_dim)
v_proj_shape = (config.hidden_size, config.num_key_value_heads * head_dim)
o_proj_shape = (config.num_attention_heads * head_dim, config.hidden_size)
Haven't checked whether these shapes are correct, but they must be.
If this snippet doesn't work, you could try reconstructing the original shapes from here.
from mixtral-offloading.
Related Issues (17)
- Enhancing the Efficacy of MoE Offloading with Speculative Prefetching Strategies
- Mixtral OffLoading/GGUF/ExLlamaV2, which approach to use? HOT 1
- How to use the offloading in my MoE model? HOT 4
- Doesn't work HOT 10
- Can it run on multi-GPU? HOT 10
- Can it run with LlamaIndex?
- CUDA OOM errors in wsl2
- need mixtral offload for NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO
- hqq_aten package not installed. HOT 1
- Run without quantization HOT 9
- 4bit-3bit model produces gibberish when plugged into demo
- Run on second GPU (torch.device("cuda:1")) HOT 1
- Update Requirements.txt
- a strange issue with default parameters " RuntimeError about memory"
- exl2 HOT 2
- Session crashed on colab HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mixtral-offloading.