Code Monkey home page Code Monkey logo

Comments (7)

jukofyork avatar jukofyork commented on June 26, 2024

Also it would be nice if we could post these sort of ideas to 'Discussions' instead of 'Issues'! :)

from mergekit.

jukofyork avatar jukofyork commented on June 26, 2024

Well there is a pretty big problem with the discrete uniform distribution assumption and it's causing the weights to be scaled far too much... So without actually being able to measure anything the next best assumption is a Zipf distribution:

import numpy as np

def zipf_distribution(N, s):
    """Generate Zipf distribution for N experts with parameter s."""
    ranks = np.arange(1, N+1)
    weights = 1 / np.power(ranks, s)
    normalization = np.sum(weights)
    probabilities = weights / normalization
    return probabilities

def expected_norm_squared(probabilities, num_experts):
    """Calculate the expected norm squared for a subset of experts."""
    return np.sum(probabilities[:num_experts]**2)

def calculate_scaling_factor(N, n, m, s):
    """Calculate the scaling factor alpha for given N, n, m, and s."""
    probabilities = zipf_distribution(N, s)
    norm_squared_n = expected_norm_squared(probabilities, n)
    norm_squared_m = expected_norm_squared(probabilities, m)
    alpha = np.sqrt(norm_squared_n / norm_squared_m)
    return alpha

N = 8  # num_local_experts
n = 2  # num_experts_per_tok
s = 0  # Skew parameter (0 = Uniform, 0.5 = Square-Root, 1 = Zipf's law)

# Print the Zipf distribution for the given s
probabilities = zipf_distribution(N, s)
print(f"Zipf distribution for s = {s}: {[f'{p:.4f}' for p in probabilities]}")

# Loop over all values of m from 1 to N
for m in range(1, N+1):
    alpha = calculate_scaling_factor(N, n, m, s)
    print(f"For m = {m}, Scaling factor alpha: {alpha:.4f}")
Zipf distribution for s = 0: ['0.1250', '0.1250', '0.1250', '0.1250', '0.1250', '0.1250', '0.1250', '0.1250']
For m = 1, Scaling factor alpha: 1.4142
For m = 2, Scaling factor alpha: 1.0000
For m = 3, Scaling factor alpha: 0.8165
For m = 4, Scaling factor alpha: 0.7071
For m = 5, Scaling factor alpha: 0.6325
For m = 6, Scaling factor alpha: 0.5774
For m = 7, Scaling factor alpha: 0.5345
For m = 8, Scaling factor alpha: 0.5000
Zipf distribution for s = 0.5: ['0.2288', '0.1618', '0.1321', '0.1144', '0.1023', '0.0934', '0.0865', '0.0809']
For m = 1, Scaling factor alpha: 1.2247
For m = 2, Scaling factor alpha: 1.0000
For m = 3, Scaling factor alpha: 0.9045
For m = 4, Scaling factor alpha: 0.8485
For m = 5, Scaling factor alpha: 0.8105
For m = 6, Scaling factor alpha: 0.7825
For m = 7, Scaling factor alpha: 0.7606
For m = 8, Scaling factor alpha: 0.7429
Zipf distribution for s = 1: ['0.3679', '0.1840', '0.1226', '0.0920', '0.0736', '0.0613', '0.0526', '0.0460']
For m = 1, Scaling factor alpha: 1.1180
For m = 2, Scaling factor alpha: 1.0000
For m = 3, Scaling factor alpha: 0.9583
For m = 4, Scaling factor alpha: 0.9370
For m = 5, Scaling factor alpha: 0.9241
For m = 6, Scaling factor alpha: 0.9155
For m = 7, Scaling factor alpha: 0.9093
For m = 8, Scaling factor alpha: 0.9046

I'll see if I can run a grid-search overnight.

from mergekit.

jukofyork avatar jukofyork commented on June 26, 2024

Here's the yaml file if anybody is interested:

# mergekit-yaml --verbose --cuda mixtral-scaled.yaml mixtral-scaled-m
# ~/LLMs/llama.cpp/convert.py mixtral-scaled-m --outfile mixtral-scaled-m.gguf --outtype q8_0
# ~/LLMs/llama.cpp/build/bin/perplexity -m mixtral-scaled-m.gguf -f ~/LLMs/misc/datasets/wikitext-2-raw//wiki.test.raw -ngl 1000

const_tag: &MODEL Mixtral-8x7B-Instruct-v0.1

############################################################################
# Don't forget to also set `num_experts_per_tok` value in `config.json`!!! #
############################################################################

#const_tag: &RESIDUAL_SCALE_FACTOR 1.1180  # [s=0 --> 7.2995]
#const_tag: &RESIDUAL_SCALE_FACTOR 1.0     # 4.4103 +/- 0.02355
const_tag: &RESIDUAL_SCALE_FACTOR 0.9583  # [s=0 --> 4.6758]

# The `down_proj` of each MLP expert seems to be held in the `w2.weight` tensor for Mixtral:
# > current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
# > current_hidden_states = self.w2(current_hidden_states)
models:
  - model: *MODEL
    parameters:
      scale:
        - filter: w2.weight
          value: *RESIDUAL_SCALE_FACTOR
        - value: 1.0

dtype: bfloat16
merge_method: passthrough

from mergekit.

jukofyork avatar jukofyork commented on June 26, 2024

This isn't doing much useful... For 3 experts:

s=0 --> PPL = 4.6758
s=0.5 --> PPL = 4.5406
s=1 --> PPL = 4.4835
s=2 --> PPL = 4.4546

and s=2 is almost the same as unused:

Zipf distribution for s = 2: ['0.6547', '0.1637', '0.0727', '0.0409', '0.0262', '0.0182', '0.0134', '0.0102']
For m = 1, Scaling factor alpha: 1.0308
For m = 2, Scaling factor alpha: 1.0000
For m = 3, Scaling factor alpha: 0.9942
For m = 4, Scaling factor alpha: 0.9924
For m = 5, Scaling factor alpha: 0.9917
For m = 6, Scaling factor alpha: 0.9913
For m = 7, Scaling factor alpha: 0.9912
For m = 8, Scaling factor alpha: 0.9910

vs 2 experts & stock settings --> PPL = 4.4103 +/- 0.02355

It still may be useful to try setting residual_scale for the mergekit-moe merges as they are likely to be much more correlated and less likely to mess up the early embedding transformation layers...

from mergekit.

jukofyork avatar jukofyork commented on June 26, 2024

So next I'm going to try to attenuate the MOE-routing softmax-gate's distribution:

# mergekit-yaml --verbose --cuda mixtral-scaled.yaml mixtral-scaled-m
# ~/LLMs/llama.cpp/convert.py mixtral-scaled-m --outfile mixtral-scaled-m.gguf --outtype q8_0
# ~/LLMs/llama.cpp/build/bin/perplexity -m mixtral-scaled-m.gguf -f ~/LLMs/misc/datasets/wikitext-2-raw/wiki.test.raw -ngl 1000

const_tag: &MODEL Mixtral-8x7B-Instruct-v0.1

############################################################################
# Don't forget to also set `num_experts_per_tok` value in `config.json`!!! #
############################################################################

const_tag: &QK_ATTENUATION_FACTOR 1.0    # NOTE: The scaling effect is QK_ATTENUATION_FACTOR^2 because of the dot-product!!!
const_tag: &GATE_ATTENUATION_FACTOR 0.9  # NOTE: Setting this < 1 will attenuate the MOE-routing softmax-gate's distribution.
const_tag: &RESIDUAL_SCALE_FACTOR 1.0    # NOTE: Attempt to rescale the residual stream when we change `num_experts_per_tok`.

models:
  - model: *MODEL
    parameters:
      scale:
        - filter: q_proj.weight
          value: *QK_ATTENUATION_FACTOR
        - filter: k_proj.weight
          value: *QK_ATTENUATION_FACTOR
        - filter: block_sparse_moe.gate.weight
          value: *GATE_ATTENUATION_FACTOR
        - filter: experts.w2.weight
          value: *RESIDUAL_SCALE_FACTOR
        - value: 1.0

dtype: bfloat16
merge_method: passthrough

and then the score matrix like we did for the frankenmerges:

  • Attenuating the routing softmax will have a similar effect to what was hoped above.
  • Attenuating (or sharpening!) the score matrix might be beneficial when there are more experts working.

from mergekit.

jukofyork avatar jukofyork commented on June 26, 2024

Not really worth bothering with with I think. At best just going to get something about the same but slower to run:

# 2 experts & stock settings   : PPL = 4.4103 +/- 0.02355

# 3 experts
# QK_ATTENUATION_FACTOR 1.10   : PPL = 4.5309 +/- 0.02444
# QK_ATTENUATION_FACTOR 1.05   : PPL = 4.4808 +/- 0.02415
# QK_ATTENUATION_FACTOR 0.95   : PPL = 4.4471 +/- 0.02401
# QK_ATTENUATION_FACTOR 0.90   : PPL = 4.4858 +/- 0.02431
# GATE_ATTENUATION_FACTOR 1.50 : PPL = 4.5641 +/- 0.02446
# GATE_ATTENUATION_FACTOR 1.20 : PPL = 4.4235 +/- 0.02377
# GATE_ATTENUATION_FACTOR 1.10 : PPL = 4.4329 +/- 0.02385
# GATE_ATTENUATION_FACTOR 0.98 : PPL = 4.4561 +/- 0.02404
# GATE_ATTENUATION_FACTOR 0.95 : PPL = 4.4639 +/- 0.02410
# GATE_ATTENUATION_FACTOR 0.90 : PPL = 4.4807 +/- 0.02422
# GATE_ATTENUATION_FACTOR 0.80 : PPL = 4.5236 +/- 0.02454
# QK_ATTENUATION_FACTOR 0.95 & GATE_ATTENUATION_FACTOR 1.20 : PPL = 4.4218 +/- 0.02380

# 4 experts
# GATE_ATTENUATION_FACTOR 1.20 : PPL = 4.4539 +/- 0.02402

from mergekit.

jukofyork avatar jukofyork commented on June 26, 2024

Maybe this idea does have some use after all. If we can scale the gate weight tensor with n=8 to work as closely as possible to n=2, then very low bit quantized models using 2-3 bpw might actually work better and see their perplexity grow less slowly (due to more active weights cancelling out more of the noise caused by quantization).

This assumes that the optimal scale factor doesn't just approximate the hard n=2 thresholding with a soft n=8 version that barely uses the other 6 sets of MLP weights (ie: doesn't shift the lower valued logits so far down that the Gumbel error distributions effectively head towards -inf and contribute almost nothing to the gated sum...).

from mergekit.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.