1
1
import datetime
2
+ import re
2
3
from copy import copy
3
4
from dataclasses import dataclass
4
5
from typing import TYPE_CHECKING , Any , Iterator , List , Optional , Union
@@ -81,9 +82,18 @@ def is_stop_sequence_found(
81
82
]
82
83
)
83
84
84
- def strip_stop_sequences (
85
- self , sequence : str , stop_sequences : Optional [List [str ]]
86
- ) -> str :
85
+ @staticmethod
86
+ def strip_max_words_sequences (sequence : str , max_words : Optional [int ]) -> str :
87
+ if max_words is not None :
88
+ splits = sequence .split ()
89
+ if len (splits ) > max_words :
90
+ last_word = splits [- 1 ]
91
+ sequence = sequence .rstrip (last_word ).rstrip ()
92
+
93
+ return sequence
94
+
95
+ @staticmethod
96
+ def strip_stop_sequences (sequence : str , stop_sequences : Optional [List [str ]]) -> str :
87
97
"""Remove the stop sequences from the generated sequences.
88
98
89
99
Parameters
@@ -130,6 +140,7 @@ def __call__(
130
140
self ,
131
141
prompts : Union [str , List [str ]],
132
142
max_tokens : Optional [int ] = None ,
143
+ max_words : Optional [int ] = None ,
133
144
stop_at : Optional [Union [str , List [str ]]] = None ,
134
145
rng : Optional ["torch.Generator" ] = None ,
135
146
) -> Union [FormattedOutput , List [FormattedOutput ], List [List [FormattedOutput ]]]:
@@ -147,7 +158,12 @@ def __call__(
147
158
generating the first token.
148
159
max_tokens
149
160
An integer representing maximum number of tokens that will be generated
150
- (per prompt)
161
+ (per prompt). If both `max_tokens` and `max_words` are passed, it will
162
+ stop when the first one is reached
163
+ max_words
164
+ An integer representing maximum number of words that will be generated
165
+ (per prompt). If both `max_tokens` and `max_words` are passed, it will
166
+ stop when the first one is reached
151
167
stop_at
152
168
A string or list of strings at which the text generated will stop
153
169
rng
@@ -202,16 +218,29 @@ def __call__(
202
218
rng = rng ,
203
219
)
204
220
221
+ # If we have max_words but no max_tokens, let's put a limit on the number of tokens
222
+ # so that we reduce the generation time and do not exceed context length if
223
+ # no stop token is met.
224
+ # A high estimation of average number of tokens per word in a multilanguage
225
+ # context is 2, let's take some precaution and increase it a bit to 3
226
+ if max_words and max_tokens is None :
227
+ max_tokens = 3 * max_words
228
+
205
229
while True :
206
230
try :
207
231
last_state = next (states )
208
- if max_tokens or stop_sequences :
232
+ if max_tokens or max_words or stop_sequences :
209
233
token_ids = last_state .token_ids
210
234
generated_token_ids = self .get_generated_token_ids (
211
235
prompt_token_ids , token_ids
212
236
)
213
237
if max_tokens and len (generated_token_ids [0 ]) >= max_tokens :
214
238
break
239
+ if max_words and all (
240
+ len (sentence .split ()) > max_words
241
+ for sentence in self .tokenizer .decode (generated_token_ids )
242
+ ):
243
+ break
215
244
if stop_sequences and self .is_stop_sequence_found (
216
245
self .tokenizer .decode (generated_token_ids ), stop_sequences
217
246
):
@@ -223,9 +252,13 @@ def __call__(
223
252
generated_token_ids = self .get_generated_token_ids (prompt_token_ids , token_ids )
224
253
225
254
generated = self .tokenizer .decode (generated_token_ids )
255
+ max_words_stripped = [
256
+ self .strip_max_words_sequences (sequence , max_words )
257
+ for sequence in generated
258
+ ]
226
259
stripped = [
227
260
self .strip_stop_sequences (sequence , stop_sequences )
228
- for sequence in generated
261
+ for sequence in max_words_stripped
229
262
]
230
263
formatted = [self .format_sequence (sequence ) for sequence in stripped ]
231
264
@@ -248,6 +281,7 @@ def stream(
248
281
self ,
249
282
prompts : Union [str , List [str ]],
250
283
max_tokens : Optional [int ] = None ,
284
+ max_words : Optional [int ] = None ,
251
285
stop_at : Optional [Union [str , List [str ]]] = None ,
252
286
rng : Optional ["torch.Generator" ] = None ,
253
287
) -> Iterator [Union [List [str ], str , List [List [str ]]]]:
@@ -328,9 +362,12 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
328
362
] * num_samples
329
363
num_generated = 0
330
364
is_stop_at_reached = [False for _ in range (batch_size )] * num_samples
365
+ is_max_words_at_reached = [False for _ in range (batch_size )] * num_samples
331
366
while True :
332
- if (max_tokens and num_generated >= max_tokens ) or all (
333
- is_stop_at_reached
367
+ if (
368
+ (max_tokens and num_generated >= max_tokens )
369
+ or all (is_stop_at_reached )
370
+ or all (is_max_words_at_reached )
334
371
):
335
372
return
336
373
try :
@@ -340,6 +377,21 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
340
377
return
341
378
generated_token_ids = sequence .token_ids [:, - num_generated :]
342
379
generated_sequences = self .tokenizer .decode (generated_token_ids )
380
+ if max_words is not None :
381
+ is_max_words_at_reached = [
382
+ stop or len (generated_sequence .split ()) > max_words
383
+ for generated_sequence , stop in zip (
384
+ generated_sequences , is_max_words_at_reached
385
+ )
386
+ ]
387
+ generated_sequences = [
388
+ self .strip_max_words_sequences (sequence , max_words )
389
+ if stop
390
+ else sequence
391
+ for sequence , stop in zip (
392
+ generated_sequences , is_max_words_at_reached
393
+ )
394
+ ]
343
395
if stop_sequences :
344
396
is_stop_at_reached = [
345
397
stop
@@ -473,16 +525,36 @@ def _format(self, sequences):
473
525
else :
474
526
return self .format_sequence (sequences )
475
527
528
+ @staticmethod
529
+ def reconstruct_till_max_words (sequence : str , max_words : Optional [int ]) -> str :
530
+ if max_words is not None :
531
+ if len (sequence .split ()) > max_words :
532
+ matches = re .findall (r"(\s*\S+)(\s*)" , sequence )
533
+ return "" .join (
534
+ word + whitespace for word , whitespace in matches [:max_words ]
535
+ ).rstrip ()
536
+
537
+ return sequence
538
+
476
539
def __call__ (
477
540
self ,
478
541
prompts : Union [str , List [str ]],
479
542
max_tokens : Optional [int ] = None ,
543
+ max_words : Optional [int ] = None ,
480
544
stop_at : Optional [Union [str , List [str ]]] = None ,
481
545
seed : Optional [int ] = None ,
482
546
** model_specific_params ,
483
547
):
484
548
"""Generate text from a prompt of list of prompts."""
485
549
550
+ # If we have max_words but no max_tokens, let's put a limit on the number of tokens
551
+ # so that we reduce the generation time and do not exceed context length if
552
+ # no stop token is met.
553
+ # A high estimation of average number of tokens per word in a multilanguage
554
+ # context is 2, let's take some precaution and increase it a bit to 3
555
+ if max_words and max_tokens is None :
556
+ max_tokens = 3 * max_words
557
+
486
558
generation_params = self .prepare_generation_parameters (
487
559
max_tokens , stop_at , seed
488
560
)
@@ -495,6 +567,13 @@ def __call__(
495
567
** model_specific_params ,
496
568
)
497
569
570
+ if isinstance (completions , str ):
571
+ completions = self .reconstruct_till_max_words (completions , max_words )
572
+ else :
573
+ completions = [
574
+ self .reconstruct_till_max_words (seq , max_words ) for seq in completions
575
+ ]
576
+
498
577
return self ._format (completions )
499
578
500
579
def stream (
0 commit comments