The training curves show MLA-o has roughly 2x higher gradient norms compared to MLA. Does this tell us something?
Thoughts:
- Where are the added gradients coming from--all over the model, or are they associated with the output decomposition?
- Maybe try comparing the norms of the output weights between the two models.