Skip to content

Commit 238910d

Browse files
committed
feat(generator): add maximum number of words limit in generation
1 parent 36f1bf2 commit 238910d

File tree

3 files changed

+186
-8
lines changed

3 files changed

+186
-8
lines changed

outlines/generate/api.py

+87-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime
2+
import re
23
from copy import copy
34
from dataclasses import dataclass
45
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
@@ -82,9 +83,18 @@ def is_stop_sequence_found(
8283
]
8384
)
8485

85-
def strip_stop_sequences(
86-
self, sequence: str, stop_sequences: Optional[List[str]]
87-
) -> str:
86+
@staticmethod
87+
def strip_max_words_sequences(sequence: str, max_words: Optional[int]) -> str:
88+
if max_words is not None:
89+
splits = sequence.split()
90+
if len(splits) > max_words:
91+
last_word = splits[-1]
92+
sequence = sequence.rstrip(last_word).rstrip()
93+
94+
return sequence
95+
96+
@staticmethod
97+
def strip_stop_sequences(sequence: str, stop_sequences: Optional[List[str]]) -> str:
8898
"""Remove the stop sequences from the generated sequences.
8999
90100
Parameters
@@ -131,6 +141,7 @@ def __call__(
131141
self,
132142
prompts: Union[str, List[str]],
133143
max_tokens: Optional[int] = None,
144+
max_words: Optional[int] = None,
134145
stop_at: Optional[Union[str, List[str]]] = None,
135146
rng: Optional["torch.Generator"] = None,
136147
) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]:
@@ -148,7 +159,12 @@ def __call__(
148159
generating the first token.
149160
max_tokens
150161
An integer representing maximum number of tokens that will be generated
151-
(per prompt)
162+
(per prompt). If both `max_tokens` and `max_words` are passed, it will
163+
stop when the first one is reached
164+
max_words
165+
An integer representing maximum number of words that will be generated
166+
(per prompt). If both `max_tokens` and `max_words` are passed, it will
167+
stop when the first one is reached
152168
stop_at
153169
A string or list of strings at which the text generated will stop
154170
rng
@@ -203,16 +219,29 @@ def __call__(
203219
rng=rng,
204220
)
205221

222+
# If we have max_words but no max_tokens, let's put a limit on the number of tokens
223+
# so that we reduce the generation time and do not exceed context length if
224+
# no stop token is met.
225+
# A high estimation of average number of tokens per word in a multilanguage
226+
# context is 2, let's take some precaution and increase it a bit to 3
227+
if max_words and max_tokens is None:
228+
max_tokens = 3 * max_words
229+
206230
while True:
207231
try:
208232
last_state = next(states)
209-
if max_tokens or stop_sequences:
233+
if max_tokens or max_words or stop_sequences:
210234
token_ids = last_state.token_ids
211235
generated_token_ids = self.get_generated_token_ids(
212236
prompt_token_ids, token_ids
213237
)
214238
if max_tokens and len(generated_token_ids[0]) >= max_tokens:
215239
break
240+
if max_words and all(
241+
len(sentence.split()) > max_words
242+
for sentence in self.tokenizer.decode(generated_token_ids)
243+
):
244+
break
216245
if stop_sequences and self.is_stop_sequence_found(
217246
self.tokenizer.decode(generated_token_ids), stop_sequences
218247
):
@@ -224,9 +253,13 @@ def __call__(
224253
generated_token_ids = self.get_generated_token_ids(prompt_token_ids, token_ids)
225254

226255
generated = self.tokenizer.decode(generated_token_ids)
256+
max_words_stripped = [
257+
self.strip_max_words_sequences(sequence, max_words)
258+
for sequence in generated
259+
]
227260
stripped = [
228261
self.strip_stop_sequences(sequence, stop_sequences)
229-
for sequence in generated
262+
for sequence in max_words_stripped
230263
]
231264
formatted = [self.format_sequence(sequence) for sequence in stripped]
232265

@@ -249,6 +282,7 @@ def stream(
249282
self,
250283
prompts: Union[str, List[str]],
251284
max_tokens: Optional[int] = None,
285+
max_words: Optional[int] = None,
252286
stop_at: Optional[Union[str, List[str]]] = None,
253287
rng: Optional["torch.Generator"] = None,
254288
) -> Iterator[Union[List[str], str, List[List[str]]]]:
@@ -329,9 +363,12 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
329363
] * num_samples
330364
num_generated = 0
331365
is_stop_at_reached = [False for _ in range(batch_size)] * num_samples
366+
is_max_words_at_reached = [False for _ in range(batch_size)] * num_samples
332367
while True:
333-
if (max_tokens and num_generated >= max_tokens) or all(
334-
is_stop_at_reached
368+
if (
369+
(max_tokens and num_generated >= max_tokens)
370+
or all(is_stop_at_reached)
371+
or all(is_max_words_at_reached)
335372
):
336373
return
337374
try:
@@ -341,6 +378,21 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
341378
return
342379
generated_token_ids = sequence.token_ids[:, -num_generated:]
343380
generated_sequences = self.tokenizer.decode(generated_token_ids)
381+
if max_words is not None:
382+
is_max_words_at_reached = [
383+
stop or len(generated_sequence.split()) > max_words
384+
for generated_sequence, stop in zip(
385+
generated_sequences, is_max_words_at_reached
386+
)
387+
]
388+
generated_sequences = [
389+
self.strip_max_words_sequences(sequence, max_words)
390+
if stop
391+
else sequence
392+
for sequence, stop in zip(
393+
generated_sequences, is_max_words_at_reached
394+
)
395+
]
344396
if stop_sequences:
345397
is_stop_at_reached = [
346398
stop
@@ -487,16 +539,36 @@ def _format(self, sequences):
487539
else:
488540
return self.format_sequence(sequences)
489541

542+
@staticmethod
543+
def reconstruct_till_max_words(sequence: str, max_words: Optional[int]) -> str:
544+
if max_words is not None:
545+
if len(sequence.split()) > max_words:
546+
matches = re.findall(r"(\s*\S+)(\s*)", sequence)
547+
return "".join(
548+
word + whitespace for word, whitespace in matches[:max_words]
549+
).rstrip()
550+
551+
return sequence
552+
490553
def __call__(
491554
self,
492555
prompts: Union[str, List[str]],
493556
max_tokens: Optional[int] = None,
557+
max_words: Optional[int] = None,
494558
stop_at: Optional[Union[str, List[str]]] = None,
495559
seed: Optional[int] = None,
496560
**model_specific_params,
497561
):
498562
"""Generate text from a prompt of list of prompts."""
499563

564+
# If we have max_words but no max_tokens, let's put a limit on the number of tokens
565+
# so that we reduce the generation time and do not exceed context length if
566+
# no stop token is met.
567+
# A high estimation of average number of tokens per word in a multilanguage
568+
# context is 2, let's take some precaution and increase it a bit to 3
569+
if max_words and max_tokens is None:
570+
max_tokens = 3 * max_words
571+
500572
generation_params = self.prepare_generation_parameters(
501573
max_tokens, stop_at, seed
502574
)
@@ -509,6 +581,13 @@ def __call__(
509581
**model_specific_params,
510582
)
511583

584+
if isinstance(completions, str):
585+
completions = self.reconstruct_till_max_words(completions, max_words)
586+
else:
587+
completions = [
588+
self.reconstruct_till_max_words(seq, max_words) for seq in completions
589+
]
590+
512591
return self._format(completions)
513592

514593
def stream(

tests/generate/test_generate.py

+11
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,17 @@ def test_generate_text(request, model_fixture, sampler_name):
232232
assert isinstance(res, str)
233233

234234

235+
@pytest.mark.parametrize("sampler_name", ("greedy", "multinomial", "beam_search"))
236+
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
237+
def test_generate_text_max_words(request, model_fixture, sampler_name):
238+
max_words = 5
239+
model = request.getfixturevalue(model_fixture)
240+
generator = generate.text(model, getattr(samplers, sampler_name)())
241+
with enforce_not_implemented(model_fixture, sampler_name):
242+
res = generator("Write a long sentence", max_words=max_words)
243+
assert len(res.split()) <= max_words
244+
245+
235246
@pytest.mark.parametrize("pattern", REGEX_PATTERNS)
236247
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
237248
def test_generate_regex(request, model_fixture, pattern):

tests/generate/test_generator.py

+88
Original file line numberDiff line numberDiff line change
@@ -495,3 +495,91 @@ def test_expand_attention_masks(attention_masks, ancestors, expected_result):
495495
def test_bias_logits(logits, indices_to_mask, expected):
496496
masked_logits = bias_logits(logits, indices_to_mask)
497497
assert torch.equal(masked_logits, expected)
498+
499+
500+
def test_generator_max_words():
501+
class MockFSM:
502+
first_state = 0
503+
504+
def get_next_state(self, state, next_token_ids):
505+
return 4
506+
507+
def get_next_instruction(self, *_):
508+
return Generate([4])
509+
510+
def is_final_state(self, _):
511+
return False # let's generate tokens for ever
512+
513+
def copy(self):
514+
return self
515+
516+
class MockTokenizer:
517+
def encode(self, _):
518+
# Input: "test"
519+
return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1]])
520+
521+
def decode(self, tokens):
522+
return [" ".join(["test" for _ in tokens[0]])]
523+
524+
class MockModel:
525+
def __init__(self):
526+
self.tokenizer = MockTokenizer()
527+
528+
def __call__(*_):
529+
return torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.float), None
530+
531+
class sampler:
532+
def __init__(self):
533+
self.samples = 1
534+
535+
def __call__(self, biased_logits, *_):
536+
return torch.argmax(biased_logits, keepdims=True), torch.tensor([0]), None
537+
538+
generator = SequenceGenerator(MockFSM(), MockModel(), sampler(), "cpu")
539+
result = generator("test", max_words=5)
540+
assert result == "test test test test test"
541+
542+
543+
def test_generator_max_tokens_from_max_words():
544+
class MockFSM:
545+
first_state = 0
546+
547+
def get_next_state(self, state, next_token_ids):
548+
return 4
549+
550+
def get_next_instruction(self, *_):
551+
return Generate([4])
552+
553+
def is_final_state(self, _):
554+
return False # let's generate tokens for ever
555+
556+
def copy(self):
557+
return self
558+
559+
class MockTokenizer:
560+
def encode(self, _):
561+
# Input: "test"
562+
return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1]])
563+
564+
def decode(self, tokens):
565+
return [
566+
"123456789"[: len(tokens[0])]
567+
] # not generating any word seperated by white space
568+
569+
class MockModel:
570+
def __init__(self):
571+
self.tokenizer = MockTokenizer()
572+
573+
def __call__(*_):
574+
return torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.float), None
575+
576+
class sampler:
577+
def __init__(self):
578+
self.samples = 1
579+
580+
def __call__(self, biased_logits, *_):
581+
return torch.argmax(biased_logits, keepdims=True), torch.tensor([0]), None
582+
583+
generator = SequenceGenerator(MockFSM(), MockModel(), sampler(), "cpu")
584+
result = generator("test", max_words=2) # should generate max_words * 3 tokens
585+
assert result == "123456"

0 commit comments

Comments
 (0)