Skip to content

Commit fe5759b

Browse files
committed
Feat: Enable longer context window for inference by chunking.
1 parent 869f52a commit fe5759b

File tree

2 files changed

+68
-22
lines changed

2 files changed

+68
-22
lines changed

src/transformers/models/xlstm/configuration_xlstm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ class xLSTMConfig(PretrainedConfig):
114114
EOS token id needed for generation.
115115
force_bos_token_insert (bool, optional, *optional*, defaults to `True`):
116116
Whether to force the insertion of a BOS token for prompting.
117+
max_inference_chunksize (int, optional, *optional*, defaults to 16384):
118+
Limit the chunk size for inference to save memory.
117119
118120
Example:
119121
@@ -172,6 +174,7 @@ def __init__(
172174
bos_token_id: int = 0,
173175
eos_token_id: int = 2,
174176
force_bos_token_insert: bool = True,
177+
max_inference_chunksize: int = 16384,
175178
**kwargs,
176179
):
177180
self.vocab_size = vocab_size
@@ -209,6 +212,7 @@ def __init__(
209212
self.bos_token_id = bos_token_id
210213
self.eos_token_id = eos_token_id
211214
self.force_bos_token_insert = force_bos_token_insert
215+
self.max_inference_chunksize = max_inference_chunksize
212216

213217
super().__init__(
214218
bos_token_id=bos_token_id,

src/transformers/models/xlstm/modeling_xlstm.py

Lines changed: 64 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -306,28 +306,60 @@ def forward(
306306
cache_params = None
307307

308308
hidden_states = inputs_embeds
309-
all_hidden_states = () if output_hidden_states else None
310-
for i, xlstm_block in enumerate(self.blocks):
311-
if self.gradient_checkpointing and self.training:
312-
hidden_states, rnn_state = self._gradient_checkpointing_func(
313-
xlstm_block.__call__,
314-
hidden_states,
315-
cache_params.rnn_state[i] if cache_params is not None else None,
316-
)
317-
else:
318-
hidden_states, rnn_state = xlstm_block(
319-
hidden_states,
320-
state=cache_params.rnn_state[i] if cache_params is not None else None,
321-
)
322-
if cache_params:
323-
for state_idx in range(len(cache_params.rnn_state[i])):
324-
local_rnn_state = rnn_state[state_idx]
325-
local_rnn_state = rnn_state[state_idx]
326-
cache_params.rnn_state[i][state_idx].copy_(local_rnn_state)
327-
cache_params.rnn_state_initial = False
328309

329-
if output_hidden_states:
330-
all_hidden_states = all_hidden_states + (hidden_states,)
310+
if (
311+
not self.training
312+
and self.config.max_inference_chunksize < hidden_states.shape[1]
313+
and not output_hidden_states
314+
):
315+
all_hidden_states = None
316+
offset = 0
317+
with torch.no_grad():
318+
if cache_params is None:
319+
cache_params = xLSTMCache(config=self.config, batch_size=hidden_states.shape[0])
320+
final_state = torch.zeros_like(hidden_states)
321+
while offset < hidden_states.shape[1]:
322+
hidden_states_chunk = hidden_states[
323+
:, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
324+
]
325+
for i, xlstm_block in enumerate(self.blocks):
326+
hidden_states_chunk, rnn_state = xlstm_block(
327+
hidden_states_chunk,
328+
state=cache_params.rnn_state[i],
329+
)
330+
for state_idx in range(len(cache_params.rnn_state[i])):
331+
local_rnn_state = rnn_state[state_idx]
332+
local_rnn_state = rnn_state[state_idx]
333+
cache_params.rnn_state[i][state_idx].copy_(local_rnn_state)
334+
cache_params.rnn_state_initial = False
335+
final_state[
336+
:, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
337+
] = hidden_states_chunk
338+
offset += self.config.max_inference_chunksize
339+
hidden_states = final_state
340+
else:
341+
all_hidden_states = () if output_hidden_states else None
342+
for i, xlstm_block in enumerate(self.blocks):
343+
if self.gradient_checkpointing and self.training:
344+
hidden_states, rnn_state = self._gradient_checkpointing_func(
345+
xlstm_block.__call__,
346+
hidden_states,
347+
cache_params.rnn_state[i] if cache_params is not None else None,
348+
)
349+
else:
350+
hidden_states, rnn_state = xlstm_block(
351+
hidden_states,
352+
state=cache_params.rnn_state[i] if cache_params is not None else None,
353+
)
354+
if cache_params:
355+
for state_idx in range(len(cache_params.rnn_state[i])):
356+
local_rnn_state = rnn_state[state_idx]
357+
local_rnn_state = rnn_state[state_idx]
358+
cache_params.rnn_state[i][state_idx].copy_(local_rnn_state)
359+
cache_params.rnn_state_initial = False
360+
361+
if output_hidden_states:
362+
all_hidden_states = all_hidden_states + (hidden_states,)
331363

332364
if use_cache:
333365
cache_params.seqlen_offset += inputs_embeds.shape[1]
@@ -507,7 +539,17 @@ def forward(
507539

508540
logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
509541

510-
logits = soft_cap(logits, self.config.output_logit_soft_cap)
542+
if not self.training and self.config.max_inference_chunksize < logits.shape[1]:
543+
offset = 0
544+
with torch.no_grad():
545+
while offset < logits.shape[1]:
546+
logits[:, offset : min(offset + self.config.max_inference_chunksize, logits.shape[1])] = soft_cap(
547+
logits[:, offset : min(offset + self.config.max_inference_chunksize, logits.shape[1])],
548+
self.config.output_logit_soft_cap,
549+
)
550+
offset += self.config.max_inference_chunksize
551+
else:
552+
logits = soft_cap(logits, self.config.output_logit_soft_cap)
511553

512554
loss = None
513555
if labels is not None:

0 commit comments

Comments
 (0)