Skip to content

Integrate xlstm cleanly. #35377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 147 commits into
base: main
Choose a base branch
from
Open

Conversation

kpoeppel
Copy link

What does this PR do?

This PR integrates xLSTM via the xlstm-library including certain optimizations (potentially use torch.compile and cuda graphs for speed up). This enables using the NX-AI/xLSTM-7b without a special fork of transformers.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests? Yes, I adapted the tests of the recurrent Mamba2 model.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker

@kpoeppel kpoeppel force-pushed the integrate_xlstm_clean branch 4 times, most recently from 4a4e347 to fe5759b Compare December 21, 2024 13:41
Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@kpoeppel
Copy link
Author

How can this PR keep failing inside other models' tests, how can they get into 'main'?

@Cyrilvallez Cyrilvallez self-assigned this Jan 13, 2025
@Cyrilvallez Cyrilvallez self-requested a review January 13, 2025 10:55
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment on the text-generation-level changes: the xLSTMCache class is missing proper documentation (see MambaCache for an example) and should be added to __init__.py

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super grateful for the PR! In general happy to have more arch, but we want to make sure we are aligned in terms of what can be done!
🤗

Comment on lines 219 to 248
xlstm_block_config = xLSTMLargeConfig(
vocab_size=config.vocab_size,
embedding_dim=config.embedding_dim,
num_blocks=config.num_blocks,
num_heads=config.num_heads,
use_bias=config.use_bias,
add_out_norm=config.add_out_norm,
norm_eps=config.norm_eps,
norm_reduction_force_float32=config.norm_reduction_force_float32,
# mlstm_layer
qk_dim_factor=config.qk_dim_factor,
v_dim_factor=config.v_dim_factor,
# mlstm backend
chunkwise_kernel=config.chunkwise_kernel,
sequence_kernel=config.sequence_kernel,
step_kernel=config.step_kernel,
mode=config.mode,
chunk_size=config.chunk_size,
return_last_states=config.return_last_states,
autocast_kernel_dtype=config.autocast_kernel_dtype,
eps=config.eps,
inference_state_dtype=config.inference_state_dtype,
# feedforward
ffn_proj_factor=config.ffn_proj_factor,
ffn_round_up_to_multiple_of=config.ffn_round_up_to_multiple_of,
# capping
gate_soft_cap=config.gate_soft_cap,
output_logit_soft_cap=config.output_logit_soft_cap,
weight_mode=config.weight_mode,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should align xLSTMLargeConfig to match the inputs of mLSTMBlock

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and this not have to do this here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are still slight deviations of the xLSTMLargeConfig compare to the xLSTMConfig in configuration_xlstm.py. So I think this conversion is necessary actually.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some refactoring needed for the camel casing of classes!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the casing ok as it is now?

@superbock superbock force-pushed the integrate_xlstm_clean branch 3 times, most recently from 4545076 to 69b9d5c Compare March 27, 2025 16:01
@kpoeppel
Copy link
Author

kpoeppel commented Jul 2, 2025

Hey! Super super sorry for the delay! Here is a new round of reviews! Let me know is something is unclear 🤗 Still mostly concerned about the Cache (is it needed??), unnecesary abstractions, single letter variables, and asserts 🤗

Hey! Sorry also for the delay from my side. I integrated all your comments, the xLSTMCache is still necessary I think (like a MambaCache or KVCache). As it has a different structure than these, we need a separate class.

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Sorry for the delay! Here is a new review 🤗 Let me know if something is still unclear!

Comment on lines 2176 to 2245
class xLSTMCache:
"""
Cache for xLSTM model which does not have attention mechanism and key value states.

Arguments:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The batch size with which the model will be used.
dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
The default `dtype` to use when initializing the layer.
device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. Should be the same as the layer.

Attributes:
seqlen_offset: int
dtype: torch.dtype

Example:

```python
>>> from transformers import AutoTokenizer, xLSTMForCausalLM, xLSTMCache

>>> model = xLSTMForCausalLM.from_pretrained("NX-AI/xLSTM-7b")
>>> tokenizer = xLSTMTokenizer.from_pretrained("NX-AI/xLSTM-7b")

>>> inputs = tokenizer(text="I am an xLSTM", return_tensors="pt")

>>> # Prepare a cache class and pass it to model's forward
>>> cache_params = xLSTMCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, cache_params=cache_params, use_cache=True)
>>> outputs.cache_params
xLSTMCache()
"""

def __init__(
self,
config: PretrainedConfig,
max_batch_size: int,
dtype: torch.dtype = torch.bfloat16,
device: Optional[str] = None,
**kwargs,
):
self.seqlen_offset = 0
self.dtype = dtype
self.config = config
self.rnn_state = {
layer: (
torch.zeros(
[max_batch_size, config.num_heads, config.qk_head_dim, config.v_head_dim],
dtype=dtype,
device=device,
),
torch.zeros([max_batch_size, config.num_heads, config.qk_head_dim], dtype=dtype, device=device),
torch.zeros([max_batch_size, config.num_heads, 1], dtype=dtype, device=device),
)
for layer in range(config.num_hidden_layers)
}

def reset(self):
self.rnn_state = {
layer: (
torch.zeros_like(self.rnn_state[layer][0]),
torch.zeros_like(self.rnn_state[layer][1]),
torch.zeros_like(self.rnn_state[layer][2]),
)
for layer in self.rnn_state
}


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All right, but it should be moved to the modeling file instead then, not general cache_utils

Comment on lines 1286 to 1344
for param_name, param in self.named_parameters():
if "bias" in param_name and param is not None:
torch.nn.init.zeros_(param)
elif "weight" in param_name and param is not None and param.ndim > 1:
small_init_method(self.config.hidden_size)(param)

small_init_method(self.config.hidden_size)(self.embeddings.weight)
torch.nn.init.ones_(self.out_norm.weight)

for block in self.blocks:
torch.nn.init.ones_(block.mlstm_layer.multihead_norm.weight)
torch.nn.init.ones_(block.norm_mlstm.weight)
torch.nn.init.ones_(block.norm_ffn.weight)

wang_init_method(dim=block.ffn.up_proj_dim, n_layers=self.config.num_hidden_layers)(
block.ffn.proj_down.weight
)
wang_init_method(dim=self.config.hidden_size, n_layers=self.config.num_hidden_layers)(
block.mlstm_layer.out_proj.weight
)

if self.config.weight_mode == "single":
torch.nn.init.zeros_(block.mlstm_layer.ogate_preact.weight)
torch.nn.init.zeros_(block.mlstm_layer.igate_preact.weight)
torch.nn.init.zeros_(block.mlstm_layer.fgate_preact.weight)

with torch.no_grad():
block.mlstm_layer.igate_preact.bias.copy_(
-10.0 * torch.ones_like(block.mlstm_layer.igate_preact.bias)
)
block.mlstm_layer.fgate_preact.bias.copy_(
torch.linspace(
3.0,
6.0,
block.mlstm_layer.fgate_preact.bias.shape[-1],
).to(
device=block.mlstm_layer.fgate_preact.bias.device,
dtype=block.mlstm_layer.fgate_preact.bias.dtype,
)
)
elif self.config.weight_mode == "fused":
torch.nn.init.zeros_(block.mlstm_layer.ifgate_preact.weight)

with torch.no_grad():
block.mlstm_layer.ifgate_preact.bias[: self.config.num_heads] += (
-block.mlstm_layer.ifgate_preact.bias[: self.config.num_heads]
- 10.0 * torch.ones_like(block.mlstm_layer.igate_preact.bias)
)
block.mlstm_layer.ifgate_preact.bias[: self.config.num_heads] += (
-block.mlstm_layer.ifgate_preact.bias[self.config.num_heads :]
+ torch.linspace(
3.0,
6.0,
block.mlstm_layer.fgate_preact.bias.shape[-1],
).to(
device=block.mlstm_layer.fgate_preact.bias.device,
dtype=block.mlstm_layer.fgate_preact.bias.dtype,
)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is applied iteratively on each module in the model -> we should not iterate on them again, see how it is usually done in e.g. Llama (each module decides what to do)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there are some special nn.Linear modules (gates) that need a certain init, I now added an additional utility method that can get the global name of a module within xLSTMPreTrainedModel to use it for adaptive initialization. I hope this matches better how the init is intended to work. Otherwise I would have need to wrap many special modules (also FF downprojection). Both within HF code and the original xLSTM repo.

@kpoeppel
Copy link
Author

kpoeppel commented Jul 8, 2025

Hey! Sorry for the delay! Here is a new review 🤗 Let me know if something is still unclear!

Thanks for the next review!

I moved the xLSTMCache to modeling_xlstm.py and resolved all other issues. However now the auto_docstring decorator fails to work, as xLSTMCache probably is not global anymore. Should I switch back to the non-autodocstring docstring or is there a better way to fix this?

@Cyrilvallez
Copy link
Member

I don't think it has anything to do with the class being public or not! But you can find everything you need about auto_docstring here! Basically, you need to add docstring only for "unknown" args, e.g. here cache_params is unknown in the library

@kpoeppel
Copy link
Author

kpoeppel commented Jul 9, 2025

There was a leftover xLSTMCache mention in the generation docs files. So all your comments should be integrated now. :)

Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, xlstm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants