Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

muP (maximum update parametrization) #650

Open
wants to merge 37 commits into
base: master
Choose a base branch
from

Conversation

gordicaleksa
Copy link
Contributor

@gordicaleksa gordicaleksa commented Jun 26, 2024

Main changes (see mup.md file for more details):

  • Modify random initialization
  • Scale attention scores by 1/d instead of 1/sqrt(d), also add an attn_mult tunable coefficient
  • Scale activations by 1/width_mult before mapping into logits
  • Update learning rate & weight decay for a subset of layers
  • Add coordinate check test - it's like gradient check but for muP

where:

  • width_mult is the ratio of widths of the current model to the base model
  • d is the number of channels in a single attn head

Test

To test muP vs SP (standard parametrization):

  • Run scripts/mup_coordinate_check.sh script
  • Run dev/mup_coordinate_check_visualize.py script

Run

  • Set use_mup to 1
  • Set mup_width_mult to ratio of widths of your target model to your base model
    mup_base_attn_mult is a tunable param, 1 seems to be working nicely for our family of models.

Ablations

The coord check results are highly dependent on the learning rate and max width used.

In my preliminary ablations (max width = 1024 & lr = 0.0006) I concluded that the only thing that would mess up the coordinate check was this line:
scale = (model->use_mup && i != 0 && i != 1) ? mup_scale_inv*scale : scale;

In my subsequent ablations (max width = 1024 & lr = 0.006, i.e. lr almost the same as in the reference mup gpt-2 imp, they use 0.01) i concluded that the results are much more sensitive: Adam modifications also matter, 1/width_mult logits scaling matters and whether we use 1/d.

See the next comment for more thorough ablation results.

References:

@gordicaleksa gordicaleksa force-pushed the mup branch 3 times, most recently from ce71c19 to 9864277 Compare June 29, 2024 14:30
@gordicaleksa gordicaleksa changed the title muP (maximum update parametrization) [WIP] muP (maximum update parametrization) Jun 29, 2024
@gordicaleksa
Copy link
Contributor Author

Ablation study

Setting:
lr = 0.006
width goes from 64 -> 4096 (geometric sequence, 2x coefficient)

Baseline muP:
image

At step 4 it seems we observe a bit more oscillations. These do dissapear with a bit lower learning rate so not sure what to think of it. Ultimately running a sweep of runs and confirming we have stable HPs will be the final test.

Here is what happens setting attn_mult to 1. It looks like it's a bit more stable? My implementation follows that of mutransformers but it's possible they had a bug when it comes to the attn_mult logic, will have to consult paper again:
image

SP (standard parametrization) baseline:
image
These clearly explode.


Comment out zeroing out of the embedding/readout layer & queries:
image
Observe that one of the layers is decreasing with width, which is undesirable.

Remove learning rate / weight decay modulating logic:
image

Remove scale = (model->use_mup && i != 0 && i != 1) ? mup_scale_inv*scale : scale;:
image

Remove logit scaling by 1/width_mult:
image

Set attn_mult to 1 and replace 1/d q*k scaling with the usual 1/sqrt(d):
image

@alxndrTL
Copy link

alxndrTL commented Jul 15, 2024

Hello, I looked over your implementation of muP in both CUDA and Python, and found that in CUDA (layernorm.cuh) you scale the output of all the layernorms by mup_scale (which is defined as sqrt(mup_width_multiplier)) :

image

but I haven't seen the same scaling in your Pytorch version (train_gpt2.py) (which is very concise and clear btw!)
Also, this is not mentioned in the doc/mup/mup.md file so I was wondering about where did it come from. Maybe I missed something.
Thank you

@gordicaleksa
Copy link
Contributor Author

gordicaleksa commented Jul 19, 2024

@alxndrTL it's mentioned in mup.md it is under 3.3 (1. / model->mup_width_mult), I don't use sqrt in the c/cuda code?

@alxndrTL
Copy link

Ok I didn't realize that the layernorm code I showcased is only used pre-logits, as per 3.3 (I thought it was used for every layer norms).

@gordicaleksa
Copy link
Contributor Author

gordicaleksa commented Jul 19, 2024

Hyperparam sweeps

Note:

  • y-axis is always training loss unless mentioned otherwise.
  • These checkpoints were trained to convergence on a 10B FineWeb subset.
  • For more info/experiments check out the following Discord thread and a few threads below it.

scheduler sweep:
image

Conclusion: cosine is a good choice.

attn_mult tunable param sweep:
image

Conclusion: Using 1 is a good default.

lr sweeps (note x-axis should be parsed as 1/2^x):
image

Conclusion: ~1/2^10 is a sweet spot for lr. The curves are stable as we increase the depth, i.e. the optimal lr is invariant to depth scaling.

out_mult sweep (out_mult is currently not supported in this PR but it's a minor tweak i've implemented locally):
image

Conclusion: 1 is a good default.

next steps:

cc: @karpathy

@gordicaleksa
Copy link
Contributor Author

@YuchenJin would be great to kick off a 7B mup run if you have some bandwidth! :)

@YuchenJin
Copy link
Contributor

@YuchenJin would be great to kick off a 7B mup run if you have some bandwidth! :)

Hey @gordicaleksa, happy to! Do you want me to just run the two scripts (scripts/mup_coordinate_check.sh and dev/mup_coordinate_check_visualize.py)? What LR should I use for the 7B model?

@habanoz
Copy link

habanoz commented Nov 5, 2024

@gordicaleksa

if self.config.use_mup: torch.nn.init.zeros_(module.weight)

AFAIK this line zero initializes modules with 'LLMC_SKIP_INIT' flag if mup is enabled. There is only one module with 'LLMC_SKIP_INIT' flag, it is lm_head. lm_head weight is tied to wte.weight.

Since embedding layers are initialized later in the code, what is the purpose of the zero initialization referenced above?

@habanoz
Copy link

habanoz commented Nov 6, 2024

@gordicaleksa

if self.config.use_mup: torch.nn.init.zeros_(module.weight)

AFAIK this line zero initializes modules with 'LLMC_SKIP_INIT' flag if mup is enabled. There is only one module with 'LLMC_SKIP_INIT' flag, it is lm_head. lm_head weight is tied to wte.weight.

Since embedding layers are initialized later in the code, what is the purpose of the zero initialization referenced above?

After reading mup.md, I can see now that MUP requires output layers to be initialized to zero.

The code assumes that embeddings are initialized before linear layers, which is a correct but IMHO a weak assumption.

Thanks for the great work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants