Skip to content

Commit cb65326

Browse files
committed
feat(generator): add maximum number of words limit in generation
1 parent 088f439 commit cb65326

File tree

3 files changed

+186
-8
lines changed

3 files changed

+186
-8
lines changed

outlines/generate/api.py

+87-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime
2+
import re
23
from copy import copy
34
from dataclasses import dataclass
45
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
@@ -81,9 +82,18 @@ def is_stop_sequence_found(
8182
]
8283
)
8384

84-
def strip_stop_sequences(
85-
self, sequence: str, stop_sequences: Optional[List[str]]
86-
) -> str:
85+
@staticmethod
86+
def strip_max_words_sequences(sequence: str, max_words: Optional[int]) -> str:
87+
if max_words is not None:
88+
splits = sequence.split()
89+
if len(splits) > max_words:
90+
last_word = splits[-1]
91+
sequence = sequence.rstrip(last_word).rstrip()
92+
93+
return sequence
94+
95+
@staticmethod
96+
def strip_stop_sequences(sequence: str, stop_sequences: Optional[List[str]]) -> str:
8797
"""Remove the stop sequences from the generated sequences.
8898
8999
Parameters
@@ -130,6 +140,7 @@ def __call__(
130140
self,
131141
prompts: Union[str, List[str]],
132142
max_tokens: Optional[int] = None,
143+
max_words: Optional[int] = None,
133144
stop_at: Optional[Union[str, List[str]]] = None,
134145
rng: Optional["torch.Generator"] = None,
135146
) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]:
@@ -147,7 +158,12 @@ def __call__(
147158
generating the first token.
148159
max_tokens
149160
An integer representing maximum number of tokens that will be generated
150-
(per prompt)
161+
(per prompt). If both `max_tokens` and `max_words` are passed, it will
162+
stop when the first one is reached
163+
max_words
164+
An integer representing maximum number of words that will be generated
165+
(per prompt). If both `max_tokens` and `max_words` are passed, it will
166+
stop when the first one is reached
151167
stop_at
152168
A string or list of strings at which the text generated will stop
153169
rng
@@ -202,16 +218,29 @@ def __call__(
202218
rng=rng,
203219
)
204220

221+
# If we have max_words but no max_tokens, let's put a limit on the number of tokens
222+
# so that we reduce the generation time and do not exceed context length if
223+
# no stop token is met.
224+
# A high estimation of average number of tokens per word in a multilanguage
225+
# context is 2, let's take some precaution and increase it a bit to 3
226+
if max_words and max_tokens is None:
227+
max_tokens = 3 * max_words
228+
205229
while True:
206230
try:
207231
last_state = next(states)
208-
if max_tokens or stop_sequences:
232+
if max_tokens or max_words or stop_sequences:
209233
token_ids = last_state.token_ids
210234
generated_token_ids = self.get_generated_token_ids(
211235
prompt_token_ids, token_ids
212236
)
213237
if max_tokens and len(generated_token_ids[0]) >= max_tokens:
214238
break
239+
if max_words and all(
240+
len(sentence.split()) > max_words
241+
for sentence in self.tokenizer.decode(generated_token_ids)
242+
):
243+
break
215244
if stop_sequences and self.is_stop_sequence_found(
216245
self.tokenizer.decode(generated_token_ids), stop_sequences
217246
):
@@ -223,9 +252,13 @@ def __call__(
223252
generated_token_ids = self.get_generated_token_ids(prompt_token_ids, token_ids)
224253

225254
generated = self.tokenizer.decode(generated_token_ids)
255+
max_words_stripped = [
256+
self.strip_max_words_sequences(sequence, max_words)
257+
for sequence in generated
258+
]
226259
stripped = [
227260
self.strip_stop_sequences(sequence, stop_sequences)
228-
for sequence in generated
261+
for sequence in max_words_stripped
229262
]
230263
formatted = [self.format_sequence(sequence) for sequence in stripped]
231264

@@ -248,6 +281,7 @@ def stream(
248281
self,
249282
prompts: Union[str, List[str]],
250283
max_tokens: Optional[int] = None,
284+
max_words: Optional[int] = None,
251285
stop_at: Optional[Union[str, List[str]]] = None,
252286
rng: Optional["torch.Generator"] = None,
253287
) -> Iterator[Union[List[str], str, List[List[str]]]]:
@@ -328,9 +362,12 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
328362
] * num_samples
329363
num_generated = 0
330364
is_stop_at_reached = [False for _ in range(batch_size)] * num_samples
365+
is_max_words_at_reached = [False for _ in range(batch_size)] * num_samples
331366
while True:
332-
if (max_tokens and num_generated >= max_tokens) or all(
333-
is_stop_at_reached
367+
if (
368+
(max_tokens and num_generated >= max_tokens)
369+
or all(is_stop_at_reached)
370+
or all(is_max_words_at_reached)
334371
):
335372
return
336373
try:
@@ -340,6 +377,21 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
340377
return
341378
generated_token_ids = sequence.token_ids[:, -num_generated:]
342379
generated_sequences = self.tokenizer.decode(generated_token_ids)
380+
if max_words is not None:
381+
is_max_words_at_reached = [
382+
stop or len(generated_sequence.split()) > max_words
383+
for generated_sequence, stop in zip(
384+
generated_sequences, is_max_words_at_reached
385+
)
386+
]
387+
generated_sequences = [
388+
self.strip_max_words_sequences(sequence, max_words)
389+
if stop
390+
else sequence
391+
for sequence, stop in zip(
392+
generated_sequences, is_max_words_at_reached
393+
)
394+
]
343395
if stop_sequences:
344396
is_stop_at_reached = [
345397
stop
@@ -473,16 +525,36 @@ def _format(self, sequences):
473525
else:
474526
return self.format_sequence(sequences)
475527

528+
@staticmethod
529+
def reconstruct_till_max_words(sequence: str, max_words: Optional[int]) -> str:
530+
if max_words is not None:
531+
if len(sequence.split()) > max_words:
532+
matches = re.findall(r"(\s*\S+)(\s*)", sequence)
533+
return "".join(
534+
word + whitespace for word, whitespace in matches[:max_words]
535+
).rstrip()
536+
537+
return sequence
538+
476539
def __call__(
477540
self,
478541
prompts: Union[str, List[str]],
479542
max_tokens: Optional[int] = None,
543+
max_words: Optional[int] = None,
480544
stop_at: Optional[Union[str, List[str]]] = None,
481545
seed: Optional[int] = None,
482546
**model_specific_params,
483547
):
484548
"""Generate text from a prompt of list of prompts."""
485549

550+
# If we have max_words but no max_tokens, let's put a limit on the number of tokens
551+
# so that we reduce the generation time and do not exceed context length if
552+
# no stop token is met.
553+
# A high estimation of average number of tokens per word in a multilanguage
554+
# context is 2, let's take some precaution and increase it a bit to 3
555+
if max_words and max_tokens is None:
556+
max_tokens = 3 * max_words
557+
486558
generation_params = self.prepare_generation_parameters(
487559
max_tokens, stop_at, seed
488560
)
@@ -495,6 +567,13 @@ def __call__(
495567
**model_specific_params,
496568
)
497569

570+
if isinstance(completions, str):
571+
completions = self.reconstruct_till_max_words(completions, max_words)
572+
else:
573+
completions = [
574+
self.reconstruct_till_max_words(seq, max_words) for seq in completions
575+
]
576+
498577
return self._format(completions)
499578

500579
def stream(

tests/generate/test_generate.py

+11
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,17 @@ def test_generate_text(request, model_fixture, sampler_name):
245245
assert isinstance(res, str)
246246

247247

248+
@pytest.mark.parametrize("sampler_name", ("greedy", "multinomial", "beam_search"))
249+
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
250+
def test_generate_text_max_words(request, model_fixture, sampler_name):
251+
max_words = 5
252+
model = request.getfixturevalue(model_fixture)
253+
generator = generate.text(model, getattr(samplers, sampler_name)())
254+
with enforce_not_implemented(model_fixture, sampler_name):
255+
res = generator("Write a long sentence", max_words=max_words)
256+
assert len(res.split()) <= max_words
257+
258+
248259
@pytest.mark.parametrize("pattern", REGEX_PATTERNS)
249260
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
250261
def test_generate_regex(request, model_fixture, pattern):

tests/generate/test_generator.py

+88
Original file line numberDiff line numberDiff line change
@@ -495,3 +495,91 @@ def test_expand_attention_masks(attention_masks, ancestors, expected_result):
495495
def test_bias_logits(logits, indices_to_mask, expected):
496496
masked_logits = bias_logits(logits, indices_to_mask)
497497
assert torch.equal(masked_logits, expected)
498+
499+
500+
def test_generator_max_words():
501+
class MockFSM:
502+
first_state = 0
503+
504+
def get_next_state(self, state, next_token_ids):
505+
return 4
506+
507+
def get_next_instruction(self, *_):
508+
return Generate([4])
509+
510+
def is_final_state(self, _):
511+
return False # let's generate tokens for ever
512+
513+
def copy(self):
514+
return self
515+
516+
class MockTokenizer:
517+
def encode(self, _):
518+
# Input: "test"
519+
return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1]])
520+
521+
def decode(self, tokens):
522+
return [" ".join(["test" for _ in tokens[0]])]
523+
524+
class MockModel:
525+
def __init__(self):
526+
self.tokenizer = MockTokenizer()
527+
528+
def __call__(*_):
529+
return torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.float), None
530+
531+
class sampler:
532+
def __init__(self):
533+
self.samples = 1
534+
535+
def __call__(self, biased_logits, *_):
536+
return torch.argmax(biased_logits, keepdims=True), torch.tensor([0]), None
537+
538+
generator = SequenceGenerator(MockFSM(), MockModel(), sampler(), "cpu")
539+
result = generator("test", max_words=5)
540+
assert result == "test test test test test"
541+
542+
543+
def test_generator_max_tokens_from_max_words():
544+
class MockFSM:
545+
first_state = 0
546+
547+
def get_next_state(self, state, next_token_ids):
548+
return 4
549+
550+
def get_next_instruction(self, *_):
551+
return Generate([4])
552+
553+
def is_final_state(self, _):
554+
return False # let's generate tokens for ever
555+
556+
def copy(self):
557+
return self
558+
559+
class MockTokenizer:
560+
def encode(self, _):
561+
# Input: "test"
562+
return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1]])
563+
564+
def decode(self, tokens):
565+
return [
566+
"123456789"[: len(tokens[0])]
567+
] # not generating any word seperated by white space
568+
569+
class MockModel:
570+
def __init__(self):
571+
self.tokenizer = MockTokenizer()
572+
573+
def __call__(*_):
574+
return torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.float), None
575+
576+
class sampler:
577+
def __init__(self):
578+
self.samples = 1
579+
580+
def __call__(self, biased_logits, *_):
581+
return torch.argmax(biased_logits, keepdims=True), torch.tensor([0]), None
582+
583+
generator = SequenceGenerator(MockFSM(), MockModel(), sampler(), "cpu")
584+
result = generator("test", max_words=2) # should generate max_words * 3 tokens
585+
assert result == "123456"

0 commit comments

Comments
 (0)