Code Monkey home page Code Monkey logo

mup's Introduction

Maximal Update Parametrization (μP) and Hyperparameter Transfer (μTransfer)

Paper link | Blog link | YouTube link

In Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer, we show that optimal hyperparameters become stable across neural network sizes when we parametrize the model in maximal update parametrization (μP). This can be used to tune extremely large neural networks such as large pretrained transformers, as we have done in our work. More generally, μP reduces the fragility and uncertainty when transitioning from exploration to scaling up, which are not often talked about explicitly in the deep learning literature.

Figure above: Training loss against learning rate on Transformers of varying d_model trained with Adam.

μP turns out to be the unique "natural" parametrization that has this hyperparameter stability property across width, as empirically verified in the gif below on MLPs trained with SGD. Here, across time, we interpolate between PyTorch default and μP's learning rate and initialization scalings (right), and we scale up the width-256 model (log2(width)=8) to width 2^13 = 8192 using this interpolated scaling rule (left).

This repo contains the source code for the mup package, our tool that makes the implementation of μP in Pytorch models effortless and less error-prone.

Table of Contents

Installation

pip install mup

Install From Source

Clone this repo, change to its directory, and do

pip install -r requirements.txt
pip install -e .

Basic Usage

from mup import MuReadout, make_base_shapes, set_base_shapes, MuSGD, MuAdam

class MyModel(nn.Module):
    def __init__(self, width, ...):
        ...
        ### In model definition, replace output layer with MuReadout
        # readout = nn.Linear(width, d_out)
        readout = MuReadout(width, d_out)
        ### If tying weights with an input nn.Embedding layer, do
        # readout = MuSharedReadout(input_layer.weight)
        ...
    def forward(self, ...):
        ...
        ### If using a transformer, make sure to use
        ###   1/d instead of 1/sqrt(d) attention scaling
        # attention_scores = query @ key.T / d**0.5
        attention_scores = query @ key.T * 8 / d
        ### We use 8/d instead of 1/d here to be backward compatible
        ###   with 1/d**0.5 when d=64, a common head dimension.
        ...

### Instantiate a base model
base_model = MyModel(width=1)
### Optionally, use `torchdistx.deferred_init.deferred_init` to avoid instantiating the parameters
### Simply install `torchdistx` and use
# base_model = torchdistx.deferred_init.deferred_init(MyModel, width=1)
### Instantiate a "delta" model that differs from the base model
###   in all dimensions ("widths") that one wishes to scale.
### Here it's simple, but e.g., in a Transformer, you may want to scale
###   both nhead and dhead, so the delta model should differ in both.
delta_model = MyModel(width=2) # Optionally use `torchdistx` to avoid instantiating

### Instantiate the target model (the model you actually want to train).
### This should be the same as the base model except 
###   the widths could be potentially different.
### In particular, base_model and model should have the same depth.
model = MyModel(width=100)

### Set base shapes
### When `model` has same parameter shapes as `base_model`,
###   `model` behaves exactly the same as `base_model`
###   (which is in PyTorch's default parametrization).
###   This provides backward compatibility at this particular model size.
###   Otherwise, `model`'s init and LR are scaled by μP.
### IMPORTANT: this should be called as soon as possible,
###   before re-initialization and optimizer definition.
set_base_shapes(model, base_model, delta=delta_model)

### Alternatively, one can save the base model shapes in a file
# make_base_shapes(base_model, delta_model, filename)
### and later set base shapes directly from the filename
# set_base_shapes(model, filename)
### This is useful when one cannot fit both 
###   base_model and model in memory at the same time

### Replace your custom init, if any
for param in model.parameters():
    ### If initializing manually with fixed std or bounds,
    ### then replace with same function from mup.init
    # torch.nn.init.uniform_(param, -0.1, 0.1)
    mup.init.uniform_(param, -0.1, 0.1)
    ### Likewise, if using
    ###   `xavier_uniform_, xavier_normal_, kaiming_uniform_, kaiming_normal_`
    ### from `torch.nn.init`, replace with the same functions from `mup.init`

### Use the optimizers from `mup.optim` instead of `torch.optim`
# optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer = MuSGD(model.parameters(), lr=0.1)

### Then just train normally

Note the base and delta models do not need to be trained --- we are only extracting parameter shape information from them. Therefore, optionally, we can avoid instantiating these potentially large models by using the deferred_init function in torchdistx. After installing torchdistx, use torchdistx.deferred_init.deferred_init(MyModel, **args) instead of MyModel(**args). See this page for more detail. In the MLP and Transformer examples (not mutransformers) we provided, you can activate this feature by passing --deferred_init.

How mup Works Under the Hood

By invoking set_base_shapes(model, ...), each parameter tensor p of model gets a p.infshape attribute that stores, for each of its dimensions, the corresponding base dimension and whether that dimension should be considered infinite (i.e. will be scaled up/down, e.g., d_model of a Transformer) or finite (i.e. will be fixed, e.g., vocabulary size). This information is used in the initializers and optimizers to automatically scale the parameters or learning rates to be compliant with μP. For example, the Adam learning rate of hidden weights p is calculated as globalLR / p.infshape.width_mult(), where p.infshape.width_mult() essentially calculates fan_in / base_fan_in.

Current Limitations

  • set_base_shapes(model, ...) assumes that model has just been randomly initialized in the standard way and rescales its parameters using the base shape information so the model is in μP.
  • If you want data parallelism, please use torch.nn.parallel.DistributedDataParallel instead of torch.nn.DataParallel. This is because the latter removes the attributes the mup package adds to each parameter tensor of the model. Also, for performance, pytorch recommends the former anyway.
  • We scale the learning rate according to μP explicitly by creating refined parameter groups from what is passed to the mup optimizer and by manipulating the lr attribute in those groups. This is compatible with PyTorch's learning rate schedulers. However, if you roll your own, make sure the scheduler sets the learning rate relative to what is currently in the refined parameter groups. The following is an example of what not to do and what is OK:
optimizer = mup.MuAdam(model.parameters(), lr=1e-3)
for pg in optimizer.param_groups:
  # what NOT to do: setting learning rate absolutely
  # pg['lr'] = 1e-3 * 2
  # what is an OK alternative: setting it relatively
  pg['lr'] *= 2
  • By default, any parameter matrix that has 2 "infinite" dimensions (i.e. dimensions that are different from base dimensions) are considered by mup to have shape (fan_out, fan_in), i.e., in the forward pass, this matrix multiplies its input on the right. This is the case with all nn.Linear weights from pytorch. If you have a custom parameter, say W, that violates this convention, you can manually set W.infshape.main_idx = 0; W.infshape.main = W.infshape[0] to let mup know that its shape corresponds to (fan_in, fan_out). A similar discussion applies if you have a parameter tensor with many dimensions but exactly 2 "infinite" dimensions, for which the first is fan_in and the second is fan_out.
  • Currently, torch.save does not save the infshape objects attached to each parameter tensor. Before this is fixed, you would have to set base shape manually after loading a model checkpoint like so:
model = torch.load('my/model/path.pt')
# Important: note the flag `rescale_params=False`!
set_base_shapes(model, 'my/base/shape/path.bsh', rescale_params=False)

(set_base_shapes by default rescales the parameters of model, assuming it's freshly initialized by PyTorch, to be consistent with μP. The rescale_params=False flag turns off this behavior.)

Checking Correctness of Parametrization

Coord Check

Just like gradient checking is a simple way of verifying the correctness of an autograd implementation, coordinate checking is a simple way to verify you have implemented μP correctly: calculate the average size (which we denote in the y-axis below by l1) of the coordinates of each activation vector in, and output of, the model, for a few steps of training and a few different widths. If implemented correctly, then we shall see this l1 stable over many widths; otherwise, the l1 can blow up or shrink to 0 with width. (We are essentially checking desideratum 1 described below.) (The l1 calculates x.abs().mean() for each activation vector x and is just one measure of the "average size" of x's entries; one can also use analogously defined l2, l4, etc, though they may exhibit greater fluctuation with random seeds.)

For example, in the following, we plot width vs l1 for 2 steps of training, where t=1 means at initialization, before any gradient update. Each curve corresponds to an (pre-)activation vector of a layer or the output of the network. The first set of 3 plots shows an MLP in standard parametrization (SP), trained by adam. We see after 1 step of update, activation/output l1 are exploding with width. This means SP is "incorrect." We now do the same for an MLP in maximal update parametrization (μP) (including using mup.optim.MuAdam instead of torch.optim.Adam). In contrast to the above, all curves stay horizontal, indicating that μP is implemented correctly. We call this way of checking implementation correctness a coord check, short for "coordinate check."

Making Your Own Coord Check Plots

We provide an easy way to implement this check via functions in the mup.coord_check module. The workflow typically looks like the following.

from mup.coord_check import get_coord_data, plot_coord_data
# construct a dictionary of lazy μP models with differing widths
def lazy_model(width):
    # `set_base_shapes` returns the model
    return lambda: set_base_shapes(MyMuModel(width), 'my/base/shape/path.bsh')
    # Note: any custom initialization with `mup.init` would need to
    # be done inside the lambda as well
models = {64: lazy_model(64), ..., 1024: lazy_model(1024)}
# make a dataloader with small batch size/seq len
#   just for testing
dataloader = ...
# record data from the model activations over a few steps of training
# this returns a pandas dataframe
df = get_coord_data(models, dataloader)
# This saves the coord check plots to filename.
plot_coord_data(df, save_to=filename)
# If you are in jupyter notebook, you can also do
#   `plt.show()`
# to show the plot

For example, the mup.coord_check.example_plot_coord_check function is implemented this way for toy MLP and CNN models.

If you see the curves blow up or shrink to 0 with width after a few steps of training, then there's a bug in your μP implementation (did you forget to vary some dimension, like d_ffn, in the delta model?). If instead you see the curves converge to the right, then most likely your implementation is correct. However, there are two typical exceptions to this; the following can shrink to 0 at initialization in μP (at a 1/sqrt(width) rate):

  • the network output
  • the attention logits in a Transformer

These are transient, and after a few steps their curves should be roughly flat. Nevertheless, to remove the discrepancy at init, we recommend

  • initializing the output layer (should be a MuReadout instance) weights to be 0 via the readout_zero_init=True option and
  • initializing the query matrix in a Transformer to 0 (this has to be done manually). If symmetry-breaking is desired in the attention logits at init, initialize the (relative) position biases with nonzero variance.

Tips for Coord Check

  • Use a large learning rate (larger than you'd use for actual training). This would emphasize any potential exploding coordinates issue, which could be hidden by the initialization if the learning rate is too small.
  • If you reuse a module multiple times in the forward pass, then mup.get_coord_data will only record the statistics from the last usage. In this case, for testing purposes, one can wrap different usages with nn.Identity modules of different names to distinguish them.

Wider is Always Better

Another sign that μP has not been implemented correctly is if going wider does worse (on training loss) after some width, at some point during training. The figure above illustrates this in a collection of training curves: (left) the correct implementation should always see performance improve with width, at any point in training; (middle) if you used standard parametrization (SP), sometimes you may see performance improve with width up to some point and then suddenly it becomes worse with wider models; (right) or you may immediately see worsening performance even for narrow models.

Examples

See the MLP, Transformer, and ResNet folders inside examples/ as well as the tests in mup/test for examples. People familiar with Huggingface Transformers may also find the examples/mutransformers submodule instructive (obtained via git submodule update --init), which is also available standalone at https://github.com/microsoft/mutransformers.

Native Integration With Huggingface

Frustrated that your Huggingface Transformer breaks when you scale up? Want to tune hyperparameters for your large mult-GPU Huggingface Transformer on a single GPU, right out the box? If so, please upvote this github issue!

Running Tests

To run tests, do

python -m mup.test

The Basic Math

μP is designed so as to satisfy the following desiderata:

At any time during training

  1. Every (pre)activation vector in a network should have Θ(1)-sized coordinates
  2. Neural network output should be O(1).
  3. All parameters should be updated as much as possible (in terms of scaling in width) without leading to divergence

It turns out these desiderata uniquely single out μP. To derive μP from them, one needs to carefully consider how the coordinate size of a vector Av, resulting from a square matrix A multiplying vector v, depends on those of A and v, when A and v are "correlated". Here you can think of A as weights and v as an activation vector. This in turn depends on what kind of matrix is A and what kind of vector is v. In the context of training a wide neural network, it turns out we only need to consider vectors that has approximately iid coordinates, and two kinds of matrices: 1) those that look like outer products of such vectors, and 2) random iid matrices. Those of type 1 cover things like weight gradients; those of type 2 cover things like weight initialization. Then, if A and v both have entry size Θ(1) and they are correlated in ways that arise naturally during training, then we have the following table.

outer product A (type 1) iid A (type 2)
Entry size of Av Θ(n) Θ(sqrt(n))

Given this table, one can then trace the forward and backward computation of a network to derive μP straightforwardly.

See our blog post for a gentle primer and our paper for details.

Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.

When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact [email protected] with any additional questions or comments.

Trademarks

This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies.

mup's People

Contributors

edwardjhu avatar microsoftopensource avatar msft-edward avatar tevenlescao avatar thegregyang avatar zanussbaum avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

mup's Issues

dim_feedforward

Sorry, my fat fingers clicked on the wrong field on this one. (-:

muP for contrastive losses

Hi, I have a question regarding the use of muP in contrastive losses: Assume we have anchor embedding x, positive embedding x_pos, and negative embedding x_neg. All x, x_pos, and x_neg are C-dim vectors where C represents the width that is categorized as an infinite dimension. The loss L is formulated as:

L = -log( exp(sim(x, x_pos)) / (exp(sim(x, x_pos)) + exp(sim(x, x_neg))) )

where sim(a, b) = cos(a, b) for each embedding pair. It seems the sim() merges two infinite-dim vectors to a finite one, which is similar to the Q K^T operation in self-attention. However, the difference is that the cosine similarity already bounds the output. Thus, I wonder if there is anything we need to change in the loss function when we use muP? Thanks!

Positional Embeddings should be MuReadout parameters ?

Duplicate of question asked on the mutransformers repository (link)

Hi !
I was wondering if (learned) positional embeddings should be MuReadout layers, since they map to a finite dimensional space. Specifically

https://github.com/microsoft/mutransformers/blob/480287ce7b18a07a3432e8f2fbc0f0e5b71e2599/mutransformers/models/bert/modeling_bert.py#L174

self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)

In addition to that, did you try using muP for sparse MoE models ? Am curious about any findings for those. Specifically, I was wondering if the routing gate (hdim, num_experts) would also be a MuReadout layer (if we don't scale the number of experts).

Would be grateful for any advice :)

Thank you !

mu parametrization for channel attention

Hi, I have another question about the mu parametrization for a special attention mechanism - channel attention.

In standard scaled dot-product attention (also regarded as spatial attention), we have Q, K, V with shape n x d (ignoring heads) and we will calculate softmax(scale * Q K^T) V to get a n x d output, where scale = 1/sqrt(d) in SP and scale = 1/d in muP (or 1/sqrt(d_0) / width_mult in muP for backward compatiblity).

In channel attention, we still have Q, K, V with shape n x d (ignoring heads). The different part is, we will calculate (softmax(scale * Q^T K) V^T)^T to get a n x d output, where scale = 1/sqrt(n) in SP. Since the attention map Q^T K now has shape d x d instead of n x n, I am not sure how the scale should be modified in SP accordingly. Should we use 1/sqrt(n) / width_mult?

In addition, Appendix B - Matrix-Like, Vector-Like, Scalar-Like Parameters has some interpretation behind the scale:

a multiplier of order 1=fan_in should accompany any weight that maps an infinite dimension to a finite one. This interpretation then nicely covers both the output logits and the attention logits (i.e. 1/d attention).

But such interpretation may not be directly used as a guidance to set up the scale in the channel attention.

Thanks!

µTransfer across batch size && weight decay setting

Hello there! I have reproduced the results of Transformer on your code and am now looking to apply mup to my own model. However, I have some doubts and would greatly appreciate it if you could explain them to me.

Question 1: µTransfer across batch size
In the paper, it is mentioned that the optimal learning rate can be transferred between different batch sizes, for example, in Figure 19. However, I couldn't find the relevant implementation in the provided examples.
I would like to confirm if the experiments were conducted by keeping all other hyperparameters fixed and only modifying the batch size. Regarding Init. Var. and LR, is it following the approach of width where a base batch size is set and then Init. Var. and LR are scaled according to the current batch size? If I want to simultaneously adjust the model's width and batch size while keeping the optimal learning rate unchanged, is it feasible?

Question 2: weight decay setting
Regarding the setting of weight decay, the explanation in the paper is also concise, and I am a bit confused about the implementation in AdamW. I noticed in the paper that weight decay should be independent of the width, and I found a code snippet on Hugging Face which uses mup. https://huggingface.co/cerebras/btlm-3b-8k-base/blob/main/modeling_btlm.py

    def get_mup_param_groups(self, lr, weight_decay=0.0, decoupled_wd=True):
        """
        Returns list of dicts defining parameter groups for muP:
        group 0: most model params get scaled learning rate and weight decay.
        group 1: embedding layer gets non-scaled learning rate and weight decay.
        group 2: normalization layers and biases get non-scaled learning rate only.
        The output can be passed to Adam-base optimizers 
        e.g.
            param_groups = model.get_mup_param_groups(lr=1e-3, weight_decay=0.1)
            torch.optim.AdamW(param_groups, betas=(0.9, 0.95), eps=1e-8)
        """
        norm_modules = (
            torch.nn.LayerNorm,
            torch.nn.BatchNorm1d,
            torch.nn.BatchNorm2d,
            torch.nn.BatchNorm3d,
            torch.nn.InstanceNorm1d,
            torch.nn.InstanceNorm2d,
            torch.nn.InstanceNorm3d,
            torch.nn.GroupNorm,
            torch.nn.SyncBatchNorm,
            torch.nn.LocalResponseNorm,
        )

        def get_group_index(param_name):
            for name, module in self.named_modules():
                if name in param_name:
                    if isinstance(module, norm_modules):
                        return 2
                    elif isinstance(module, torch.nn.Embedding):
                        return 1
            return 0

        width_scale = self.config.mup_width_scale
        new_param_groups = []
        new_param_groups.append({"params": [], "lr": lr * width_scale, "weight_decay": weight_decay})
        if not decoupled_wd:
            new_param_groups[0]["weight_decay"] /= width_scale
        new_param_groups.append({"params": [], "lr": lr, "weight_decay": weight_decay})
        new_param_groups.append({"params": [], "lr": lr, "weight_decay": 0.0})

        for name, param in self.named_parameters():
            if not param.requires_grad:
                continue

            if name.endswith("bias"):
                new_param_groups[2]["params"].append(param)
            else:
                new_param_groups[get_group_index(name)]["params"].append(param)

        for idx, param_group in enumerate(new_param_groups):
            if len(param_group["params"]) == 0:
                del new_param_groups[idx]

        return new_param_groups

Is it following this approach to categorize the parameters into three types and set their respective learning rates and weight decays? And the decoupled_wd should be set to True so that weight decay will not be scaled?

Once again, I appreciate your valuable work and the significance it holds in this field. I am eagerly looking forward to your kind explanation.

Are Sequentials with list comprehension handled incorrectly?

Because:

class TheModel(nn.Module):

    def __init__(self, n_token_embed, n_layers):
        super().__init__()
        n_heads = n_token_embed // 2
        n_key_size = n_token_embed
        self.token_embedding_table = nn.Embedding(len(all_symbols), n_token_embed)
        self.position_embedding_table = nn.Embedding(N_CONTEXT, n_token_embed)
        self.blocks = nn.Sequential(*[Block(n_token_embed, n_key_size, n_heads) for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(n_token_embed) # final layer norm
        self.lm_head = nn.Linear(n_token_embed, N_CATEGORIES * 2)
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
[<ipython-input-21-944f2038ce18>](https://localhost:8080/#) in <cell line: 86>()
     84 delta_model = TheModel(4, N_LAYERS)
     85 model = TheModel(N_TOKEN_EMBED, N_LAYERS)
---> 86 set_base_shapes(model, base_model, delta=delta_model)
     87 
     88 m = model.to(device)

1 frames
[/usr/local/lib/python3.9/dist-packages/mup/shape.py](https://localhost:8080/#) in _zip_infshape_dict(base_shapes, shapes)
     93     basenames = set(base_shapes.keys())
     94     names = set(shapes.keys())
---> 95     assert basenames == names, (
     96         f'`base_shapes` has extra names {basenames - names}. '
     97         f'`shapes` has extra names {names - basenames}.'

AssertionError: `base_shapes` has extra names set(). `shapes` has extra names {'blocks.0.sa.heads.1.value.weight', 'blocks.2.sa.heads.1.query.weight', 'blocks.2.sa.heads.1.value.weight', 'blocks.0.sa.heads.1.key.weight', 'blocks.1.sa.heads.1.value.weight', 'blocks.1.sa.heads.1.key.weight', 'blocks.2.sa.heads.1.key.weight', 'blocks.0.sa.heads.1.query.weight', 'blocks.1.sa.heads.1.query.weight'}.

Consider decoupled weight decay optimizers?

Hi! I'm a big fan of this project and I noticed that MUP has some wrappers for common optimizers MuAdam, MuAdamW, and MuSGD. Given the focus on hyperparameter stability and tuning, I was wondering if you might be interested in adding / experimenting with the decoupled weight decay optimizers from this paper (https://arxiv.org/abs/1711.05101)?

For context, PyTorch's implementations of SGD, Adam, and even AdamW all scale the weight decay by the learning rate in their update step, and I've found that this makes it tough to tune the two values independently (if you increase LR, you also silently increase the effective WD). This is due to PyTorch's scheduler implementation which only updates an Optimizer's LR, and so they schedule the WD in sync by multiplying the two values together.

E.g. here is Pytorch's AdamW code, which shows why it is not really decoupled: https://github.com/pytorch/pytorch/blob/11231b0f935c092598c994a4bab686485aac1856/torch/optim/adamw.py#L248

The "correct" way to decouple LR, WD is described in the paper, and we have some PyTorch-ready implementations here (code, docs) in MosaicML's Composer library. Though I haven't seen any examples in MUP yet of tuning the WD across model sizes, I feel like this could be a common hparam that users want to tune and that DecoupledSGDW or DecoupledAdamW could help make it more stable :)

Optimizers for coord check

Thank you for your great work! When trying the coord check in the examples, I noticed that the original optimizers (e.g., sgd, adam) are used instead of the muP optimizers (e.g., musgd, muadam). However, according to the Table 8 in the paper, the optimizers should be adjusted accordingly to make activations bounded. Is there any reason behind the use of original optimizers?

Should `base=None` be used in `set_base_shapes` for model used for tuning?

Hello! First of all, thank you for doing such great work and making it so accessible. I'm looking at using mup for a project but I'm a bit confused about how to set the base shapes for the smaller model used for hyperparameter tuning.

Let's say I want to train an MLP with hidden dimension 1024, and I want to muTransfer the best learning rate from an MLP with hidden dimension 128. My top-level code might look like this:

best_loss = float('inf')
best_lr = 0.

# Hyperparameter sweep with hidden dimension 128
for lr in learning_rates:

    small_mlp = MLP(hidden_dim=128)

    # use `base=None` in `set_base_shapes`
    small_mlp = mup.set_base_shapes(small_mlp, base=None)

    final_loss = full_training_loop(small_mlp, lr=lr)

    if final_loss < best_loss:
        best_loss = final_loss
        best_lr = lr

# Transfer optimal LR to large model

base_mlp = MLP(hidden_dim=128)
big_mlp = MLP(hidden_dim=1024)

big_mlp = mup.set_base_shapes(big_mlp, base=base_mlp)

ultimate_loss = full_training_loop(big_mlp, lr=best_lr)

or like this:

best_loss = float('inf')
best_lr = 0.

for lr in learning_rates:

    small_mlp = MLP(hidden_dim=128)

    # use a base model in `set_base_shapes`
    smaller_mlp = MLP(hidden_dim=32)
    small_mlp = mup.set_base_shapes(small_mlp, base=smaller_mlp)

    final_loss = full_training_loop(small_mlp, lr=lr)

    if final_loss < best_loss:
        best_loss = final_loss
        best_lr = lr

# Transfer optimal LR to large model

base_mlp = MLP(hidden_dim=128)
big_mlp = MLP(hidden_dim=1024)

big_mlp = mup.set_base_shapes(big_mlp, base=base_mlp)

ultimate_loss = full_training_loop(big_mlp, lr=best_lr)

Could you please clarify which of these would be correct? Thank you very much for your time!

Batch size, Seq len, Step Transfering

Hi!
I didn't fully understand how the transfer of parameters such as batch_size/seq_len/steps should work (Figure 17, 19 in the article). Also I didn't find any mention of this either in the article or in the library code
It would seem that according to the idea of mup, we shouldn't do any scales for these parameters, but then it is unclear how it works with batch size. Should I forget about all lr/batch_size dependency rules? what will happen to the convergence rate in this case ?

missing os import in mup/examples/MLP/main.py ?

Just thought I'd let you know that, when I ran python mup/examples/MLP/main.py --load_base_shapes ./mup/examples/MLP/width64.bsh, I got an error at the very bottom (when saving the .tsv) because os wasn't imported (losing a few hours of runtime!). Adding import os to the header fixed that. Maybe I was doing something wrong, but thought I'd flag it just in case

MuP Coord Check not Working with Electra Style Model

I'm trying to use an Electra-Style model with µP but am not able to get a the coord plots to work correctly. Currently, I have Readout layers on both the Discriminator and Generator.

Creating coord checks for the Discriminator and Generator alone seem to work, but when combined the µP plot does not seem as expected.

Generator coord checks:
μp_electra_generator_adam_lr0 001_nseeds5_coord
sp_electra_generator_adam_lr0 001_nseeds5_coord

Discriminator coord checks:
μp_electra_adam_lr0 001_nseeds5_coord
sp_electra_adam_lr0 001_nseeds5_coord

Electra Model coord checks:

sp_electra_model_adam_lr0 001_nseeds5_coord
μp_electra_model_adam_lr0 001_nseeds5_coord

Will µP not work for "multi-task" losses like here where the overall loss is a weighted sum of mlm_loss and disc_loss?

Does mup support fine tuning pretrained models

Hi, I'm trying to tune hyperparameters of a pretrained model (e.g. resnet or swin-transformer) during the fine tuning stage. If I scale the model using mup, the pretrained weight cannot be used anymore. And I think the best hps for fully training a model might be different from the best hps for fine tuning a model. Can mup be applied to this scenario?
Thanks.

Can base model be larger than target model?

This might be a naive question. In the provided examples, it seems that all base models are small models and target models are larger ones. However, if I have a large model to be tuned, and I don't want to change the original training process of the large model, can I set the large model as the base model? In my understanding, mup behaves the same as torch when model.width == base_model.width.

To illustrate:

# suppose I have a large model: large_model = Model(width=100)
# build mup model
base_model = mupModel(width=100)
smaller_model = mupModel(width=10)

# train and tune lr in the smaller model
best_lr = tune(smaller_model)

# train large model with best lr once
# Does Model(width=100) == mupModel(width=100) ?
train(Model(width=100), best_lr)

Is the above example correct?

Issue in reproducing the training loss vs learning rates curve

Hi,
First of all, thanks for sharing your work.

We tried to reproduce the expected behavior of muP, using ResNet18 and the CIFAR10, as provided in the main script of your repository. The idea was to launch a training, for multiple learning rates and width_mult, and get the minimum loss each time, as you did in your paper, to ensure that the best learning rate doesn't change with a different width_mult.

We modified a bit the main.py script, to skip the saving/loading of the base shape file, as follows:

'''Train CIFAR10 with PyTorch.'''
import argparse
import os
from time import gmtime, strftime
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from mup import MuAdam, MuSGD, get_shapes, set_base_shapes
from copy import deepcopy
from mup.infshape import InfShape
from mup.shape import clear_dims, zip_infshapes
from torch.utils.tensorboard import SummaryWriter
import resnet


# Training
def train(epoch, net, writer):
#    from utils import progress_bar
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    writer.add_scalar("Train/Loss", train_loss/(batch_idx+1), epoch)
    writer.add_scalar("Train/Acc", 100.*correct/total, epoch)
    
    return train_loss/len(trainloader)

def test(epoch, net, writer):
#    from utils import progress_bar
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    writer.add_scalar("Test/Loss", test_loss, epoch)
    writer.add_scalar("Test/Acc", 100.*correct/total , epoch)

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc
        
    return test_loss/len(testloader), best_acc

# Custom method to skip save and load shapes
def get_base_shapes(base_shapes, delta_shapes):
    model_or_shapes = clear_dims(zip_infshapes(base_shapes, delta_shapes))
    if isinstance(model_or_shapes, nn.Module):
        sh = get_infshapes(model_or_shapes)
    elif isinstance(model_or_shapes, dict):
        sh = deepcopy(model_or_shapes)
    else:
        raise ValueError()
    sh = {k: s.base_shape() for k, s in sh.items()}
    return {k: InfShape.from_base_shape(v) for k, v in sh.items()}
    
if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description=''
    '''
    PyTorch CIFAR10 Training, with μP.
    To save base shapes info, run e.g.
        python main.py --save_base_shapes resnet18.bsh --width_mult 1
    To train using MuAdam (or MuSGD), run
        python main.py --width_mult 2 --load_base_shapes resnet18.bsh --optimizer {muadam,musgd}
    To test coords, run
        python main.py --load_base_shapes resnet18.bsh --optimizer sgd --lr 0.1 --coord_check
        python main.py --load_base_shapes resnet18.bsh --optimizer adam --lr 0.001 --coord_check
    If you don't specify a base shape file, then you are using standard parametrization, e.g.
        python main.py --width_mult 2 --optimizer {muadam,musgd}
    Here muadam (resp. musgd) would have the same result as adam (resp. sgd).
    Note that models of different depths need separate `.bsh` files.
    ''', formatter_class=argparse.RawTextHelpFormatter)
    
    parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
    parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
    parser.add_argument('--arch', type=str, default='resnet18')
    parser.add_argument('--optimizer', default='musgd', choices=['sgd', 'adam', 'musgd', 'muadam'])
    parser.add_argument('--epochs', type=int, default=150)
    parser.add_argument('--width_mult', type=float, default=1)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--test_batch_size', type=int, default=128)
    parser.add_argument('--weight_decay', type=float, default=5e-4)
    parser.add_argument('--num_workers', type=int, default=2)
    parser.add_argument('--test_num_workers', type=int, default=2)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--seed', type=int, default=1111, help='random seed')

    args = parser.parse_args()

    root_dir = "/out/"

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    best_acc = 0  # best test accuracy
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    # Set the random seed manually for reproducibility.
    torch.manual_seed(args.seed)
    
    print('==> Preparing data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(
        root='../dataset', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

    testset = torchvision.datasets.CIFAR10(
        root='../dataset', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.test_num_workers)

    classes = ('plane', 'car', 'bird', 'cat', 'deer',
            'dog', 'frog', 'horse', 'ship', 'truck')

    # Model
    print('==> Building model..')
    net = getattr(resnet, args.arch)(wm=args.width_mult)
    net = net.to(device)
    if args.optimizer in ["musgd","muadam"]:
        print(f'using muP Parametrization')
        base_shapes = get_shapes(net)
        delta_shapes = get_shapes(getattr(resnet, args.arch)(wm=args.width_mult/2))
        dict_infshape = get_base_shapes(base_shapes, delta_shapes)
        
        set_base_shapes(net, dict_infshape)
    else:
        print(f'using Standard Parametrization')
        set_base_shapes(net, None)

    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('./checkpoint/ckpt.pth')
        net.load_state_dict(checkpoint['net'])
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch']

    criterion = nn.CrossEntropyLoss()
    if args.optimizer == 'musgd':
        optimizer = MuSGD(net.parameters(), lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    elif args.optimizer == 'muadam':
        optimizer = MuAdam(net.parameters(), lr=args.lr)
    elif args.optimizer == 'sgd':
        optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = optim.Adam(net.parameters(), lr=args.lr)
    else:
        raise ValueError()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)

    tb_time = strftime("%Y-%m-%d-%H:%M:%S", gmtime())
    sub_dir = root_dir + tb_time + "-" + str(args.arch) +  "-" + str(args.lr) + "-" + str(args.width_mult)
    
    os.makedirs(sub_dir, exist_ok = True)
    
    writer = SummaryWriter(sub_dir)
    
    for epoch in range(start_epoch, start_epoch+args.epochs):
        best_train_loss = train(epoch, net, writer)
        best_test_loss, best_acc = test(epoch, net, writer)
        scheduler.step()
        
    writer.add_hparams({"Epochs": args.epochs, "Width": args.width_mult, "BatchSize": args.batch_size}, {"Test/Score": best_acc})

Then, for each width multiplier wm from 1 to 5, we launched the following bash scripts, which train the models for a set of learning rates.

In muP mode:

#!/bin/bash
wm=5
cd /exp/ResNetTest
pip3 install mup
for lr in 6.10351562e-05 8.13100441e-05 1.08319921e-04 1.44302039e-04 \
          1.92236834e-04 2.56094789e-04 3.41165320e-04 4.54494899e-04 \
          6.05470725e-04 8.06598270e-04 1.07453712e-03 1.43148090e-03 \
          1.90699561e-03 2.54046859e-03 3.38437101e-03 4.50860411e-03 \
          6.00628919e-03 8.00148094e-03 1.06594430e-02 1.42003369e-02 \
          1.89174582e-02 2.52015307e-02 3.35730700e-02 4.47254987e-02 \
          5.95825832e-02 7.93749499e-02 1.05742019e-01 1.40867801e-01 \
          1.87661798e-01 2.50000000e-01
do
    echo "Training for wm = ${wm} and lr = ${lr}"
    python3 main.py --lr=$lr --epochs=150 --batch_size=128 --num_workers=4 --seed=1111 --width_mult=$wm
done

In SP mode:

#!/bin/bash
wm=5
cd /exp/ResNetTest
pip3 install mup
for lr in 2.56094789e-04 4.54494899e-04 \
          8.06598270e-04 1.43148090e-03 \
          2.54046859e-03 4.50860411e-03 \
          8.00148094e-03 1.42003369e-02 \
          2.52015307e-02 4.47254987e-02 \
          7.93749499e-02 1.40867801e-01 \
          2.50000000e-01
do
    echo "Training for wm = ${wm} and lr = ${lr}"
    python3 main.py --lr=$lr --epochs=150 --batch_size=128 --num_workers=4 --seed=1111 --width_mult=$wm --optimizer='sgd'
done

Then, we get the minimum loss and plot the two curves (loss vs lr) : one with mup, one without.

With muP :

loss-vs-lr-with-mup

Without muP :

loss-vs-lr-with-sp

As you can see on the two figures, there is no visible difference between the two scenarios: In both case, minima are aligned except for those with wm=1
Do you have an idea why it is happening ?
Thanks for your help

MuAdam not adjusting lr for output weights

Hi, thank you for your great project for hyperparameter tuning!

As our team migrating the mup to other training framework, it occurs to us that the MuAdam does not scale the learning rate for output weights as the TP5 paper illustrated:
image

mup/mup/optim.py

Lines 55 to 70 in c9d6700

for p in param_group['params']:
assert hasattr(p, 'infshape'), (
f'A parameter with shape {p.shape} does not have `infshape` attribute. '
'Did you forget to call `mup.set_base_shapes` on the model?')
if p.infshape.ninf() == 2:
matrix_like_p[p.infshape.width_mult()]['params'].append(p)
elif p.infshape.ninf() > 2:
raise NotImplementedError('more than 2 inf dimensions')
else:
vector_like_p['params'].append(p)
for width_mult, group in matrix_like_p.items():
# Scale learning rate and weight decay accordingly
group['lr'] /= width_mult
group['weight_decay'] *= width_mult
new_param_groups.extend(list(matrix_like_p.values()) + [vector_like_p])
return impl(new_param_groups, **kwargs)

It seems to us that only the lr of hidden layer (the layer with 2 inf dimensions) is scaled w.r.t fanin, but the output weight is ignored. We wonder if this is intended. Thank you!

Embedding Multiplier for Transformer - Clarification

Hi!

I really love the work in muP! I wanted to clarify one point from the GPT-3 sweep in Appendix F.4:
image
Does this multiplier correspond to scaling the output of the input embedding layer:
ie. inputs_embeds = self.embed_multiplier * self.wte(input_ids)?

I didn't notice this hyperparameter being set in either mutransformers or the Transformer example in this repo, but saw that Appendix F.4 recommended a value of 10 for this param, (and that it seems different from the output temperature of the unembeddings layer).

Thanks.

Examples with ConvNets

Could this be applied with Convnets ?
I would really love to have an example with Convnets to understand better how to use the library for my company's use cases.

Conv1D Coord check looks good (I think), but μTransfer does not seem to work?

Hi all, attaching the coord check plots and also a screen shot of the train loss and auprc plots. I used the Conv1D from the branch, but have also tried

I was looking at the conv plots from the examples and I noticed that one of the layers is constant across width, but after the first step is significantly smaller. Is that an issue?

Mup plot
coord_conv_mup

Sp Plot
coord_conv_sp

Train Loss
Screen Shot 2022-08-01 at 1 12 52 PM

Train AUPRC
Screen Shot 2022-08-01 at 1 12 41 PM

I also tried this with a transformer based model and found similar results where the transferred HPs did not result in better performance. I can regenerate those plots if needed.

Is this expected? What can I do to fix this? Having mup work would be a huge unlock for us :D

Does MuReadout apply to all outputs on which loss is computed?

Hi,

I have an autoencoder-like structure, where I have a loss also on the intermediate representation (say z). The loss is computed as L=L_1(x_hat) + L_2(z), where the final output is x_hat, for a regression-style problem. Should I apply MuReadout to the intermediate representation too?

Related question (continuation of issue #3): How are the initialization and learning rate scales for a convolution operation computed according to this method?

Thanks for your help and the super cool project!

Some questions about the implementation of muP.

I have some questions about the implementation of muP in the rescale and transfer hyperparameter. Specifically, in

  1. linear.bias.data *= fanin_mult**0.5
    . As mentioned in Tabel 8 in your paper TP V, the bias should scale to O(1/fan_in), but I notice that you multiply a width_mult here. Is that correct or did I miss something?
  2. Similarly, in
    self.weight.data *= self.width_mult()**0.5
    , I think the weight of readout parameter should not scale because it obeys O(1) as in Tabel 8 in TP V.

I would greatly appreciate it if you could take the time to answer my question!

mu parametrization for multi-head attention / grouped convolution

Hi, in Appendix E.2 - Number of Attention Heads, there is a use case that fixes d_head (dimension size per head) and scales n_head (number of heads). Do we need to change anything when we use such multi-head attention with scaled n_head? Or we still follow the same way as shown in the provided Transformer example (scale d_head, only change 1/sqrt(d) to 1/d and keep other settings the same).

Similarly, when applying to the muP to grouped convolution which keeps dim size per group and scales number of groups, is there any special rule we should follow?

Thanks!

LayerNorm Gain and Bias Multipliers

Hi,
I'm wondering how to correctly implement the LayerNorm gain and bias multipliers.

In section F.3 it is detailed that a hyperparameter search on LayerNorm gain multiplier and bias multiplier is done in addition to the normal output multiplier and attention logits multiplier. However, I cannot find any example of this being done in the examples/Transformers/models.py module. I also checked the mutransformers repo and checked the BERT example and could also find no evidence of how to correctly implement a layer norm gain multiplier and bias multiplier. Sorry if I somehow managed to miss this.

In the 3rd paragraph in Appendix A you note that any parameter tensor in a neural network can be multiplied by a constant c where c is defined as a parameter multiplier. Therefore, I think it is reasonable to deduce that the correct implementation of a LayerNorm gain multiplier is simply:
layernorm_gain_multiplier * torch.nn.LayerNorm(x)
Is this correct?

For the bias multiplier it is not so clear to me how this is correctly implemented. Of course one can just say it would be:
torch.nn.Linear(x, bias=False) + bias * bias_multiplier
but if this is the correct implementation it's not clear to me why:

  1. Why is the bias multiplier only being applied to the bias and not the entire linear transformation?
  2. Should this same bias multiplier be applied to all linear transformations? From Table 8 it seems to suggest there's nothing stopping you from doing this.
  3. Should this bias multiplier be applied to output weights that have an output multiplier being applied to it already? Should the bias multiplier in this case effectively be bias_multiplier * output_multiplier ?
  4. Should this bias multiplier be applied to other terms that have a bias-like term? For example, batch and layer normalization.

Thank you for this wonderful paper and excellent repo detailing the correct implementation and cross-checking via coord checking!
-AJ

Warmup schedule when changing the number of tokens/steps (GPT-3 experiment detail)

Hi! I had a few questions regarding the warmup schedule when changing the number of training tokens, as done in the GPT-3 experiments in your work.

  1. For the GPT-3 sweeps, is the batch size kept the same between the proxy model and target model?

  2. For the 40M proxy model, which was trained for 4B and 16B tokens respectively compared to the 300B tokens for the full 6.7B param model, is the warmup period set as a proportion of the total training steps (ex. 1% of training steps) or as an absolute number of steps (ex. 1B steps)?

Interpreting jitter in coordcheck

Hi,

I have coord check plots with mup that look roughly stable but are a bit more jittery compared to the example plots. I was wondering if this can be considered expected behavior.

mup_test

Thanks!

integration with Flax?

Is there any interest in integrating this work with Flax?

They already have a init function, decoupling parameters initialization from model definition which could make introducing mup fairly plug-and-play.

Plus they relie on optax for their optimizers. As that library has a focus on composability, you might be able to introduce a transformation that takes an optimizer and makes it mup compatible.

Overall, I believe the Flax ecosystem could make mup more easily accessible to people.

How to use 'attn_mult' config

Hi, thanks for your amazing work!

In the example of using mup in GPT-2: https://github.com/microsoft/mutransformers/tree/main/mutransformers/models/gpt2, I notice that you changed attention scores from kq / sqrt(d) to kq * attn_mult / d, where attn_mult
is a new added config (https://github.com/microsoft/mutransformers/blob/main/mutransformers/models/gpt2/modeling_gpt2.py#L205). However, the default value of attn_mult is sqrt(d) (https://github.com/microsoft/mutransformers/blob/main/mutransformers/models/gpt2/configuration_gpt2.py#L199), which makes attention scores back to kq / sqrt(d).

So why do we need this attn_mult? How should I set its value?

Thanks!

Is it possible to also scale the depth of the model?

In the paper and blog post you provide examples of scaling the number of layers / depth of the model but when I try doing this in the coordinate check function from the mutransformer package I get the following error.
Is there currently a way to make this work?

AssertionError: `base_shapes` has extra names set(). `shapes` has extra names 
{'bert.encoder.layer.4.output.dense.bias', 'bert.encoder.layer.6.attention.output.dense.bias', 
'bert.encoder.layer.7.attention.self.value.bias', 'bert.encoder.layer.4.attention.ln.bias', 
'bert.encoder.layer.5.attention.self.key.weight', 'bert.encoder.layer.7.output.dense.bias', 
'bert.encoder.layer.7.intermediate.dense.bias', 'bert.encoder.layer.6.attention.self.value.weight', 
'bert.encoder.layer.7.ln.weight', 'bert.encoder.layer.4.attention.self.key.bias', 
'bert.encoder.layer.7.attention.self.key.weight', 'bert.encoder.layer.5.output.dense.bias', 
'bert.encoder.layer.5.attention.ln.weight', 'bert.encoder.layer.7.attention.self.query.weight', 
'bert.encoder.layer.5.attention.output.dense.bias', 'bert.encoder.layer.7.attention.output.dense.weight', 
'bert.encoder.layer.5.intermediate.dense.weight', 'bert.encoder.layer.5.attention.self.query.weight', 
'bert.encoder.layer.4.attention.self.value.weight', 'bert.encoder.layer.7.ln.bias', 
'bert.encoder.layer.7.attention.self.value.weight', 'bert.encoder.layer.6.attention.self.key.weight', 
'bert.encoder.layer.4.attention.self.query.weight', 'bert.encoder.layer.6.attention.self.key.bias', 
'bert.encoder.layer.6.output.dense.weight', 'bert.encoder.layer.6.attention.ln.bias', 
'bert.encoder.layer.4.attention.ln.weight', 'bert.encoder.layer.7.attention.self.key.bias', 
'bert.encoder.layer.7.attention.output.dense.bias', 'bert.encoder.layer.6.ln.bias', 
'bert.encoder.layer.7.output.dense.weight', 'bert.encoder.layer.7.attention.self.query.bias', 
'bert.encoder.layer.4.intermediate.dense.weight', 'bert.encoder.layer.6.attention.self.value.bias', 
'bert.encoder.layer.7.intermediate.dense.weight', 'bert.encoder.layer.5.attention.output.dense.weight', 
'bert.encoder.layer.5.ln.weight', 'bert.encoder.layer.5.intermediate.dense.bias', 
'bert.encoder.layer.6.intermediate.dense.weight', 'bert.encoder.layer.4.ln.weight', 
'bert.encoder.layer.4.attention.self.query.bias', 'bert.encoder.layer.5.attention.self.query.bias', 
'bert.encoder.layer.5.attention.self.key.bias', 'bert.encoder.layer.4.output.dense.weight', 
'bert.encoder.layer.4.ln.bias', 'bert.encoder.layer.6.attention.ln.weight', 
'bert.encoder.layer.4.attention.self.key.weight', 'bert.encoder.layer.7.attention.ln.bias', 
'bert.encoder.layer.5.output.dense.weight', 'bert.encoder.layer.4.attention.self.value.bias', 
'bert.encoder.layer.6.attention.self.query.weight', 'bert.encoder.layer.5.attention.ln.bias', 
'bert.encoder.layer.5.attention.self.value.weight', 'bert.encoder.layer.6.ln.weight', 
'bert.encoder.layer.7.attention.ln.weight', 'bert.encoder.layer.4.attention.output.dense.weight', 
'bert.encoder.layer.6.attention.output.dense.weight', 'bert.encoder.layer.5.ln.bias', 
'bert.encoder.layer.6.attention.self.query.bias', 'bert.encoder.layer.4.intermediate.dense.bias', 
'bert.encoder.layer.6.intermediate.dense.bias', 'bert.encoder.layer.4.attention.output.dense.bias', 
'bert.encoder.layer.5.attention.self.value.bias', 'bert.encoder.layer.6.output.dense.bias'}.

Finetuning a Pretrained Model Using MuP

Somewhat of a naive question, but say we have pretrained a model and now want to finetune it on a downstream task. Is there any reason we shouldn't replace the MuP layers with the equivalent torch layers? I have to imagine that we don't need to use MuP here, but want to make sure that this doesn't break anything if we replace them

Multiple nn.Linear layers

Hi,
Your project is really interesting, so I am learning how to apply it to some specific models.
For example, the model has multiple nn.Linear layers like in wav2vec 2.0 (self.post_extract_proj, self.project_q, self.project_inp, self.target_glu, self.final_proj), should I replace all these layers to MuReadout?

class Wav2Vec2Model(BaseFairseqModel):
    def __init__(self, cfg: Wav2Vec2Config):
        super().__init__()
        self.cfg = cfg

        feature_enc_layers = eval(cfg.conv_feature_layers)
        self.embed = feature_enc_layers[-1][0]

        self.feature_extractor = ConvFeatureExtractionModel(
            conv_layers=feature_enc_layers,
            dropout=0.0,
            mode=cfg.extractor_mode,
            conv_bias=cfg.conv_bias,
        )

        self.post_extract_proj = (
            nn.Linear(self.embed, cfg.encoder_embed_dim)
            if self.embed != cfg.encoder_embed_dim and not cfg.quantize_input
            else None
        )

        self.mask_prob = cfg.mask_prob
        self.mask_selection = cfg.mask_selection
        self.mask_other = cfg.mask_other
        self.mask_length = cfg.mask_length
        self.no_mask_overlap = cfg.no_mask_overlap
        self.mask_min_space = cfg.mask_min_space

        self.mask_channel_prob = cfg.mask_channel_prob
        self.mask_channel_before = cfg.mask_channel_before
        self.mask_channel_selection = cfg.mask_channel_selection
        self.mask_channel_other = cfg.mask_channel_other
        self.mask_channel_length = cfg.mask_channel_length
        self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
        self.mask_channel_min_space = cfg.mask_channel_min_space

        self.dropout_input = nn.Dropout(cfg.dropout_input)
        self.dropout_features = nn.Dropout(cfg.dropout_features)

        self.feature_grad_mult = cfg.feature_grad_mult

        self.quantizer = None
        self.input_quantizer = None

        self.n_negatives = cfg.num_negatives
        self.cross_sample_negatives = cfg.cross_sample_negatives
        self.codebook_negatives = cfg.codebook_negatives
        self.negatives_from_everywhere = cfg.negatives_from_everywhere

        self.logit_temp = cfg.logit_temp

        final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim

        if cfg.quantize_targets:
            vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else final_dim
            self.quantizer = GumbelVectorQuantizer(
                dim=self.embed,
                num_vars=cfg.latent_vars,
                temp=cfg.latent_temp,
                groups=cfg.latent_groups,
                combine_groups=False,
                vq_dim=vq_dim,
                time_first=True,
                weight_proj_depth=cfg.quantizer_depth,
                weight_proj_factor=cfg.quantizer_factor,
            )
            self.project_q = nn.Linear(vq_dim, final_dim)
        else:
            self.project_q = nn.Linear(self.embed, final_dim)

        if cfg.quantize_input:
            if cfg.same_quantizer and self.quantizer is not None:
                vq_dim = final_dim
                self.input_quantizer = self.quantizer
            else:
                vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else cfg.encoder_embed_dim
                self.input_quantizer = GumbelVectorQuantizer(
                    dim=self.embed,
                    num_vars=cfg.latent_vars,
                    temp=cfg.latent_temp,
                    groups=cfg.latent_groups,
                    combine_groups=False,
                    vq_dim=vq_dim,
                    time_first=True,
                    weight_proj_depth=cfg.quantizer_depth,
                    weight_proj_factor=cfg.quantizer_factor,
                )
            self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim)

        self.mask_emb = nn.Parameter(
            torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
        )

        self.encoder = TransformerEncoder(cfg)
        self.layer_norm = LayerNorm(self.embed)

        self.target_glu = None
        if cfg.target_glu:
            self.target_glu = nn.Sequential(
                nn.Linear(final_dim, final_dim * 2), nn.GLU()
            )

        self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)

Thank you! ^^

PyTorch Lightning example

Dear team behind mup,

This is some great work! I believe providing a PyTorch Lightning example could help users adopt this library.

I even wonder if this technique could be embedded in an even less boilerplate way. I was thinking about an extension to Pytorch Lightning Tuner which would automatically apply mup and apply the µTransferable Hyperparameters.

I wondered if someone from the mup Team would be interested to investigate those ideas to democratize even further this work.

Best,
T.C

Does mup support Swin Transformer v2 model?

Hi, we are trying to use mup tool to tune Swin Transformer v2 model.
I modified the code of Swin Transformer v2 to adapt mup and executed the "save base shape" and "coordinate check".
The results of "coordinate check" shows that it can not meet the requirements of mup.

Does mup support the Swin Transformer v2 model?

For the code of "swin_transformer_v2.py", I modified the following code (Because Swin Transformer v2 doesn't use "1/sqrt(d) attention scaling", I don't modify it):

  1. replaced the output layper nn.Linear with MuReadout
  2. replaced std normal init with mup normal init
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
# self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
### muP: replace nn.Linear with MuReadout
self.head = MuReadout(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

self.apply(self._init_weights)
for bly in self.layers:
    bly._init_respostnorm()
def _init_weights(self, m, readout_zero_init=False, query_zero_init=False):
    ### muP: swap constant std normal init with normal_ from `mup.init`.
    ### Because `_init_weights` is called in `__init__`, before `infshape` is set,
    ### we need to manually call `self.apply(self._init_weights)` after calling
    ### `set_base_shape(model, base)`
    if isinstance(m, nn.Linear):
        if isinstance(m, MuReadout) and readout_zero_init:
            m.weight.data.zero_()
        else:
            if hasattr(m.weight, 'infshape'):
                normal_(m.weight, mean=0.0, std=.02)
            else:
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    ### End muP
    elif isinstance(m, nn.LayerNorm):
        nn.init.constant_(m.bias, 0)
        nn.init.constant_(m.weight, 1.0)

For the code of "main.py" of Swin Transformer, I added "save base shape" and "coordinate check" functions.

def main(config, args):
    dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)

    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
    model = build_model(config)
    logger.info(str(model))

    ### muP
    if args.save_base_shapes:
        print(f'saving base shapes at {args.save_base_shapes}')
        base_shapes = get_shapes(model)
        delta_config = copy.deepcopy(config)
        delta_config.defrost()
        delta_config.MODEL.SWINV2.EMBED_DIM *= 2  # Modify SwinV2 embed dim
        delta_config.MODEL.SWIN.EMBED_DIM *= 2  # Modify Swin embed dim
        # delta_config.MODEL.SWIN_MOE.EMBED_DIM *= 2  # Modify Swin_moe embed dim
        delta_config.MODEL.SWIN_MLP.EMBED_DIM *= 2  # Modify Swin_mlp embed dim

        delta_shapes = get_shapes(
            # just need to change whatever dimension(s) we are scaling
            build_model(delta_config)
        )
        make_base_shapes(base_shapes, delta_shapes, savefile=args.save_base_shapes)
        print('done and exit')
        import sys;
        sys.exit()
    if args.load_base_shapes:
        print(f'loading base shapes from {args.load_base_shapes}')
        set_base_shapes(model, args.load_base_shapes)
        print('done')
    else:
        print(f'using own shapes')
        set_base_shapes(model, None)
        print('done')
### muP
def coord_check(mup, config, lr, optimizer, nsteps, nseeds, args, plotdir='', legend=False):
    dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)

    def gen(w, standparam=False):
        def f():
            delta_config = copy.deepcopy(config)
            delta_config.defrost()
            delta_config.MODEL.SWINV2.EMBED_DIM = w  # Modify SwinV2 embed dim
            delta_config.MODEL.SWIN.EMBED_DIM = w  # Modify Swin embed dim
            # delta_config.MODEL.SWIN_MOE.EMBED_DIM = w  # Modify Swin_moe embed dim
            delta_config.MODEL.SWIN_MLP.EMBED_DIM = w  # Modify Swin_mlp embed dim
            model = build_model(delta_config)

            if standparam:
                set_base_shapes(model, None)
            else:
                assert args.load_base_shapes, 'load_base_shapes needs to be nonempty'
                set_base_shapes(model, args.load_base_shapes)
            return model
        return f

    optimizer = optimizer.replace('mu', '')
    widths = (12, 24, 48, 96, 192)
    models = {w: gen(w, standparam=not mup) for w in widths}

    # train_data = batchify(corpus.train, batch_size, device=args.device)
    df = get_coord_data(models, data_loader_train, mup=mup, lr=lr, optimizer=optimizer, flatten_output=True,
                        nseeds=nseeds, nsteps=nsteps, lossfn='xent')

    prm = 'muP' if mup else 'SP'
    return plot_coord_data(df, legend=legend,
                           save_to=os.path.join(plotdir, f'{prm.lower()}_trsfmr_{optimizer}_coord.png'),
                           suptitle=f'{prm} Transformer {optimizer} lr={lr} nseeds={nseeds}',
                           face_color='xkcd:light grey' if not mup else None)
if __name__ == '__main__':
    args, config = parse_option()

    ......

    ### muP
    if args.coord_check:
        print('testing parametrization')
        import os
        os.makedirs('coord_checks', exist_ok=True)
        plotdir = 'coord_checks'
        coord_check(mup=True, config=config, lr=0.0001, optimizer='adamw',
                    nsteps=args.coord_check_nsteps, nseeds=args.coord_check_nseeds, args=args,
                    plotdir=plotdir, legend=False)
        coord_check(mup=False, config=config, lr=0.0001, optimizer='adamw',
                    nsteps=args.coord_check_nsteps, nseeds=args.coord_check_nseeds, args=args,
                    plotdir=plotdir, legend=False)
        import sys
        sys.exit()

    main(config, args)

The results of "coordinate check" show that there is only a small difference between "mup" and "SP". sorry, I can't upload pictures.
Could you please help us to check if mup can support Swin Transformer v2 model? or there are some other reasons? Thanks a lot.

Coord check looks good, but μTransfer is not working as expected

Hello, μP team! Very excited to see you open source your excellent work! I was looking to apply μP on our work, and on Megatron-DeepSpeed I modified the training script as suggested in the tutorial, set the infshape, reset parameters initialization, put on MuAdam, and got a coord_check that looked successful. But when we transfer the learning rate that performed well on the 350M GPT model to the large model 1.3B, we found that the 1.3B could not withstand such a large learning rate and eventually produced NaN.

I was wondering what details might not have been taken into account, or the conditions were not met, causing μTransfer to fail. How should I debug, or μTransfer just won't work under this condition?

The following is the experimental information.

image

image

350M -> 1.3B GPT model μTransfer training loss( tensorborad link ):
image

I think it may be a bit redundant, but if you are interested, the transformation of μP is here:

  1. Replace output layer with MuReadout, https://github.com/shjwudp/Megatron-LM/blob/mup/megatron/model/gpt_model.py#L250
  2. Make sure to use 1/d instead of 1/sqrt(d) attention scaling, https://github.com/shjwudp/Megatron-LM/blob/mup/megatron/model/transformer.py#L175
  3. Set infshape and do mup parameter initiliaze, https://github.com/shjwudp/Megatron-LM/blob/mup/pretrain_gpt.py#L110
  4. Put on MuAdam, https://github.com/shjwudp/Megatron-LM/blob/mup/megatron/optimizer/__init__.py#L65
  5. Implement the equivalent MuReadout._rescale_parameters operation, https://github.com/shjwudp/Megatron-LM/blob/mup/megatron/mpu/layers.py#L191
  6. Modify lr scheduler to update lr according to width, https://github.com/shjwudp/Megatron-LM/blob/mup/megatron/learning_rates.py#L127
  7. Coord check, https://github.com/shjwudp/Megatron-LM/blob/mup/megatron/mup_utils.py#L16

FSDP support?

Hi, does mup support training with FSDP? I got a model training with DDP, but when switching to FSDP I get the following assertion.

assert hasattr(self.weight, 'infshape'), ( AssertionError: Please call set_base_shapes(...). If using torch.nn.DataParallel, switch to distributed training with torch.nn.parallel.DistributedDataParallel instead

Reproducing the training loss vs learning rates curve on MLP

Hello!
We tried to reproduce the experiment in your paper (Figure 3, MLP width different hidden sizes trained for 20 epoch on CIFAR-10 using SGD). We made some modifications to examples/MLP/main.py:

nonlin = torch.relu
criterion = F.cross_entropy
method = 'mup' if args.load_base_shapes else 'sp'
for width in [64, 512, 4096]:
    for lr in np.logspace(-14, -2, 13, base=2):
        ...

And we ran the following commands:

# sp
python main.py

# mup
python main.py --load_base_shapes width64.bsh 

However, we didn't observe the shift of best LR with different width in SP. There doesn't seem to be much difference between SP and MuP. Is there anything wrong in our implementation? Thanks.

SP MuP
logs_sp json logs_mup json

Questions on learning schedule and binary classification

I'm trying mup on a deep transformer structure and have the following questions:

  1. Warmup Ratio
    I tuned the learning rate alone with width=256 and transferred the result to width=512, but the training curve diverges. I coord_checked and there seems to be no problem. Then I modified the warm_up ratio from 0.01 to 0.1 for width=512 and the loss converges. I was wondering if warmup ratio is a mu-transferable hyperparameter that must be tuned, or it can be mu-transferred across (seems not according to my experiments above?) Do you have theoretical insights or experiences on this?

Specifically, say if your full training steps should be 1M and you do hp tuning with 0.1M steps, how did you set warmup ratio on both side? Also, I'd like to confirm that with a linear schedule, the final learning rate is down to 0 at the end of each training (i.e. with a sharper decrease for the 0.1M training).

  1. Binary Classification
    If I have a binary classification head (linear + softmax) trained along with the language model, and follow the practice of all-zero initialization for this head, the weights and logits of the two output neurons will be always "x and -x" across the training due to the gradient property of binary softmax. Although I'm not sure if this actually has bad effects, I was wondering is all-zero initialization necessary for a classification head in transfromer pretraining?

Help would be appreciated!

Hyperparameter search on base models

Following up on the conversation here #11 since it wasn't related to the original issue.

How exactly should the learning rates be split up when doing hyperparameter search on the base model? You said input/hidden/output, but Table 8 groups all biases with input weights. Does the output bias also fall into the input/bias group? (it has finite fan-in and fan-out, unlike the other biases which have infinite fan-out)

Unclear `assert_hidden_size_inf` triggers

My code is triggering the "has infinite fan-in and finite fan-out dimensions but is not type MuReadout" assertion on "non-obvious" situations (not the last linear layer of the model):

  • Often it's happening on the *first linear layer of the module when I'm changing the size of the last linear layer;
  • Sometimes it's happening on intermediate attention layers of a module;

What am I doing wrong? Is there a good way to debug those situations?

_rescale_parameters() inconsistent with the paper for the tied embedding scenario?

Hi! I've been looking into the integration of muP into the Megatron-LM setup and I was wondering about the _rescale_parameters() method of MuReadout in case of shared (tied) input/output embeddings. Specifically, in the Transformer example I am not really sure that it is in line with the suggested embedding initialisation (i.e., constant) from the paper.

Currently, in the example:

  • encoder is initialised from N(0,1) <- default nn.Embedding init
    self.encoder = nn.Embedding(ntoken, ninp)
  • decoder is firstly initialised within MuSharedReadout from U(-1/sqrt(fan_in), 1/sqrt(fan_in)) <- default nn.Linear init
    super().__init__(*weight.shape, bias=bias, **kwargs)
  • but then the decoder weights are overwritten with those from encoder (the next line 68) -> they become N(0,1) init
  • finally, once set_base_shapes() is called, both encoder and decoder weights will be rescaled within _rescale_parameters() by *= self.width_mult()**0.5 -> which makes them initialised from N(0, sqrt(d/d_0)) and so scale with width.

However, in the muP paper it is suggested to initialise them as constants to be muP compatible. It also should be mentioned that in the untied case, the output embeddings are set to 0, so _rescale_parameters() doesn't have an effect and things are consistent with the paper.

Below I also attach the coordinate check plots for the Transformer example for untied, tied+rescaling (current implementation) and tied+no rescaling (_rescale_parameters() disabled), respectively. One can see that for untied the norms are nicely flat, for tied+rescaling some layers have growing activations, and for tied+no rescaling one layer has a vanishing trend.

So I was wondering if _rescale_parameters() should be disabled for the tied embedding scenario to keep the init constant, assuming the inheritance of N(0,1) initialisation in nn.Embedding()?

μp_trsfmr_adam_coord
μp_trsfmr_adam_coord_tied
μp_trsfmr_adam_coord_tied_fix

in mlp example: 2 problems

  1. https://github.com/microsoft/mup/blob/main/examples/MLP/main.py#L61
    If you don't specify a base shape file, then you are using standard parametrization,in the code,the optimizer will use the MuSGD?https://github.com/microsoft/mup/blob/main/examples/MLP/main.py#L257

  2. why the init func not use the mup.init?
    https://github.com/microsoft/mup/blob/main/examples/MLP/main.py#L139

### Replace your custom init, if any
for param in model.parameters():
    ### If initializing manually with fixed std or bounds,
    ### then replace with same function from mup.init
    # torch.nn.init.uniform_(param, -0.1, 0.1)
    mup.init.uniform_(param, -0.1, 0.1)

interpreting coord checks

Hi there, I'm working on a flax port of this and I'm trying to use the coord check scripts on a variant of your MLP example to see if I've done it right. I'm struggling to interpret the results though:

sp_mlp_sgd_coord
μp_mlp_sgd_coord

The point I'm confused on is the green line in the muP graph step 1: if I understood your paper correctly, this should be a flat line right?
Looking through my code, i can't spot the mistake though, so I must ask, is my assumption about step 1 of the coord check wrong?

Coord-check for conv1d

I modified the muconv2d branch to get a conv1d variant of the output layer for mup, and I applied it to a shallow variant of a unet model I've been testing.

repo for model: https://github.com/bob80333/audio_bandwidth_extension
fork of mup with conv1d: https://github.com/bob80333/mup/tree/muconv2d

Here's the coord-check results, they don't look quite a smooth as the paper but there's definitely a big difference between mup and no mup.

mup:
plot_mup

no mup:
plot_nomup

Does this look about right for the coordinate check? The figures I saw in the example looked much smoother than this.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.