Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate xlstm cleanly. #35377

Open
wants to merge 50 commits into
base: main
Choose a base branch
from

Conversation

kpoeppel
Copy link

What does this PR do?

This PR integrates xLSTM via the xlstm-library including certain optimizations (potentially use torch.compile and cuda graphs for speed up). This enables using the NX-AI/xLSTM-7b without a special fork of transformers.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests? Yes, I adapted the tests of the recurrent Mamba2 model.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker

@kpoeppel kpoeppel force-pushed the integrate_xlstm_clean branch 4 times, most recently from 4a4e347 to fe5759b Compare December 21, 2024 13:41
Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@kpoeppel
Copy link
Author

How can this PR keep failing inside other models' tests, how can they get into 'main'?

@Cyrilvallez Cyrilvallez self-assigned this Jan 13, 2025
@Cyrilvallez Cyrilvallez self-requested a review January 13, 2025 10:55
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment on the text-generation-level changes: the xLSTMCache class is missing proper documentation (see MambaCache for an example) and should be added to __init__.py

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super grateful for the PR! In general happy to have more arch, but we want to make sure we are aligned in terms of what can be done!
🤗

Comment on lines 219 to 248
xlstm_block_config = xLSTMLargeConfig(
vocab_size=config.vocab_size,
embedding_dim=config.embedding_dim,
num_blocks=config.num_blocks,
num_heads=config.num_heads,
use_bias=config.use_bias,
add_out_norm=config.add_out_norm,
norm_eps=config.norm_eps,
norm_reduction_force_float32=config.norm_reduction_force_float32,
# mlstm_layer
qk_dim_factor=config.qk_dim_factor,
v_dim_factor=config.v_dim_factor,
# mlstm backend
chunkwise_kernel=config.chunkwise_kernel,
sequence_kernel=config.sequence_kernel,
step_kernel=config.step_kernel,
mode=config.mode,
chunk_size=config.chunk_size,
return_last_states=config.return_last_states,
autocast_kernel_dtype=config.autocast_kernel_dtype,
eps=config.eps,
inference_state_dtype=config.inference_state_dtype,
# feedforward
ffn_proj_factor=config.ffn_proj_factor,
ffn_round_up_to_multiple_of=config.ffn_round_up_to_multiple_of,
# capping
gate_soft_cap=config.gate_soft_cap,
output_logit_soft_cap=config.output_logit_soft_cap,
weight_mode=config.weight_mode,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should align xLSTMLargeConfig to match the inputs of mLSTMBlock

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and this not have to do this here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are still slight deviations of the xLSTMLargeConfig compare to the xLSTMConfig in configuration_xlstm.py. So I think this conversion is necessary actually.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some refactoring needed for the camel casing of classes!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the casing ok as it is now?

@superbock superbock force-pushed the integrate_xlstm_clean branch 3 times, most recently from 4545076 to 69b9d5c Compare March 27, 2025 16:01
@superbock superbock force-pushed the integrate_xlstm_clean branch from b925fc6 to 3f6aec9 Compare March 28, 2025 10:25
@kpoeppel
Copy link
Author

kpoeppel commented Apr 5, 2025

So, finally we got all tests passing, and have integrated a standalone version in the code now, so the external libraries would just be needed for further speed ups. Thanks for the further review!

@kpoeppel
Copy link
Author

kpoeppel commented Apr 5, 2025

comment on the text-generation-level changes: the xLSTMCache class is missing proper documentation (see MambaCache for an example) and should be added to init.py

I think the documentation should now be similar, thanks for pointing this out!

@kpoeppel
Copy link
Author

kpoeppel commented Apr 8, 2025

@Rocketknight1 @Cyrilvallez , sorry for the long-time stall of this, I think it should be ready now for a review, thanks! :)

@ArthurZucker
Copy link
Collaborator

cc @Cyrilvallez as I'll be off for 2 weeks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants