1
1
import datetime
2
+ import re
2
3
from copy import copy
3
4
from dataclasses import dataclass
4
5
from typing import TYPE_CHECKING , Any , Iterator , List , Optional , Union
@@ -82,9 +83,18 @@ def is_stop_sequence_found(
82
83
]
83
84
)
84
85
85
- def strip_stop_sequences (
86
- self , sequence : str , stop_sequences : Optional [List [str ]]
87
- ) -> str :
86
+ @staticmethod
87
+ def strip_max_words_sequences (sequence : str , max_words : Optional [int ]) -> str :
88
+ if max_words is not None :
89
+ splits = sequence .split ()
90
+ if len (splits ) > max_words :
91
+ last_word = splits [- 1 ]
92
+ sequence = sequence .rstrip (last_word ).rstrip ()
93
+
94
+ return sequence
95
+
96
+ @staticmethod
97
+ def strip_stop_sequences (sequence : str , stop_sequences : Optional [List [str ]]) -> str :
88
98
"""Remove the stop sequences from the generated sequences.
89
99
90
100
Parameters
@@ -131,6 +141,7 @@ def __call__(
131
141
self ,
132
142
prompts : Union [str , List [str ]],
133
143
max_tokens : Optional [int ] = None ,
144
+ max_words : Optional [int ] = None ,
134
145
stop_at : Optional [Union [str , List [str ]]] = None ,
135
146
rng : Optional ["torch.Generator" ] = None ,
136
147
) -> Union [FormattedOutput , List [FormattedOutput ], List [List [FormattedOutput ]]]:
@@ -148,7 +159,12 @@ def __call__(
148
159
generating the first token.
149
160
max_tokens
150
161
An integer representing maximum number of tokens that will be generated
151
- (per prompt)
162
+ (per prompt). If both `max_tokens` and `max_words` are passed, it will
163
+ stop when the first one is reached
164
+ max_words
165
+ An integer representing maximum number of words that will be generated
166
+ (per prompt). If both `max_tokens` and `max_words` are passed, it will
167
+ stop when the first one is reached
152
168
stop_at
153
169
A string or list of strings at which the text generated will stop
154
170
rng
@@ -203,16 +219,29 @@ def __call__(
203
219
rng = rng ,
204
220
)
205
221
222
+ # If we have max_words but no max_tokens, let's put a limit on the number of tokens
223
+ # so that we reduce the generation time and do not exceed context length if
224
+ # no stop token is met.
225
+ # A high estimation of average number of tokens per word in a multilanguage
226
+ # context is 2, let's take some precaution and increase it a bit to 3
227
+ if max_words and max_tokens is None :
228
+ max_tokens = 3 * max_words
229
+
206
230
while True :
207
231
try :
208
232
last_state = next (states )
209
- if max_tokens or stop_sequences :
233
+ if max_tokens or max_words or stop_sequences :
210
234
token_ids = last_state .token_ids
211
235
generated_token_ids = self .get_generated_token_ids (
212
236
prompt_token_ids , token_ids
213
237
)
214
238
if max_tokens and len (generated_token_ids [0 ]) >= max_tokens :
215
239
break
240
+ if max_words and all (
241
+ len (sentence .split ()) > max_words
242
+ for sentence in self .tokenizer .decode (generated_token_ids )
243
+ ):
244
+ break
216
245
if stop_sequences and self .is_stop_sequence_found (
217
246
self .tokenizer .decode (generated_token_ids ), stop_sequences
218
247
):
@@ -224,9 +253,13 @@ def __call__(
224
253
generated_token_ids = self .get_generated_token_ids (prompt_token_ids , token_ids )
225
254
226
255
generated = self .tokenizer .decode (generated_token_ids )
256
+ max_words_stripped = [
257
+ self .strip_max_words_sequences (sequence , max_words )
258
+ for sequence in generated
259
+ ]
227
260
stripped = [
228
261
self .strip_stop_sequences (sequence , stop_sequences )
229
- for sequence in generated
262
+ for sequence in max_words_stripped
230
263
]
231
264
formatted = [self .format_sequence (sequence ) for sequence in stripped ]
232
265
@@ -249,6 +282,7 @@ def stream(
249
282
self ,
250
283
prompts : Union [str , List [str ]],
251
284
max_tokens : Optional [int ] = None ,
285
+ max_words : Optional [int ] = None ,
252
286
stop_at : Optional [Union [str , List [str ]]] = None ,
253
287
rng : Optional ["torch.Generator" ] = None ,
254
288
) -> Iterator [Union [List [str ], str , List [List [str ]]]]:
@@ -329,9 +363,12 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
329
363
] * num_samples
330
364
num_generated = 0
331
365
is_stop_at_reached = [False for _ in range (batch_size )] * num_samples
366
+ is_max_words_at_reached = [False for _ in range (batch_size )] * num_samples
332
367
while True :
333
- if (max_tokens and num_generated >= max_tokens ) or all (
334
- is_stop_at_reached
368
+ if (
369
+ (max_tokens and num_generated >= max_tokens )
370
+ or all (is_stop_at_reached )
371
+ or all (is_max_words_at_reached )
335
372
):
336
373
return
337
374
try :
@@ -341,6 +378,21 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
341
378
return
342
379
generated_token_ids = sequence .token_ids [:, - num_generated :]
343
380
generated_sequences = self .tokenizer .decode (generated_token_ids )
381
+ if max_words is not None :
382
+ is_max_words_at_reached = [
383
+ stop or len (generated_sequence .split ()) > max_words
384
+ for generated_sequence , stop in zip (
385
+ generated_sequences , is_max_words_at_reached
386
+ )
387
+ ]
388
+ generated_sequences = [
389
+ self .strip_max_words_sequences (sequence , max_words )
390
+ if stop
391
+ else sequence
392
+ for sequence , stop in zip (
393
+ generated_sequences , is_max_words_at_reached
394
+ )
395
+ ]
344
396
if stop_sequences :
345
397
is_stop_at_reached = [
346
398
stop
@@ -487,16 +539,36 @@ def _format(self, sequences):
487
539
else :
488
540
return self .format_sequence (sequences )
489
541
542
+ @staticmethod
543
+ def reconstruct_till_max_words (sequence : str , max_words : Optional [int ]) -> str :
544
+ if max_words is not None :
545
+ if len (sequence .split ()) > max_words :
546
+ matches = re .findall (r"(\s*\S+)(\s*)" , sequence )
547
+ return "" .join (
548
+ word + whitespace for word , whitespace in matches [:max_words ]
549
+ ).rstrip ()
550
+
551
+ return sequence
552
+
490
553
def __call__ (
491
554
self ,
492
555
prompts : Union [str , List [str ]],
493
556
max_tokens : Optional [int ] = None ,
557
+ max_words : Optional [int ] = None ,
494
558
stop_at : Optional [Union [str , List [str ]]] = None ,
495
559
seed : Optional [int ] = None ,
496
560
** model_specific_params ,
497
561
):
498
562
"""Generate text from a prompt of list of prompts."""
499
563
564
+ # If we have max_words but no max_tokens, let's put a limit on the number of tokens
565
+ # so that we reduce the generation time and do not exceed context length if
566
+ # no stop token is met.
567
+ # A high estimation of average number of tokens per word in a multilanguage
568
+ # context is 2, let's take some precaution and increase it a bit to 3
569
+ if max_words and max_tokens is None :
570
+ max_tokens = 3 * max_words
571
+
500
572
generation_params = self .prepare_generation_parameters (
501
573
max_tokens , stop_at , seed
502
574
)
@@ -509,6 +581,13 @@ def __call__(
509
581
** model_specific_params ,
510
582
)
511
583
584
+ if isinstance (completions , str ):
585
+ completions = self .reconstruct_till_max_words (completions , max_words )
586
+ else :
587
+ completions = [
588
+ self .reconstruct_till_max_words (seq , max_words ) for seq in completions
589
+ ]
590
+
512
591
return self ._format (completions )
513
592
514
593
def stream (
0 commit comments