Skip to content

Commit ddd5409

Browse files
committed
Use generate.json in Outlines
And the `output_type` should now be a Pydantic model or a JSON Schema str
1 parent ac77adb commit ddd5409

File tree

3 files changed

+55
-70
lines changed

3 files changed

+55
-70
lines changed

outlines/function.py

+3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from outlines.generate.api import SequenceGenerator
1212
from outlines.prompts import Prompt
1313

14+
# Print a deprecation message instead of raising a warning
15+
print("The 'function' module is deprecated and will be removed in a future release.")
16+
1417

1518
# FIXME: This causes all the tests to fail...
1619
warnings.warn(

outlines/outline.py

+25-16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
import ast
1+
import json
22
from dataclasses import dataclass
33

4+
from pydantic import BaseModel
5+
6+
from outlines import generate
7+
48

59
@dataclass
610
class Outline:
@@ -19,34 +23,39 @@ class Outline:
1923
2024
Examples
2125
--------
22-
from outlines import models
26+
from pydantic import BaseModel
27+
from outlines import models, Outline
28+
29+
class OutputModel(BaseModel):
30+
result: int
2331
2432
model = models.transformers("gpt2")
2533
26-
def template(a: int) -> str:
27-
return f"What is 2 times {a}?"
34+
def template(a: int) -> str:
35+
return f"What is 2 times {a}?"
2836
29-
fn = Outline(model, template, int)
37+
fn = Outline(model, template, OutputModel)
3038
31-
result = fn(3)
32-
print(result) # Expected output: 6
39+
result = fn(3)
40+
print(result) # Expected output: OutputModel(result=6)
3341
"""
3442

3543
def __init__(self, model, template, output_type):
36-
self.model = model
44+
if not (isinstance(output_type, str) or issubclass(output_type, BaseModel)):
45+
raise TypeError(
46+
"output_type must be a Pydantic model or a JSON Schema string"
47+
)
3748
self.template = template
3849
self.output_type = output_type
50+
self.generator = generate.json(model, output_type)
3951

4052
def __call__(self, *args):
4153
prompt = self.template(*args)
42-
response = self.model.generate(prompt)
54+
response = self.generator(prompt)
4355
try:
44-
parsed_response = ast.literal_eval(response.strip())
45-
if isinstance(parsed_response, self.output_type):
46-
return parsed_response
47-
else:
48-
raise ValueError(
49-
f"Response type {type(parsed_response)} does not match expected type {self.output_type}"
50-
)
56+
if isinstance(self.output_type, str):
57+
return json.loads(response)
58+
return self.output_type.model_validate_json(response)
5159
except (ValueError, SyntaxError):
60+
# If `outlines.generate.json` works as intended, this error should never be raised.
5261
raise ValueError(f"Unable to parse response: {response.strip()}")

tests/test_outline.py

+27-54
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,41 @@
1-
from unittest.mock import MagicMock
1+
from unittest.mock import Mock, patch
22

3-
import pytest
3+
from pydantic import BaseModel
44

5-
from outlines.outline import Outline
5+
from outlines import Outline
66

77

8-
def test_outline_int_output():
9-
model = MagicMock()
10-
model.generate.return_value = "6"
8+
class OutputModel(BaseModel):
9+
result: int
1110

12-
def template(a: int) -> str:
13-
return f"What is 2 times {a}?"
1411

15-
fn = Outline(model, template, int)
16-
result = fn(3)
17-
assert result == 6
12+
def template(a: int) -> str:
13+
return f"What is 2 times {a}?"
1814

1915

20-
def test_outline_str_output():
21-
model = MagicMock()
22-
model.generate.return_value = "'Hello, world!'"
16+
def test_outline():
17+
mock_model = Mock()
18+
mock_generator = Mock()
19+
mock_generator.return_value = '{"result": 6}'
2320

24-
def template(a: int) -> str:
25-
return f"Say 'Hello, world!' {a} times"
21+
with patch("outlines.generate.json", return_value=mock_generator):
22+
outline_instance = Outline(mock_model, template, OutputModel)
23+
result = outline_instance(3)
2624

27-
fn = Outline(model, template, str)
28-
result = fn(1)
29-
assert result == "Hello, world!"
25+
assert result.result == 6
3026

3127

32-
def test_outline_str_input():
33-
model = MagicMock()
34-
model.generate.return_value = "'Hi, Mark!'"
28+
def test_outline_with_json_schema():
29+
mock_model = Mock()
30+
mock_generator = Mock()
31+
mock_generator.return_value = '{"result": 6}'
3532

36-
def template(a: str) -> str:
37-
return f"Say hi to {a}"
33+
with patch("outlines.generate.json", return_value=mock_generator):
34+
outline_instance = Outline(
35+
mock_model,
36+
template,
37+
'{"type": "object", "properties": {"result": {"type": "integer"}}}',
38+
)
39+
result = outline_instance(3)
3840

39-
fn = Outline(model, template, str)
40-
result = fn(1)
41-
assert result == "Hi, Mark!"
42-
43-
44-
def test_outline_invalid_output():
45-
model = MagicMock()
46-
model.generate.return_value = "not a number"
47-
48-
def template(a: int) -> str:
49-
return f"What is 2 times {a}?"
50-
51-
fn = Outline(model, template, int)
52-
with pytest.raises(ValueError):
53-
fn(3)
54-
55-
56-
def test_outline_mismatched_output_type():
57-
model = MagicMock()
58-
model.generate.return_value = "'Hello, world!'"
59-
60-
def template(a: int) -> str:
61-
return f"What is 2 times {a}?"
62-
63-
fn = Outline(model, template, int)
64-
with pytest.raises(
65-
ValueError,
66-
match="Unable to parse response: 'Hello, world!'",
67-
):
68-
fn(3)
41+
assert result["result"] == 6

0 commit comments

Comments
 (0)