-
Notifications
You must be signed in to change notification settings - Fork 28.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate xlstm cleanly. #35377
base: main
Are you sure you want to change the base?
Integrate xlstm cleanly. #35377
Conversation
4a4e347
to
fe5759b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
How can this PR keep failing inside other models' tests, how can they get into 'main'? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment on the text-generation-level changes: the xLSTMCache
class is missing proper documentation (see MambaCache
for an example) and should be added to __init__.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super grateful for the PR! In general happy to have more arch, but we want to make sure we are aligned in terms of what can be done!
🤗
xlstm_block_config = xLSTMLargeConfig( | ||
vocab_size=config.vocab_size, | ||
embedding_dim=config.embedding_dim, | ||
num_blocks=config.num_blocks, | ||
num_heads=config.num_heads, | ||
use_bias=config.use_bias, | ||
add_out_norm=config.add_out_norm, | ||
norm_eps=config.norm_eps, | ||
norm_reduction_force_float32=config.norm_reduction_force_float32, | ||
# mlstm_layer | ||
qk_dim_factor=config.qk_dim_factor, | ||
v_dim_factor=config.v_dim_factor, | ||
# mlstm backend | ||
chunkwise_kernel=config.chunkwise_kernel, | ||
sequence_kernel=config.sequence_kernel, | ||
step_kernel=config.step_kernel, | ||
mode=config.mode, | ||
chunk_size=config.chunk_size, | ||
return_last_states=config.return_last_states, | ||
autocast_kernel_dtype=config.autocast_kernel_dtype, | ||
eps=config.eps, | ||
inference_state_dtype=config.inference_state_dtype, | ||
# feedforward | ||
ffn_proj_factor=config.ffn_proj_factor, | ||
ffn_round_up_to_multiple_of=config.ffn_round_up_to_multiple_of, | ||
# capping | ||
gate_soft_cap=config.gate_soft_cap, | ||
output_logit_soft_cap=config.output_logit_soft_cap, | ||
weight_mode=config.weight_mode, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should align xLSTMLargeConfig to match the inputs of mLSTMBlock
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and this not have to do this here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are still slight deviations of the xLSTMLargeConfig compare to the xLSTMConfig in configuration_xlstm.py
. So I think this conversion is necessary actually.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some refactoring needed for the camel casing of classes!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the casing ok as it is now?
4545076
to
69b9d5c
Compare
b925fc6
to
3f6aec9
Compare
So, finally we got all tests passing, and have integrated a standalone version in the code now, so the external libraries would just be needed for further speed ups. Thanks for the further review! |
I think the documentation should now be similar, thanks for pointing this out! |
@Rocketknight1 @Cyrilvallez , sorry for the long-time stall of this, I think it should be ready now for a review, thanks! :) |
cc @Cyrilvallez as I'll be off for 2 weeks! |
What does this PR do?
This PR integrates
xLSTM
via thexlstm
-library including certain optimizations (potentially use torch.compile and cuda graphs for speed up). This enables using theNX-AI/xLSTM-7b
without a special fork of transformers.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker