Skip to content

Commit 5ae2d09

Browse files
author
Hailey Fong
committed
Add test for experiments query function
1 parent 7ae4ece commit 5ae2d09

File tree

2 files changed

+166
-4
lines changed

2 files changed

+166
-4
lines changed

faculty/experiments.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,19 @@ def _from_client_model(cls, client_object):
4949
return cls(**client_object._asdict())
5050

5151
@classmethod
52-
def query(cls, project_id, experiment_ids=None):
52+
def query(cls, project_id, filter=None, sort=None):
5353
def get_runs():
5454
client = faculty.client("experiment")
5555

56-
response = client.list_runs(project_id, experiment_ids)
56+
response = client.query_runs(project_id, filter, sort)
57+
print(response)
5758
yield from map(cls._from_client_model, response.runs)
5859

5960
while response.pagination.next is not None:
60-
response = client.list_runs(
61+
response = client.query_runs(
6162
project_id,
62-
experiment_ids,
63+
filter,
64+
sort,
6365
start=response.pagination.next.start,
6466
limit=response.pagination.next.limit,
6567
)

tests/test_experiments.py

+160
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
from datetime import datetime
2+
from uuid import uuid4
3+
4+
import pytest
5+
from pytz import UTC
6+
7+
from faculty.clients.experiment import (
8+
Experiment,
9+
ExperimentClient,
10+
ExperimentRun,
11+
ExperimentRunStatus,
12+
ListExperimentRunsResponse,
13+
Metric,
14+
Pagination,
15+
Param,
16+
SingleFilter,
17+
SingleFilterBy,
18+
SingleFilterOperator,
19+
Sort,
20+
SortBy,
21+
SortOrder,
22+
Tag
23+
)
24+
25+
from faculty.experiments import (
26+
ExperimentRun as FacultyExperimentRun,
27+
ExperimentRunQueryResult
28+
)
29+
30+
31+
32+
PROJECT_ID = uuid4()
33+
EXPERIMENT_ID = 661
34+
EXPERIMENT_RUN_ID = uuid4()
35+
EXPERIMENT_RUN_NUMBER = 3
36+
EXPERIMENT_RUN_NAME = "run name"
37+
PARENT_RUN_ID = uuid4()
38+
RUN_STARTED_AT = datetime(2018, 3, 10, 11, 39, 12, 110000, tzinfo=UTC)
39+
RUN_ENDED_AT = datetime(2018, 3, 10, 11, 39, 15, 110000, tzinfo=UTC)
40+
CREATED_AT = datetime(2018, 3, 10, 11, 32, 6, 247000, tzinfo=UTC)
41+
LAST_UPDATED_AT = datetime(2018, 3, 10, 11, 32, 30, 172000, tzinfo=UTC)
42+
DELETED_AT = datetime(2018, 3, 10, 11, 37, 42, 482000, tzinfo=UTC)
43+
TAG = Tag(key="tag-key", value="tag-value")
44+
PARAM = Param(key="param-key", value="param-value")
45+
METRIC_KEY = "metric-key"
46+
METRIC_TIMESTAMP = datetime(2018, 3, 12, 16, 20, 22, 122000, tzinfo=UTC)
47+
METRIC = Metric(key=METRIC_KEY, value=123, timestamp=METRIC_TIMESTAMP)
48+
49+
EXPERIMENT = Experiment(
50+
id=EXPERIMENT_ID,
51+
name="experiment name",
52+
description="experiment description",
53+
artifact_location="https://example.com",
54+
created_at=CREATED_AT,
55+
last_updated_at=LAST_UPDATED_AT,
56+
deleted_at=DELETED_AT,
57+
)
58+
59+
EXPERIMENT_RUN = ExperimentRun(
60+
id=EXPERIMENT_RUN_ID,
61+
run_number=EXPERIMENT_RUN_NUMBER,
62+
name=EXPERIMENT_RUN_NAME,
63+
parent_run_id=PARENT_RUN_ID,
64+
experiment_id=EXPERIMENT.id,
65+
artifact_location="faculty:",
66+
status=ExperimentRunStatus.RUNNING,
67+
started_at=RUN_STARTED_AT,
68+
ended_at=RUN_ENDED_AT,
69+
deleted_at=DELETED_AT,
70+
tags=[TAG],
71+
params=[PARAM],
72+
metrics=[METRIC],
73+
)
74+
75+
PAGINATION = Pagination(
76+
start=20,
77+
size=10,
78+
previous=None,
79+
next=None,
80+
)
81+
82+
LIST_EXPERIMENT_RUNS_RESPONSE = ListExperimentRunsResponse(
83+
runs=[EXPERIMENT_RUN], pagination=PAGINATION
84+
)
85+
86+
FILTER = SingleFilter(
87+
SingleFilterBy.EXPERIMENT_ID,
88+
None,
89+
SingleFilterOperator.EQUAL_TO,
90+
"2"
91+
)
92+
93+
SORT = [Sort(SortBy.METRIC, "metric_key", SortOrder.ASC)]
94+
95+
def test_experiment_run_query(mocker):
96+
97+
experiment_client_mock = mocker.MagicMock()
98+
experiment_client_mock.query_runs = LIST_EXPERIMENT_RUNS_RESPONSE
99+
mocker.patch(
100+
"faculty.client", new=experiment_client_mock
101+
)
102+
103+
expected_response = FacultyExperimentRun(
104+
id=EXPERIMENT_RUN_ID,
105+
run_number=EXPERIMENT_RUN_NUMBER,
106+
name=EXPERIMENT_RUN_NAME,
107+
parent_run_id=PARENT_RUN_ID,
108+
experiment_id=EXPERIMENT.id,
109+
artifact_location="faculty:",
110+
status=ExperimentRunStatus.RUNNING,
111+
started_at=RUN_STARTED_AT,
112+
ended_at=RUN_ENDED_AT,
113+
deleted_at=DELETED_AT,
114+
tags=[TAG],
115+
params=[PARAM],
116+
metrics=[METRIC]
117+
)
118+
119+
print("hello")
120+
121+
response = FacultyExperimentRun.query(PROJECT_ID, FILTER, SORT)
122+
123+
print(response)
124+
assert isinstance(response, ExperimentRunQueryResult)
125+
# l = list(response)
126+
# l = l[0]
127+
# assert all(i==j for i,j in list(zip([getattr(l, attr) for attr in dir(l)],
128+
# [getattr(expected_response, attr) for attr in dir(expected_response)])))
129+
130+
131+
# response_schema_mock = mocker.patch(
132+
# "faculty.clients.experiment.ListExperimentRunsResponseSchema"
133+
# )
134+
# request_schema_mock = mocker.patch(
135+
# "faculty.clients.experiment.QueryRunsSchema"
136+
# )
137+
# dump_mock = request_schema_mock.return_value.dump
138+
139+
# test_filter = SingleFilter(
140+
# SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, "2"
141+
# )
142+
# test_sort = [Sort(SortBy.METRIC, "metric_key", SortOrder.ASC)]
143+
144+
# client = ExperimentClient(mocker.Mock())
145+
# list_result = client.query_runs(
146+
# PROJECT_ID, filter=test_filter, sort=test_sort, start=20, limit=10
147+
# )
148+
149+
# assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE
150+
151+
# request_schema_mock.assert_called_once_with()
152+
# dump_mock.assert_called_once_with(
153+
# QueryRuns(test_filter, test_sort, Page(20, 10))
154+
# )
155+
# response_schema_mock.assert_called_once_with()
156+
# ExperimentClient._post.assert_called_once_with(
157+
# "/project/{}/run/query".format(PROJECT_ID),
158+
# response_schema_mock.return_value,
159+
# json=dump_mock.return_value,
160+
# )

0 commit comments

Comments
 (0)