|
1 |
| -from unittest.mock import MagicMock |
| 1 | +from unittest.mock import Mock, patch |
2 | 2 |
|
3 |
| -import pytest |
| 3 | +from pydantic import BaseModel |
4 | 4 |
|
5 |
| -from outlines.outline import Outline |
| 5 | +from outlines import Outline |
6 | 6 |
|
7 | 7 |
|
8 |
| -def test_outline_int_output(): |
9 |
| - model = MagicMock() |
10 |
| - model.generate.return_value = "6" |
| 8 | +class OutputModel(BaseModel): |
| 9 | + result: int |
11 | 10 |
|
12 |
| - def template(a: int) -> str: |
13 |
| - return f"What is 2 times {a}?" |
14 | 11 |
|
15 |
| - fn = Outline(model, template, int) |
16 |
| - result = fn(3) |
17 |
| - assert result == 6 |
| 12 | +def template(a: int) -> str: |
| 13 | + return f"What is 2 times {a}?" |
18 | 14 |
|
19 | 15 |
|
20 |
| -def test_outline_str_output(): |
21 |
| - model = MagicMock() |
22 |
| - model.generate.return_value = "'Hello, world!'" |
| 16 | +def test_outline(): |
| 17 | + mock_model = Mock() |
| 18 | + mock_generator = Mock() |
| 19 | + mock_generator.return_value = '{"result": 6}' |
23 | 20 |
|
24 |
| - def template(a: int) -> str: |
25 |
| - return f"Say 'Hello, world!' {a} times" |
| 21 | + with patch("outlines.generate.json", return_value=mock_generator): |
| 22 | + outline_instance = Outline(mock_model, template, OutputModel) |
| 23 | + result = outline_instance(3) |
26 | 24 |
|
27 |
| - fn = Outline(model, template, str) |
28 |
| - result = fn(1) |
29 |
| - assert result == "Hello, world!" |
| 25 | + assert result.result == 6 |
30 | 26 |
|
31 | 27 |
|
32 |
| -def test_outline_str_input(): |
33 |
| - model = MagicMock() |
34 |
| - model.generate.return_value = "'Hi, Mark!'" |
| 28 | +def test_outline_with_json_schema(): |
| 29 | + mock_model = Mock() |
| 30 | + mock_generator = Mock() |
| 31 | + mock_generator.return_value = '{"result": 6}' |
35 | 32 |
|
36 |
| - def template(a: str) -> str: |
37 |
| - return f"Say hi to {a}" |
| 33 | + with patch("outlines.generate.json", return_value=mock_generator): |
| 34 | + outline_instance = Outline( |
| 35 | + mock_model, |
| 36 | + template, |
| 37 | + '{"type": "object", "properties": {"result": {"type": "integer"}}}', |
| 38 | + ) |
| 39 | + result = outline_instance(3) |
38 | 40 |
|
39 |
| - fn = Outline(model, template, str) |
40 |
| - result = fn(1) |
41 |
| - assert result == "Hi, Mark!" |
42 |
| - |
43 |
| - |
44 |
| -def test_outline_invalid_output(): |
45 |
| - model = MagicMock() |
46 |
| - model.generate.return_value = "not a number" |
47 |
| - |
48 |
| - def template(a: int) -> str: |
49 |
| - return f"What is 2 times {a}?" |
50 |
| - |
51 |
| - fn = Outline(model, template, int) |
52 |
| - with pytest.raises(ValueError): |
53 |
| - fn(3) |
54 |
| - |
55 |
| - |
56 |
| -def test_outline_mismatched_output_type(): |
57 |
| - model = MagicMock() |
58 |
| - model.generate.return_value = "'Hello, world!'" |
59 |
| - |
60 |
| - def template(a: int) -> str: |
61 |
| - return f"What is 2 times {a}?" |
62 |
| - |
63 |
| - fn = Outline(model, template, int) |
64 |
| - with pytest.raises( |
65 |
| - ValueError, |
66 |
| - match="Unable to parse response: 'Hello, world!'", |
67 |
| - ): |
68 |
| - fn(3) |
| 41 | + assert result["result"] == 6 |
0 commit comments