-
Notifications
You must be signed in to change notification settings - Fork 29.6k
Integrate xlstm cleanly. #35377
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Integrate xlstm cleanly. #35377
Conversation
4a4e347
to
fe5759b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
How can this PR keep failing inside other models' tests, how can they get into 'main'? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment on the text-generation-level changes: the xLSTMCache
class is missing proper documentation (see MambaCache
for an example) and should be added to __init__.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super grateful for the PR! In general happy to have more arch, but we want to make sure we are aligned in terms of what can be done!
🤗
xlstm_block_config = xLSTMLargeConfig( | ||
vocab_size=config.vocab_size, | ||
embedding_dim=config.embedding_dim, | ||
num_blocks=config.num_blocks, | ||
num_heads=config.num_heads, | ||
use_bias=config.use_bias, | ||
add_out_norm=config.add_out_norm, | ||
norm_eps=config.norm_eps, | ||
norm_reduction_force_float32=config.norm_reduction_force_float32, | ||
# mlstm_layer | ||
qk_dim_factor=config.qk_dim_factor, | ||
v_dim_factor=config.v_dim_factor, | ||
# mlstm backend | ||
chunkwise_kernel=config.chunkwise_kernel, | ||
sequence_kernel=config.sequence_kernel, | ||
step_kernel=config.step_kernel, | ||
mode=config.mode, | ||
chunk_size=config.chunk_size, | ||
return_last_states=config.return_last_states, | ||
autocast_kernel_dtype=config.autocast_kernel_dtype, | ||
eps=config.eps, | ||
inference_state_dtype=config.inference_state_dtype, | ||
# feedforward | ||
ffn_proj_factor=config.ffn_proj_factor, | ||
ffn_round_up_to_multiple_of=config.ffn_round_up_to_multiple_of, | ||
# capping | ||
gate_soft_cap=config.gate_soft_cap, | ||
output_logit_soft_cap=config.output_logit_soft_cap, | ||
weight_mode=config.weight_mode, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should align xLSTMLargeConfig to match the inputs of mLSTMBlock
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and this not have to do this here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are still slight deviations of the xLSTMLargeConfig compare to the xLSTMConfig in configuration_xlstm.py
. So I think this conversion is necessary actually.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some refactoring needed for the camel casing of classes!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the casing ok as it is now?
4545076
to
69b9d5c
Compare
Hey! Sorry also for the delay from my side. I integrated all your comments, the xLSTMCache is still necessary I think (like a MambaCache or KVCache). As it has a different structure than these, we need a separate class. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Sorry for the delay! Here is a new review 🤗 Let me know if something is still unclear!
src/transformers/cache_utils.py
Outdated
class xLSTMCache: | ||
""" | ||
Cache for xLSTM model which does not have attention mechanism and key value states. | ||
|
||
Arguments: | ||
config (`PretrainedConfig): | ||
The configuration file defining the shape-related attributes required to initialize the static cache. | ||
max_batch_size (`int`): | ||
The batch size with which the model will be used. | ||
dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): | ||
The default `dtype` to use when initializing the layer. | ||
device (`torch.device` or `str`, *optional*): | ||
The device on which the cache should be initialized. Should be the same as the layer. | ||
|
||
Attributes: | ||
seqlen_offset: int | ||
dtype: torch.dtype | ||
|
||
Example: | ||
|
||
```python | ||
>>> from transformers import AutoTokenizer, xLSTMForCausalLM, xLSTMCache | ||
|
||
>>> model = xLSTMForCausalLM.from_pretrained("NX-AI/xLSTM-7b") | ||
>>> tokenizer = xLSTMTokenizer.from_pretrained("NX-AI/xLSTM-7b") | ||
|
||
>>> inputs = tokenizer(text="I am an xLSTM", return_tensors="pt") | ||
|
||
>>> # Prepare a cache class and pass it to model's forward | ||
>>> cache_params = xLSTMCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) | ||
>>> outputs = model(**inputs, cache_params=cache_params, use_cache=True) | ||
>>> outputs.cache_params | ||
xLSTMCache() | ||
""" | ||
|
||
def __init__( | ||
self, | ||
config: PretrainedConfig, | ||
max_batch_size: int, | ||
dtype: torch.dtype = torch.bfloat16, | ||
device: Optional[str] = None, | ||
**kwargs, | ||
): | ||
self.seqlen_offset = 0 | ||
self.dtype = dtype | ||
self.config = config | ||
self.rnn_state = { | ||
layer: ( | ||
torch.zeros( | ||
[max_batch_size, config.num_heads, config.qk_head_dim, config.v_head_dim], | ||
dtype=dtype, | ||
device=device, | ||
), | ||
torch.zeros([max_batch_size, config.num_heads, config.qk_head_dim], dtype=dtype, device=device), | ||
torch.zeros([max_batch_size, config.num_heads, 1], dtype=dtype, device=device), | ||
) | ||
for layer in range(config.num_hidden_layers) | ||
} | ||
|
||
def reset(self): | ||
self.rnn_state = { | ||
layer: ( | ||
torch.zeros_like(self.rnn_state[layer][0]), | ||
torch.zeros_like(self.rnn_state[layer][1]), | ||
torch.zeros_like(self.rnn_state[layer][2]), | ||
) | ||
for layer in self.rnn_state | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All right, but it should be moved to the modeling file instead then, not general cache_utils
for param_name, param in self.named_parameters(): | ||
if "bias" in param_name and param is not None: | ||
torch.nn.init.zeros_(param) | ||
elif "weight" in param_name and param is not None and param.ndim > 1: | ||
small_init_method(self.config.hidden_size)(param) | ||
|
||
small_init_method(self.config.hidden_size)(self.embeddings.weight) | ||
torch.nn.init.ones_(self.out_norm.weight) | ||
|
||
for block in self.blocks: | ||
torch.nn.init.ones_(block.mlstm_layer.multihead_norm.weight) | ||
torch.nn.init.ones_(block.norm_mlstm.weight) | ||
torch.nn.init.ones_(block.norm_ffn.weight) | ||
|
||
wang_init_method(dim=block.ffn.up_proj_dim, n_layers=self.config.num_hidden_layers)( | ||
block.ffn.proj_down.weight | ||
) | ||
wang_init_method(dim=self.config.hidden_size, n_layers=self.config.num_hidden_layers)( | ||
block.mlstm_layer.out_proj.weight | ||
) | ||
|
||
if self.config.weight_mode == "single": | ||
torch.nn.init.zeros_(block.mlstm_layer.ogate_preact.weight) | ||
torch.nn.init.zeros_(block.mlstm_layer.igate_preact.weight) | ||
torch.nn.init.zeros_(block.mlstm_layer.fgate_preact.weight) | ||
|
||
with torch.no_grad(): | ||
block.mlstm_layer.igate_preact.bias.copy_( | ||
-10.0 * torch.ones_like(block.mlstm_layer.igate_preact.bias) | ||
) | ||
block.mlstm_layer.fgate_preact.bias.copy_( | ||
torch.linspace( | ||
3.0, | ||
6.0, | ||
block.mlstm_layer.fgate_preact.bias.shape[-1], | ||
).to( | ||
device=block.mlstm_layer.fgate_preact.bias.device, | ||
dtype=block.mlstm_layer.fgate_preact.bias.dtype, | ||
) | ||
) | ||
elif self.config.weight_mode == "fused": | ||
torch.nn.init.zeros_(block.mlstm_layer.ifgate_preact.weight) | ||
|
||
with torch.no_grad(): | ||
block.mlstm_layer.ifgate_preact.bias[: self.config.num_heads] += ( | ||
-block.mlstm_layer.ifgate_preact.bias[: self.config.num_heads] | ||
- 10.0 * torch.ones_like(block.mlstm_layer.igate_preact.bias) | ||
) | ||
block.mlstm_layer.ifgate_preact.bias[: self.config.num_heads] += ( | ||
-block.mlstm_layer.ifgate_preact.bias[self.config.num_heads :] | ||
+ torch.linspace( | ||
3.0, | ||
6.0, | ||
block.mlstm_layer.fgate_preact.bias.shape[-1], | ||
).to( | ||
device=block.mlstm_layer.fgate_preact.bias.device, | ||
dtype=block.mlstm_layer.fgate_preact.bias.dtype, | ||
) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is applied iteratively on each module in the model -> we should not iterate on them again, see how it is usually done in e.g. Llama (each module decides what to do)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since there are some special nn.Linear modules (gates) that need a certain init, I now added an additional utility method that can get the global name of a module within xLSTMPreTrainedModel to use it for adaptive initialization. I hope this matches better how the init is intended to work. Otherwise I would have need to wrap many special modules (also FF downprojection). Both within HF code and the original xLSTM repo.
Thanks for the next review! I moved the xLSTMCache to modeling_xlstm.py and resolved all other issues. However now the auto_docstring decorator fails to work, as xLSTMCache probably is not global anymore. Should I switch back to the non-autodocstring docstring or is there a better way to fix this? |
I don't think it has anything to do with the class being public or not! But you can find everything you need about |
There was a leftover xLSTMCache mention in the generation docs files. So all your comments should be integrated now. :) |
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, xlstm |
What does this PR do?
This PR integrates
xLSTM
via thexlstm
-library including certain optimizations (potentially use torch.compile and cuda graphs for speed up). This enables using theNX-AI/xLSTM-7b
without a special fork of transformers.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker