Skip to content

Commit 3cd4ce6

Browse files
Hailey FongElias Benussi
Hailey Fong
authored and
Elias Benussi
committed
Add test for query in experiments.py
1 parent f78c019 commit 3cd4ce6

File tree

2 files changed

+41
-62
lines changed

2 files changed

+41
-62
lines changed

faculty/experiments.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,18 @@ def get_runs():
5454
client = faculty.client("experiment")
5555

5656
response = client.query_runs(project_id, filter, sort)
57-
return map(cls._from_client_model, response.runs)
58-
# yield from map(cls._from_client_model, response.runs)
57+
# return map(cls._from_client_model, response.runs)
58+
yield from map(cls._from_client_model, response.runs)
5959

60-
# while response.pagination.next is not None:
61-
# response = client.query_runs(
62-
# project_id,
63-
# filter,
64-
# sort,
65-
# start=response.pagination.next.start,
66-
# limit=response.pagination.next.limit,
67-
# )
68-
# yield from map(cls._from_client_model, response.runs)
60+
while response.pagination.next is not None:
61+
response = client.query_runs(
62+
project_id,
63+
filter,
64+
sort,
65+
start=response.pagination.next.start,
66+
limit=response.pagination.next.limit,
67+
)
68+
yield from map(cls._from_client_model, response.runs)
6969

7070
# Open question:
7171
# Should we evalutate the entire set of runs before returning the

tests/test_experiments.py

+30-51
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55
from pytz import UTC
6+
import inspect
67

78
from faculty.clients.experiment import (
89
Experiment,
@@ -18,12 +19,12 @@
1819
Sort,
1920
SortBy,
2021
SortOrder,
21-
Tag
22+
Tag,
2223
)
2324

2425
from faculty.experiments import (
2526
ExperimentRun as FacultyExperimentRun,
26-
ExperimentRunQueryResult
27+
ExperimentRunQueryResult,
2728
)
2829

2930

@@ -70,22 +71,14 @@
7071
metrics=[METRIC],
7172
)
7273

73-
PAGINATION = Pagination(
74-
start=20,
75-
size=10,
76-
previous=None,
77-
next=None,
78-
)
74+
PAGINATION = Pagination(start=20, size=10, previous=None, next=None)
7975

8076
LIST_EXPERIMENT_RUNS_RESPONSE = ListExperimentRunsResponse(
8177
runs=[EXPERIMENT_RUN], pagination=PAGINATION
8278
)
8379

8480
FILTER = SingleFilter(
85-
SingleFilterBy.EXPERIMENT_ID,
86-
None,
87-
SingleFilterOperator.EQUAL_TO,
88-
"2"
81+
SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, "2"
8982
)
9083

9184
SORT = [Sort(SortBy.METRIC, "metric_key", SortOrder.ASC)]
@@ -94,12 +87,12 @@
9487
def test_experiment_run_query(mocker):
9588

9689
experiment_client_mock = mocker.MagicMock()
97-
experiment_client_mock.query_runs.return_value = LIST_EXPERIMENT_RUNS_RESPONSE
98-
mocker.patch(
99-
"faculty.client", return_value=experiment_client_mock
90+
experiment_client_mock.query_runs.return_value = (
91+
LIST_EXPERIMENT_RUNS_RESPONSE
10092
)
93+
mocker.patch("faculty.client", return_value=experiment_client_mock)
10194

102-
expected_response = ExperimentRun(
95+
expected_response = FacultyExperimentRun(
10396
id=EXPERIMENT_RUN_ID,
10497
run_number=EXPERIMENT_RUN_NUMBER,
10598
name=EXPERIMENT_RUN_NAME,
@@ -112,44 +105,30 @@ def test_experiment_run_query(mocker):
112105
deleted_at=DELETED_AT,
113106
tags=[TAG],
114107
params=[PARAM],
115-
metrics=[METRIC]
108+
metrics=[METRIC],
116109
)
117110

118111
response = FacultyExperimentRun.query(PROJECT_ID, FILTER, SORT)
119112

120113
assert isinstance(response, ExperimentRunQueryResult)
121114
returned_run = list(response)[0]
122-
123-
124-
# assert all(i == j for i, j in list(zip([getattr(l, attr) for attr in dir(l)], [getattr(expected_response, attr) for attr in dir(expected_response)])))
125-
126-
# response_schema_mock = mocker.patch(
127-
# "faculty.clients.experiment.ListExperimentRunsResponseSchema"
128-
# )
129-
# request_schema_mock = mocker.patch(
130-
# "faculty.clients.experiment.QueryRunsSchema"
131-
# )
132-
# dump_mock = request_schema_mock.return_value.dump
133-
134-
# test_filter = SingleFilter(
135-
# SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, "2"
136-
# )
137-
# test_sort = [Sort(SortBy.METRIC, "metric_key", SortOrder.ASC)]
138-
139-
# client = ExperimentClient(mocker.Mock())
140-
# list_result = client.query_runs(
141-
# PROJECT_ID, filter=test_filter, sort=test_sort, start=20, limit=10
142-
# )
143-
144-
# assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE
145-
146-
# request_schema_mock.assert_called_once_with()
147-
# dump_mock.assert_called_once_with(
148-
# QueryRuns(test_filter, test_sort, Page(20, 10))
149-
# )
150-
# response_schema_mock.assert_called_once_with()
151-
# ExperimentClient._post.assert_called_once_with(
152-
# "/project/{}/run/query".format(PROJECT_ID),
153-
# response_schema_mock.return_value,
154-
# json=dump_mock.return_value,
155-
# )
115+
assert isinstance(returned_run, FacultyExperimentRun)
116+
assert all(
117+
list(
118+
i == j
119+
for i, j in (
120+
list(
121+
zip(
122+
[
123+
getattr(returned_run, attr)
124+
for attr in returned_run.__dict__.keys()
125+
],
126+
[
127+
getattr(expected_response, attr)
128+
for attr in expected_response.__dict__.keys()
129+
],
130+
)
131+
)
132+
)
133+
)
134+
)

0 commit comments

Comments
 (0)