3
3
4
4
import pytest
5
5
from pytz import UTC
6
- import inspect
7
6
8
7
from faculty .clients .experiment import (
9
8
Experiment ,
10
9
ExperimentRun ,
11
10
ExperimentRunStatus ,
12
11
ListExperimentRunsResponse ,
13
12
Metric ,
13
+ Page ,
14
14
Pagination ,
15
15
Param ,
16
16
SingleFilter ,
55
55
deleted_at = DELETED_AT ,
56
56
)
57
57
58
+ FILTER = SingleFilter (
59
+ SingleFilterBy .EXPERIMENT_ID , None , SingleFilterOperator .EQUAL_TO , "2"
60
+ )
61
+
62
+ SORT = [Sort (SortBy .METRIC , "metric_key" , SortOrder .ASC )]
63
+
58
64
EXPERIMENT_RUN = ExperimentRun (
59
65
id = EXPERIMENT_RUN_ID ,
60
66
run_number = EXPERIMENT_RUN_NUMBER ,
70
76
params = [PARAM ],
71
77
metrics = [METRIC ],
72
78
)
73
-
74
- PAGINATION = Pagination (start = 20 , size = 10 , previous = None , next = None )
75
-
79
+ PAGINATION = Pagination (0 , 1 , None , None )
76
80
LIST_EXPERIMENT_RUNS_RESPONSE = ListExperimentRunsResponse (
77
81
runs = [EXPERIMENT_RUN ], pagination = PAGINATION
78
82
)
79
-
80
- FILTER = SingleFilter (
81
- SingleFilterBy .EXPERIMENT_ID , None , SingleFilterOperator .EQUAL_TO , "2"
82
- )
83
-
84
- SORT = [Sort (SortBy .METRIC , "metric_key" , SortOrder .ASC )]
85
-
86
-
87
- def test_experiment_run_query (mocker ):
88
-
89
- experiment_client_mock = mocker .MagicMock ()
90
- experiment_client_mock .query_runs .return_value = (
91
- LIST_EXPERIMENT_RUNS_RESPONSE
83
+ EXPECTED_RUNS = [
84
+ FacultyExperimentRun (
85
+ id = EXPERIMENT_RUN_ID ,
86
+ run_number = EXPERIMENT_RUN_NUMBER ,
87
+ name = EXPERIMENT_RUN_NAME ,
88
+ parent_run_id = PARENT_RUN_ID ,
89
+ experiment_id = EXPERIMENT .id ,
90
+ artifact_location = "faculty:" ,
91
+ status = ExperimentRunStatus .RUNNING ,
92
+ started_at = RUN_STARTED_AT ,
93
+ ended_at = RUN_ENDED_AT ,
94
+ deleted_at = DELETED_AT ,
95
+ tags = [TAG ],
96
+ params = [PARAM ],
97
+ metrics = [METRIC ],
92
98
)
93
- mocker . patch ( "faculty.client" , return_value = experiment_client_mock )
99
+ ]
94
100
95
- expected_response = FacultyExperimentRun (
101
+ PAGINATION_MULTIPLE_1 = Pagination (0 , 1 , None , Page (1 , 1 ))
102
+ LIST_EXPERIMENT_RUNS_RESPONSE_MULTIPLE_1 = ListExperimentRunsResponse (
103
+ runs = [EXPERIMENT_RUN ], pagination = PAGINATION_MULTIPLE_1
104
+ )
105
+ EXPERIMENT_RUN_MULTIPLE_2 = ExperimentRun (
106
+ id = 7 ,
107
+ run_number = EXPERIMENT_RUN_NUMBER ,
108
+ name = EXPERIMENT_RUN_NAME ,
109
+ parent_run_id = PARENT_RUN_ID ,
110
+ experiment_id = EXPERIMENT .id ,
111
+ artifact_location = "faculty:" ,
112
+ status = ExperimentRunStatus .RUNNING ,
113
+ started_at = RUN_STARTED_AT ,
114
+ ended_at = RUN_ENDED_AT ,
115
+ deleted_at = DELETED_AT ,
116
+ tags = [TAG ],
117
+ params = [PARAM ],
118
+ metrics = [METRIC ],
119
+ )
120
+ PAGINATION_MULTIPLE_2 = Pagination (1 , 1 , Page (0 , 1 ), None )
121
+ LIST_EXPERIMENT_RUNS_RESPONSE_MULTIPLE_2 = ListExperimentRunsResponse (
122
+ runs = [EXPERIMENT_RUN_MULTIPLE_2 ], pagination = PAGINATION_MULTIPLE_2
123
+ )
124
+ EXPECTED_RUNS_2 = [
125
+ FacultyExperimentRun (
96
126
id = EXPERIMENT_RUN_ID ,
97
127
run_number = EXPERIMENT_RUN_NUMBER ,
98
128
name = EXPERIMENT_RUN_NAME ,
@@ -106,27 +136,65 @@ def test_experiment_run_query(mocker):
106
136
tags = [TAG ],
107
137
params = [PARAM ],
108
138
metrics = [METRIC ],
139
+ ),
140
+ FacultyExperimentRun (
141
+ id = 7 ,
142
+ run_number = EXPERIMENT_RUN_NUMBER ,
143
+ name = EXPERIMENT_RUN_NAME ,
144
+ parent_run_id = PARENT_RUN_ID ,
145
+ experiment_id = EXPERIMENT .id ,
146
+ artifact_location = "faculty:" ,
147
+ status = ExperimentRunStatus .RUNNING ,
148
+ started_at = RUN_STARTED_AT ,
149
+ ended_at = RUN_ENDED_AT ,
150
+ deleted_at = DELETED_AT ,
151
+ tags = [TAG ],
152
+ params = [PARAM ],
153
+ metrics = [METRIC ],
154
+ ),
155
+ ]
156
+
157
+
158
+ @pytest .mark .parametrize (
159
+ "query_runs_side_effects,expected_runs" ,
160
+ [
161
+ [[LIST_EXPERIMENT_RUNS_RESPONSE ], EXPECTED_RUNS ],
162
+ [
163
+ [
164
+ LIST_EXPERIMENT_RUNS_RESPONSE_MULTIPLE_1 ,
165
+ LIST_EXPERIMENT_RUNS_RESPONSE_MULTIPLE_2 ,
166
+ ],
167
+ EXPECTED_RUNS_2 ,
168
+ ],
169
+ ],
170
+ )
171
+ def test_experiment_run_query_single_call (
172
+ mocker , query_runs_side_effects , expected_runs
173
+ ):
174
+ experiment_client_mock = mocker .MagicMock ()
175
+ experiment_client_mock .query_runs = mocker .MagicMock (
176
+ side_effect = query_runs_side_effects
109
177
)
178
+ mocker .patch ("faculty.client" , return_value = experiment_client_mock )
110
179
111
180
response = FacultyExperimentRun .query (PROJECT_ID , FILTER , SORT )
112
181
113
182
assert isinstance (response , ExperimentRunQueryResult )
114
- returned_run = list (response )[0 ]
115
- assert isinstance (returned_run , FacultyExperimentRun )
116
- assert all (
183
+ returned_runs = list (response )
184
+ for expected_run , returned_run in zip (expected_runs , returned_runs ):
185
+ assert isinstance (returned_run , FacultyExperimentRun )
186
+ assert _are_runs_equal (expected_run , returned_run )
187
+
188
+
189
+ def _are_runs_equal (this , that ):
190
+ return all (
117
191
list (
118
192
i == j
119
193
for i , j in (
120
194
list (
121
195
zip (
122
- [
123
- getattr (returned_run , attr )
124
- for attr in returned_run .__dict__ .keys ()
125
- ],
126
- [
127
- getattr (expected_response , attr )
128
- for attr in expected_response .__dict__ .keys ()
129
- ],
196
+ [getattr (this , attr ) for attr in this .__dict__ .keys ()],
197
+ [getattr (that , attr ) for attr in that .__dict__ .keys ()],
130
198
)
131
199
)
132
200
)
0 commit comments