Comments (7)
Also it would be nice if we could post these sort of ideas to 'Discussions' instead of 'Issues'! :)
from mergekit.
Well there is a pretty big problem with the discrete uniform distribution assumption and it's causing the weights to be scaled far too much... So without actually being able to measure anything the next best assumption is a Zipf distribution:
import numpy as np
def zipf_distribution(N, s):
"""Generate Zipf distribution for N experts with parameter s."""
ranks = np.arange(1, N+1)
weights = 1 / np.power(ranks, s)
normalization = np.sum(weights)
probabilities = weights / normalization
return probabilities
def expected_norm_squared(probabilities, num_experts):
"""Calculate the expected norm squared for a subset of experts."""
return np.sum(probabilities[:num_experts]**2)
def calculate_scaling_factor(N, n, m, s):
"""Calculate the scaling factor alpha for given N, n, m, and s."""
probabilities = zipf_distribution(N, s)
norm_squared_n = expected_norm_squared(probabilities, n)
norm_squared_m = expected_norm_squared(probabilities, m)
alpha = np.sqrt(norm_squared_n / norm_squared_m)
return alpha
N = 8 # num_local_experts
n = 2 # num_experts_per_tok
s = 0 # Skew parameter (0 = Uniform, 0.5 = Square-Root, 1 = Zipf's law)
# Print the Zipf distribution for the given s
probabilities = zipf_distribution(N, s)
print(f"Zipf distribution for s = {s}: {[f'{p:.4f}' for p in probabilities]}")
# Loop over all values of m from 1 to N
for m in range(1, N+1):
alpha = calculate_scaling_factor(N, n, m, s)
print(f"For m = {m}, Scaling factor alpha: {alpha:.4f}")
Zipf distribution for s = 0: ['0.1250', '0.1250', '0.1250', '0.1250', '0.1250', '0.1250', '0.1250', '0.1250']
For m = 1, Scaling factor alpha: 1.4142
For m = 2, Scaling factor alpha: 1.0000
For m = 3, Scaling factor alpha: 0.8165
For m = 4, Scaling factor alpha: 0.7071
For m = 5, Scaling factor alpha: 0.6325
For m = 6, Scaling factor alpha: 0.5774
For m = 7, Scaling factor alpha: 0.5345
For m = 8, Scaling factor alpha: 0.5000
Zipf distribution for s = 0.5: ['0.2288', '0.1618', '0.1321', '0.1144', '0.1023', '0.0934', '0.0865', '0.0809']
For m = 1, Scaling factor alpha: 1.2247
For m = 2, Scaling factor alpha: 1.0000
For m = 3, Scaling factor alpha: 0.9045
For m = 4, Scaling factor alpha: 0.8485
For m = 5, Scaling factor alpha: 0.8105
For m = 6, Scaling factor alpha: 0.7825
For m = 7, Scaling factor alpha: 0.7606
For m = 8, Scaling factor alpha: 0.7429
Zipf distribution for s = 1: ['0.3679', '0.1840', '0.1226', '0.0920', '0.0736', '0.0613', '0.0526', '0.0460']
For m = 1, Scaling factor alpha: 1.1180
For m = 2, Scaling factor alpha: 1.0000
For m = 3, Scaling factor alpha: 0.9583
For m = 4, Scaling factor alpha: 0.9370
For m = 5, Scaling factor alpha: 0.9241
For m = 6, Scaling factor alpha: 0.9155
For m = 7, Scaling factor alpha: 0.9093
For m = 8, Scaling factor alpha: 0.9046
I'll see if I can run a grid-search overnight.
from mergekit.
Here's the yaml
file if anybody is interested:
# mergekit-yaml --verbose --cuda mixtral-scaled.yaml mixtral-scaled-m
# ~/LLMs/llama.cpp/convert.py mixtral-scaled-m --outfile mixtral-scaled-m.gguf --outtype q8_0
# ~/LLMs/llama.cpp/build/bin/perplexity -m mixtral-scaled-m.gguf -f ~/LLMs/misc/datasets/wikitext-2-raw//wiki.test.raw -ngl 1000
const_tag: &MODEL Mixtral-8x7B-Instruct-v0.1
############################################################################
# Don't forget to also set `num_experts_per_tok` value in `config.json`!!! #
############################################################################
#const_tag: &RESIDUAL_SCALE_FACTOR 1.1180 # [s=0 --> 7.2995]
#const_tag: &RESIDUAL_SCALE_FACTOR 1.0 # 4.4103 +/- 0.02355
const_tag: &RESIDUAL_SCALE_FACTOR 0.9583 # [s=0 --> 4.6758]
# The `down_proj` of each MLP expert seems to be held in the `w2.weight` tensor for Mixtral:
# > current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
# > current_hidden_states = self.w2(current_hidden_states)
models:
- model: *MODEL
parameters:
scale:
- filter: w2.weight
value: *RESIDUAL_SCALE_FACTOR
- value: 1.0
dtype: bfloat16
merge_method: passthrough
from mergekit.
This isn't doing much useful... For 3 experts:
s=0
--> PPL = 4.6758
s=0.5
--> PPL = 4.5406
s=1
--> PPL = 4.4835
s=2
--> PPL = 4.4546
and s=2
is almost the same as unused:
Zipf distribution for s = 2: ['0.6547', '0.1637', '0.0727', '0.0409', '0.0262', '0.0182', '0.0134', '0.0102']
For m = 1, Scaling factor alpha: 1.0308
For m = 2, Scaling factor alpha: 1.0000
For m = 3, Scaling factor alpha: 0.9942
For m = 4, Scaling factor alpha: 0.9924
For m = 5, Scaling factor alpha: 0.9917
For m = 6, Scaling factor alpha: 0.9913
For m = 7, Scaling factor alpha: 0.9912
For m = 8, Scaling factor alpha: 0.9910
vs 2 experts & stock settings --> PPL = 4.4103 +/- 0.02355
It still may be useful to try setting residual_scale
for the mergekit-moe
merges as they are likely to be much more correlated and less likely to mess up the early embedding transformation layers...
from mergekit.
So next I'm going to try to attenuate the MOE-routing softmax-gate's distribution:
# mergekit-yaml --verbose --cuda mixtral-scaled.yaml mixtral-scaled-m
# ~/LLMs/llama.cpp/convert.py mixtral-scaled-m --outfile mixtral-scaled-m.gguf --outtype q8_0
# ~/LLMs/llama.cpp/build/bin/perplexity -m mixtral-scaled-m.gguf -f ~/LLMs/misc/datasets/wikitext-2-raw/wiki.test.raw -ngl 1000
const_tag: &MODEL Mixtral-8x7B-Instruct-v0.1
############################################################################
# Don't forget to also set `num_experts_per_tok` value in `config.json`!!! #
############################################################################
const_tag: &QK_ATTENUATION_FACTOR 1.0 # NOTE: The scaling effect is QK_ATTENUATION_FACTOR^2 because of the dot-product!!!
const_tag: &GATE_ATTENUATION_FACTOR 0.9 # NOTE: Setting this < 1 will attenuate the MOE-routing softmax-gate's distribution.
const_tag: &RESIDUAL_SCALE_FACTOR 1.0 # NOTE: Attempt to rescale the residual stream when we change `num_experts_per_tok`.
models:
- model: *MODEL
parameters:
scale:
- filter: q_proj.weight
value: *QK_ATTENUATION_FACTOR
- filter: k_proj.weight
value: *QK_ATTENUATION_FACTOR
- filter: block_sparse_moe.gate.weight
value: *GATE_ATTENUATION_FACTOR
- filter: experts.w2.weight
value: *RESIDUAL_SCALE_FACTOR
- value: 1.0
dtype: bfloat16
merge_method: passthrough
and then the score matrix like we did for the frankenmerges:
- Attenuating the routing softmax will have a similar effect to what was hoped above.
- Attenuating (or sharpening!) the score matrix might be beneficial when there are more experts working.
from mergekit.
Not really worth bothering with with I think. At best just going to get something about the same but slower to run:
# 2 experts & stock settings : PPL = 4.4103 +/- 0.02355
# 3 experts
# QK_ATTENUATION_FACTOR 1.10 : PPL = 4.5309 +/- 0.02444
# QK_ATTENUATION_FACTOR 1.05 : PPL = 4.4808 +/- 0.02415
# QK_ATTENUATION_FACTOR 0.95 : PPL = 4.4471 +/- 0.02401
# QK_ATTENUATION_FACTOR 0.90 : PPL = 4.4858 +/- 0.02431
# GATE_ATTENUATION_FACTOR 1.50 : PPL = 4.5641 +/- 0.02446
# GATE_ATTENUATION_FACTOR 1.20 : PPL = 4.4235 +/- 0.02377
# GATE_ATTENUATION_FACTOR 1.10 : PPL = 4.4329 +/- 0.02385
# GATE_ATTENUATION_FACTOR 0.98 : PPL = 4.4561 +/- 0.02404
# GATE_ATTENUATION_FACTOR 0.95 : PPL = 4.4639 +/- 0.02410
# GATE_ATTENUATION_FACTOR 0.90 : PPL = 4.4807 +/- 0.02422
# GATE_ATTENUATION_FACTOR 0.80 : PPL = 4.5236 +/- 0.02454
# QK_ATTENUATION_FACTOR 0.95 & GATE_ATTENUATION_FACTOR 1.20 : PPL = 4.4218 +/- 0.02380
# 4 experts
# GATE_ATTENUATION_FACTOR 1.20 : PPL = 4.4539 +/- 0.02402
from mergekit.
Maybe this idea does have some use after all. If we can scale the gate weight tensor with n=8 to work as closely as possible to n=2, then very low bit quantized models using 2-3 bpw might actually work better and see their perplexity grow less slowly (due to more active weights cancelling out more of the noise caused by quantization).
This assumes that the optimal scale factor doesn't just approximate the hard n=2 thresholding with a soft n=8 version that barely uses the other 6 sets of MLP weights (ie: doesn't shift the lower valued logits so far down that the Gumbel error distributions effectively head towards -inf and contribute almost nothing to the gated sum...).
from mergekit.
Related Issues (20)
- Merge only the transformer parts (including the input embedding layer) HOT 5
- Implementation of AdaMerging: Adaptive Model Merging for Multi-Task Learning
- Existing Mergekit algorithms to merge VLM with LLM?
- Qwen/Qwen1.5-1.8B MoE Merging fails HOT 3
- Mixed Precision Merging HOT 1
- Add support for `subfolder` loading
- How to merge a VLM and LLM with different model type.
- Merge of hidden_size
- Relax dependency versions
- EvoMerge Genome Bug HOT 2
- Merging models with different structures in linear HOT 2
- how to merge for different rope scaling? HOT 3
- Require later version of transformers HOT 1
- How to merge only at q_proj layers with SLERP? HOT 2
- `extract_lora.py` can't handle mismatched `lm_head` tensor due to added tokens HOT 10
- Does this run single core? HOT 2
- Questions about density gradient and weight gradient in Ties example
- Support for microsoft/Phi-3-vision-128k-instruct
- parameters: int8_mask: true ?? HOT 1
- Use logscale for operations dealing with norm layers?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mergekit.