@@ -306,28 +306,60 @@ def forward(
306
306
cache_params = None
307
307
308
308
hidden_states = inputs_embeds
309
- all_hidden_states = () if output_hidden_states else None
310
- for i , xlstm_block in enumerate (self .blocks ):
311
- if self .gradient_checkpointing and self .training :
312
- hidden_states , rnn_state = self ._gradient_checkpointing_func (
313
- xlstm_block .__call__ ,
314
- hidden_states ,
315
- cache_params .rnn_state [i ] if cache_params is not None else None ,
316
- )
317
- else :
318
- hidden_states , rnn_state = xlstm_block (
319
- hidden_states ,
320
- state = cache_params .rnn_state [i ] if cache_params is not None else None ,
321
- )
322
- if cache_params :
323
- for state_idx in range (len (cache_params .rnn_state [i ])):
324
- local_rnn_state = rnn_state [state_idx ]
325
- local_rnn_state = rnn_state [state_idx ]
326
- cache_params .rnn_state [i ][state_idx ].copy_ (local_rnn_state )
327
- cache_params .rnn_state_initial = False
328
309
329
- if output_hidden_states :
330
- all_hidden_states = all_hidden_states + (hidden_states ,)
310
+ if (
311
+ not self .training
312
+ and self .config .max_inference_chunksize < hidden_states .shape [1 ]
313
+ and not output_hidden_states
314
+ ):
315
+ all_hidden_states = None
316
+ offset = 0
317
+ with torch .no_grad ():
318
+ if cache_params is None :
319
+ cache_params = xLSTMCache (config = self .config , batch_size = hidden_states .shape [0 ])
320
+ final_state = torch .zeros_like (hidden_states )
321
+ while offset < hidden_states .shape [1 ]:
322
+ hidden_states_chunk = hidden_states [
323
+ :, offset : min (offset + self .config .max_inference_chunksize , hidden_states .shape [1 ])
324
+ ]
325
+ for i , xlstm_block in enumerate (self .blocks ):
326
+ hidden_states_chunk , rnn_state = xlstm_block (
327
+ hidden_states_chunk ,
328
+ state = cache_params .rnn_state [i ],
329
+ )
330
+ for state_idx in range (len (cache_params .rnn_state [i ])):
331
+ local_rnn_state = rnn_state [state_idx ]
332
+ local_rnn_state = rnn_state [state_idx ]
333
+ cache_params .rnn_state [i ][state_idx ].copy_ (local_rnn_state )
334
+ cache_params .rnn_state_initial = False
335
+ final_state [
336
+ :, offset : min (offset + self .config .max_inference_chunksize , hidden_states .shape [1 ])
337
+ ] = hidden_states_chunk
338
+ offset += self .config .max_inference_chunksize
339
+ hidden_states = final_state
340
+ else :
341
+ all_hidden_states = () if output_hidden_states else None
342
+ for i , xlstm_block in enumerate (self .blocks ):
343
+ if self .gradient_checkpointing and self .training :
344
+ hidden_states , rnn_state = self ._gradient_checkpointing_func (
345
+ xlstm_block .__call__ ,
346
+ hidden_states ,
347
+ cache_params .rnn_state [i ] if cache_params is not None else None ,
348
+ )
349
+ else :
350
+ hidden_states , rnn_state = xlstm_block (
351
+ hidden_states ,
352
+ state = cache_params .rnn_state [i ] if cache_params is not None else None ,
353
+ )
354
+ if cache_params :
355
+ for state_idx in range (len (cache_params .rnn_state [i ])):
356
+ local_rnn_state = rnn_state [state_idx ]
357
+ local_rnn_state = rnn_state [state_idx ]
358
+ cache_params .rnn_state [i ][state_idx ].copy_ (local_rnn_state )
359
+ cache_params .rnn_state_initial = False
360
+
361
+ if output_hidden_states :
362
+ all_hidden_states = all_hidden_states + (hidden_states ,)
331
363
332
364
if use_cache :
333
365
cache_params .seqlen_offset += inputs_embeds .shape [1 ]
@@ -507,7 +539,17 @@ def forward(
507
539
508
540
logits = self .lm_head (hidden_states .to (self .lm_head .weight .dtype )).float ()
509
541
510
- logits = soft_cap (logits , self .config .output_logit_soft_cap )
542
+ if not self .training and self .config .max_inference_chunksize < logits .shape [1 ]:
543
+ offset = 0
544
+ with torch .no_grad ():
545
+ while offset < logits .shape [1 ]:
546
+ logits [:, offset : min (offset + self .config .max_inference_chunksize , logits .shape [1 ])] = soft_cap (
547
+ logits [:, offset : min (offset + self .config .max_inference_chunksize , logits .shape [1 ])],
548
+ self .config .output_logit_soft_cap ,
549
+ )
550
+ offset += self .config .max_inference_chunksize
551
+ else :
552
+ logits = soft_cap (logits , self .config .output_logit_soft_cap )
511
553
512
554
loss = None
513
555
if labels is not None :
0 commit comments