@@ -43,19 +43,19 @@ class LogitsStorage:
4343
4444 def __init__ (
4545 self ,
46+ * ,
4647 seq_length : int ,
4748 use_device_memory = True ,
48- should_exclude_last = False ,
49+ extra_token_for_overlap_scheduler = False ,
4950 use_chunked_generation_logits = False ,
5051 chunk_size = 8
5152 ): # logic adpted from HandleGenerationLogits.cpp to use chunked transfer
52- if should_exclude_last :
53+ if extra_token_for_overlap_scheduler :
5354 # Exclude last logits is used when overlap scheduler is used, that generates one extra token,
5455 # so we should make sure there's memory for that extra +1.
5556 seq_length += 1
5657 self .seq_length = seq_length
5758 self .use_device_memory = use_device_memory
58- self ._should_exclude_last = should_exclude_last
5959 self .use_chunked_generation_logits = use_chunked_generation_logits
6060 self .chunk_size = chunk_size
6161 self ._logits_indices = []
@@ -126,14 +126,14 @@ def append(self, logits: torch.Tensor):
126126 non_blocking = True )
127127 self ._logits_indices .append ((position , new_position ))
128128
129- def get (self , all_logits : bool ) -> torch .Tensor | None :
129+ def get (self , all_logits : bool , exclude_last : bool ) -> torch .Tensor | None :
130130 """Returns the used logits storage if there are any, otherwise, returns None.
131131 When all_logits is True then all set logits are returned, otherwise, only the last logits are returned."""
132132 if self ._storage is None :
133133 return None
134134
135135 try :
136- last = - 2 if self . _should_exclude_last else - 1
136+ last = - 2 if exclude_last else - 1
137137 start = 0 if all_logits else self ._logits_indices [last ][0 ]
138138 end = self ._logits_indices [last ][1 ]
139139 return self ._storage [start :end ]
@@ -175,9 +175,6 @@ def finalize_chunked_transfer(self):
175175 if self .use_chunked_generation_logits and self ._device_fragments :
176176 self ._transfer_chunk_to_host ()
177177
178- def set_exclude_last (self , should_exclude_last : bool ) -> None :
179- self ._should_exclude_last = should_exclude_last
180-
181178
182179class LogProbStorage :
183180 beam_width : int = - 1
@@ -225,6 +222,7 @@ class PyResult:
225222 """PyResult reimplements some features of `bindings.executor.Result` in Python"""
226223
227224 def __init__ (self ,
225+ * ,
228226 prompt_len : int ,
229227 max_new_tokens : int ,
230228 use_device_memory = True ,
@@ -240,16 +238,20 @@ def __init__(self,
240238 assert chunk_size == 1 , "chunk_size must be 1 in streaming mode"
241239 self ._streaming = streaming
242240 self ._chunk_size = chunk_size
241+ self ._exclude_last_generation_logits = exclude_last_generation_logits
243242
244243 # Note that in C++ implemnetation both context logits and generation logits are stored on host memory.
245244 # Here we only use host memory for generation logits if in chunked model.
246245 self ._context_logits = LogitsStorage (
247- prompt_len , use_device_memory , use_chunked_generation_logits = False
246+ seq_length = prompt_len ,
247+ use_device_memory = use_device_memory ,
248+ extra_token_for_overlap_scheduler = False ,
249+ use_chunked_generation_logits = False
248250 ) if return_context_logits else None
249251 self ._generation_logits = LogitsStorage (
250- max_new_tokens ,
251- use_device_memory ,
252- exclude_last_generation_logits ,
252+ seq_length = max_new_tokens ,
253+ use_device_memory = use_device_memory ,
254+ extra_token_for_overlap_scheduler = exclude_last_generation_logits ,
253255 use_chunked_generation_logits = use_chunked_generation_logits ,
254256 chunk_size = self ._chunk_size ) if return_generation_logits else None
255257 self ._log_probs = LogProbStorage () if return_log_probs else None
@@ -263,6 +265,10 @@ def __init__(self,
263265 for name in additional_outputs
264266 } if additional_outputs else None
265267
268+ def set_exclude_last_generation_logits (
269+ self , exclude_last_generation_logits : bool ):
270+ self ._exclude_last_generation_logits = exclude_last_generation_logits
271+
266272 def append_context_logits (self , context_logits : torch .Tensor ):
267273 if self ._context_logits :
268274 self ._context_logits .append (context_logits )
@@ -309,7 +315,7 @@ def set_log_probs(self, log_probs: list[TokenLogprobs],
309315 @property
310316 def context_logits (self ) -> torch .Tensor | None :
311317 if self ._context_logits is None or (storage := self ._context_logits .get (
312- all_logits = True )) is None :
318+ all_logits = True , exclude_last = False )) is None :
313319 return None
314320 return storage [:, 0 ] # remove beam_width axis for context
315321
@@ -320,7 +326,9 @@ def generation_logits(self) -> torch.Tensor | None:
320326 if not self ._generation_logits :
321327 return None
322328
323- storage = self ._generation_logits .get (all_logits = not self ._streaming )
329+ storage = self ._generation_logits .get (
330+ all_logits = not self ._streaming ,
331+ exclude_last = self ._exclude_last_generation_logits )
324332 if storage is None :
325333 return None
326334 return storage .transpose (0 , 1 )
@@ -522,14 +530,14 @@ def __init__(
522530 self .py_stop_words_list = stop_words_list
523531
524532 self .py_result = PyResult (
525- self .py_prompt_len ,
526- self .py_max_new_tokens ,
527- return_logits_device_memory ,
528- self .streaming ,
529- return_log_probs ,
530- return_context_logits ,
531- return_generation_logits ,
532- exclude_last_generation_logits ,
533+ prompt_len = self .py_prompt_len ,
534+ max_new_tokens = self .py_max_new_tokens ,
535+ use_device_memory = return_logits_device_memory ,
536+ streaming = self .streaming ,
537+ return_log_probs = return_log_probs ,
538+ return_context_logits = return_context_logits ,
539+ return_generation_logits = return_generation_logits ,
540+ exclude_last_generation_logits = exclude_last_generation_logits ,
533541 use_chunked_generation_logits = self .py_use_chunked_generation_logits ,
534542 chunk_size = self .py_logits_chunk_size ,
535543 additional_outputs = additional_outputs )
@@ -543,6 +551,11 @@ def __init__(
543551 else :
544552 self ._py_embedding_bias_1d = self .embedding_bias
545553
554+ def set_exclude_last_generation_logits (
555+ self , exclude_last_generation_logits : bool ):
556+ self .py_result .set_exclude_last_generation_logits (
557+ exclude_last_generation_logits )
558+
546559 @property
547560 def cached_tokens (self ) -> int :
548561 return self ._cached_tokens
0 commit comments