Comments (3)
己也没有跑通,会卡在反向传播的位置,或许我可以在我的分支上push一个,你接着修改 PR一下
我应该没有这个权限。Lora的问题我最后通过自己写段代码解决了,只不过只实现了最基本的版本,但我试过能用,满足我自己的需求。完善的lora微调支持还需要官方做一下
import torch.nn as nn
import logging
import math
class LoRABlock(nn.Module):
"""
A simple implementation of LoRA
"""
def __init__(self, linear: nn.Linear, rank: int) -> None:
super(LoRABlock, self).__init__()
assert isinstance(linear, nn.Linear), "LoRA only supports Linear module!"
linear_dtype = linear.weight.dtype
input_dim = linear.weight.shape[-1]
out_dim = linear.weight.shape[0]
self.original_linear = nn.Linear(input_dim, out_dim, dtype=linear_dtype)
self.original_linear.weight.data = linear.weight.data.clone().detach()
if linear.bias is not None:
self.original_linear.bias.data = linear.bias.data.clone().detach()
else:
self.original_linear.bias.data.zero_()
self.original_linear.requires_grad_(False)
rank_upper_bound = (input_dim * out_dim) / (input_dim + out_dim + 1)
while rank > rank_upper_bound:
rank = math.floor(rank/2)
logging.warning("Preset rank ({}) was too high, degrading to {}".format(rank * 2, rank))
if rank == 0:
raise ValueError("rank_upper_bound error: current value: {}.\n \
The cause of this issue is: input_dim: {}, out_dim:\
{}".format(rank_upper_bound, input_dim, out_dim))
assert rank <= rank_upper_bound, "Rank is too large to shrink the original model"
self.B = nn.Linear(input_dim, rank, bias=False, dtype=linear_dtype)
self.B.weight.data.random_()
self.A = nn.Linear(rank, out_dim, bias=False, dtype=linear_dtype)
self.A.weight.data.zero_()
self.weight = self.original_linear.weight
self.bias = self.original_linear.bias
def forward(self, x):
origin_output = self.original_linear(x)
lora_modification = self.A(self.B(x))
return origin_output + lora_modification
def substitute_model_with_lora(model: nn.Module, rank: int=32):
"""
replace all linear blocks in a pytorch Module
"""
names = dir(model)
for name in names:
if not name.startswith("_") and not name.startswith("get") and name != "base_model":
obj = getattr(model, name)
if isinstance(obj, nn.Linear):
lora_block = LoRABlock(obj, rank)
setattr(model, name, lora_block)
elif isinstance(obj, nn.Module) and not isinstance(obj, nn.ModuleList):
if next(obj.named_parameters(), None) is not None:
lora_block = substitute_model_with_lora(obj, rank)
setattr(model, name, lora_block)
elif isinstance(obj, nn.ModuleList):
lora_list = nn.ModuleList()
for sub_module in obj:
lora_sub_module = substitute_model_with_lora(sub_module, rank)
lora_list.append(lora_sub_module)
setattr(model, name, lora_list)
else:
pass
return model
from cogvlm.
我们自己也没有跑通,会卡在反向传播的位置,或许我可以在我的分支上push一个,你接着修改 PR一下
from cogvlm.
Have you tried using the peft library provided by huggingface? If so, Any issues with it?
from cogvlm.
Related Issues (20)
- web demo无法中文对话,但可以英文对话 HOT 1
- Release of cogagent dataset HOT 2
- Cannot run huggingface on single GPU: HOT 1
- could it finetune in qlora or quant 4/8? HOT 1
- Could you release the script which can transform CogAgent weights' sat version to huggingface version? HOT 1
- Can I run CogVLM using actual openai API HOT 2
- Some Clarification HOT 1
- CUDA error: an illegal memory access was encountered HOT 3
- 关于多轮对话 HOT 9
- 参数说明
- The model can be started after fine-tuning, but the output value is garbled HOT 1
- 网页demo的训练权重和部署公开的chat版本一致吗?可以公布最新的网页版demo权重吗 HOT 1
- 8张A800(80G)微调Cogagent依然报错: CUDA out of memory HOT 4
- LoRA 合并模型报错 HOT 1
- 复合demo没有正常流式输出? HOT 2
- Assistance Requested for Fine-Tuning on Visual Question Answering Task HOT 4
- 使用 fp16 训练,merge lora 之后的模型推理结果异常 HOT 3
- Installation instructions are insufficient HOT 1
- 用cogagent-chat-hf直接做预测结果中出现数字列表(basic-demo) HOT 4
- 关于模型量化 HOT 14
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from cogvlm.