1
- from unittest .mock import Mock , patch
1
+ from unittest .mock import MagicMock , patch
2
2
3
3
import pytest
4
4
from pydantic import BaseModel
5
5
6
6
from outlines import Outline
7
7
8
8
9
+ class IterableMock (MagicMock ):
10
+ def __getattr__ (self , name ):
11
+ result = MagicMock ()
12
+ result .__iter__ .return_value = iter ([])
13
+ return result
14
+
15
+
9
16
class OutputModel (BaseModel ):
10
17
result : int
11
18
@@ -15,8 +22,8 @@ def template(a: int) -> str:
15
22
16
23
17
24
def test_outline ():
18
- mock_model = Mock ()
19
- mock_generator = Mock ()
25
+ mock_model = IterableMock ()
26
+ mock_generator = MagicMock ()
20
27
mock_generator .return_value = '{"result": 6}'
21
28
with patch ("outlines.generate.json" , return_value = mock_generator ):
22
29
outline_instance = Outline (mock_model , template , OutputModel )
@@ -26,8 +33,8 @@ def test_outline():
26
33
27
34
28
35
def test_outline_with_json_schema ():
29
- mock_model = Mock ()
30
- mock_generator = Mock ()
36
+ mock_model = IterableMock ()
37
+ mock_generator = MagicMock ()
31
38
mock_generator .return_value = '{"result": 6}'
32
39
with patch ("outlines.generate.json" , return_value = mock_generator ):
33
40
outline_instance = Outline (
@@ -40,14 +47,14 @@ def test_outline_with_json_schema():
40
47
41
48
42
49
def test_invalid_output_type ():
43
- mock_model = Mock ()
50
+ mock_model = IterableMock ()
44
51
with pytest .raises (TypeError ):
45
52
Outline (mock_model , template , int )
46
53
47
54
48
55
def test_invalid_json_response ():
49
- mock_model = Mock ()
50
- mock_generator = Mock ()
56
+ mock_model = IterableMock ()
57
+ mock_generator = MagicMock ()
51
58
mock_generator .return_value = "invalid json"
52
59
with patch ("outlines.generate.json" , return_value = mock_generator ):
53
60
outline_instance = Outline (mock_model , template , OutputModel )
@@ -56,7 +63,7 @@ def test_invalid_json_response():
56
63
57
64
58
65
def test_invalid_json_schema ():
59
- mock_model = Mock ()
66
+ mock_model = IterableMock ()
60
67
invalid_json_schema = (
61
68
'{"type": "object", "properties": {"result": {"type": "invalid_type"}}}'
62
69
)
0 commit comments