Skip to content

Commit c36e1cf

Browse files
committed
Refactoring, add filter helper and add tests
1 parent 7b1467d commit c36e1cf

File tree

3 files changed

+520
-199
lines changed

3 files changed

+520
-199
lines changed

faculty/experiment.py

+247-25
Original file line numberDiff line numberDiff line change
@@ -20,33 +20,21 @@
2020
from faculty._util.resolvers import resolve_project_id
2121
from faculty.clients.experiment import ExperimentClient
2222

23-
24-
class QueryResult(object):
25-
def __init__(self, iterable):
26-
self.iterable = iterable
27-
28-
def __iter__(self):
29-
return iter(self.iterable)
30-
31-
32-
class ExperimentRunQueryResult(QueryResult):
33-
def as_dataframe(self):
34-
records = []
35-
for run in self:
36-
row = {
37-
"Experiment ID": run.experiment_id,
38-
"Run ID": run.id,
39-
"Status": run.status.value,
40-
"Started At": run.started_at,
41-
}
42-
for metric in run.metrics:
43-
row[metric.key] = row[metric.value]
44-
records.append(row)
45-
return pandas.DataFrame(records)
23+
from faculty.clients.experiment import (
24+
ComparisonOperator,
25+
DeletedAtFilter,
26+
ExperimentIdFilter,
27+
MetricFilter,
28+
ParamFilter,
29+
RunIdFilter,
30+
TagFilter,
31+
)
4632

4733

4834
@attrs
4935
class ExperimentRun(object):
36+
"""A single run of an experiment."""
37+
5038
id = attrib()
5139
run_number = attrib()
5240
experiment_id = attrib()
@@ -67,6 +55,52 @@ def _from_client_model(cls, client_object):
6755

6856
@classmethod
6957
def query(cls, project=None, filter=None, sort=None, **session_config):
58+
"""Query the platform for experiment runs.
59+
60+
Parameters
61+
----------
62+
project : str, UUID, or None
63+
The name or ID of a project. If ``None`` is passed (the default),
64+
the project will be inferred from the runtime context.
65+
filter : a filter object from ``faculty.clients.experiment``
66+
Condition(s) to filter experiment runs by. ``FilterBy`` provides a
67+
convenience interface for constructing filter objects.
68+
sort : a sequence of sort objects from ``faculty.clients.experiment``
69+
Condition(s) to sort experiment runs by.
70+
**session_config
71+
Configuration options to build the session with.
72+
73+
Returns
74+
-------
75+
ExperimentRunList
76+
77+
Examples
78+
--------
79+
Get all experiment runs in the current project:
80+
81+
>>> ExperimentRun.query()
82+
ExperimentRunList([ExperimentRun(...)])
83+
84+
Get all experiment runs in a named project:
85+
86+
>>> ExperimentRun.query("my project")
87+
ExperimentRunList([ExperimentRun(...)])
88+
89+
Filter experiment runs by experiment ID:
90+
91+
>>> ExperimentRun.query(filter=FilterBy.experiment_id() == 2)
92+
ExperimentRunList([ExperimentRun(...)])
93+
94+
Filter experiment runs by a more complex condition:
95+
96+
>>> filter = (
97+
... FilterBy.experiment_id().one_of([2, 3, 4]) &
98+
... (FilterBy.metric("accuracy") > 0.9) &
99+
... (FilterBy.param("alpha") < 0.3)
100+
... )
101+
>>> ExperimentRun.query("my project", filter)
102+
ExperimentRunList([ExperimentRun(...)])
103+
"""
70104

71105
session = get_session(**session_config)
72106
project_id = resolve_project_id(session, project)
@@ -75,7 +109,6 @@ def _get_runs():
75109
client = ExperimentClient(session)
76110

77111
response = client.query_runs(project_id, filter, sort)
78-
# return map(cls._from_client_model, response.runs)
79112
yield from map(cls._from_client_model, response.runs)
80113

81114
while response.pagination.next is not None:
@@ -88,4 +121,193 @@ def _get_runs():
88121
)
89122
yield from map(cls._from_client_model, response.runs)
90123

91-
return ExperimentRunQueryResult(list(_get_runs()))
124+
return ExperimentRunList(_get_runs())
125+
126+
127+
class ExperimentRunList(list):
128+
"""A list of experiment runs.
129+
130+
This collection is a subclass of ``list``, and so supports all its
131+
functionality, but adds the ``as_dataframe`` method which returns a
132+
representation of the contained ExperimentRuns as a ``pandas.DataFrame``.
133+
"""
134+
135+
def __repr__(self):
136+
return "{}({})".format(
137+
self.__class__.__name__, super(ExperimentRunList, self).__repr__()
138+
)
139+
140+
def as_dataframe(self):
141+
"""Get the experiment runs as a pandas DataFrame.
142+
143+
Returns
144+
-------
145+
pandas.DataFrame
146+
"""
147+
148+
records = []
149+
for run in self:
150+
row = {
151+
("experiment_id", ""): run.experiment_id,
152+
("run_id", ""): run.id,
153+
("run_number", ""): run.run_number,
154+
("status", ""): run.status.value,
155+
("started_at", ""): run.started_at,
156+
("ended_at", ""): run.ended_at,
157+
}
158+
for param in run.params:
159+
row[("params", param.key)] = param.value
160+
for metric in run.metrics:
161+
row[("metrics", metric.key)] = metric.value
162+
records.append(row)
163+
164+
df = pandas.DataFrame(records)
165+
df.columns = pandas.MultiIndex.from_tuples(df.columns)
166+
167+
# Reorder columns and return
168+
return df[
169+
[
170+
"experiment_id",
171+
"run_id",
172+
"run_number",
173+
"status",
174+
"started_at",
175+
"ended_at",
176+
"params",
177+
"metrics",
178+
]
179+
]
180+
181+
182+
class _FilterBuilder(object):
183+
def __init__(self, constructor, *constructor_args):
184+
self.constructor = constructor
185+
self.constructor_args = constructor_args
186+
187+
def _build(self, *args):
188+
return self.constructor(*(self.constructor_args + args))
189+
190+
def defined(self, value=True):
191+
return self._build(ComparisonOperator.DEFINED, value)
192+
193+
def __eq__(self, value):
194+
return self._build(ComparisonOperator.EQUAL_TO, value)
195+
196+
def __ne__(self, value):
197+
return self._build(ComparisonOperator.NOT_EQUAL_TO, value)
198+
199+
def __gt__(self, value):
200+
return self._build(ComparisonOperator.GREATER_THAN, value)
201+
202+
def __ge__(self, value):
203+
return self._build(ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, value)
204+
205+
def __lt__(self, value):
206+
return self._build(ComparisonOperator.LESS_THAN, value)
207+
208+
def __le__(self, value):
209+
return self._build(ComparisonOperator.LESS_THAN_OR_EQUAL_TO, value)
210+
211+
def one_of(self, values):
212+
try:
213+
first, remaining = values[0], values[1:]
214+
except IndexError:
215+
raise ValueError("Must provide at least one value")
216+
filter = self == first
217+
for val in remaining:
218+
filter |= self == val
219+
return filter
220+
221+
222+
class FilterBy(object):
223+
@staticmethod
224+
def experiment_id():
225+
"""Filter by experiment ID.
226+
227+
Examples
228+
--------
229+
Get runs for experiment 4:
230+
231+
>>> FilterBy.experiment_id() == 4
232+
"""
233+
return _FilterBuilder(ExperimentIdFilter)
234+
235+
@staticmethod
236+
def run_id():
237+
"""Filter by run ID.
238+
239+
Examples
240+
--------
241+
Get the run with a specified ID:
242+
243+
>>> FilterBy.run_id() == "945f1d96-9937-4b95-aa3f-addcdd1c8749"
244+
"""
245+
return _FilterBuilder(RunIdFilter)
246+
247+
@staticmethod
248+
def deleted_at():
249+
"""Filter by run deletion time.
250+
251+
Examples
252+
--------
253+
Get runs deleted more than ten minutes ago:
254+
255+
>>> from datetime import datetime, timedelta
256+
>>> FilterBy.deleted_at() < datetime.now() - timedelta(minutes=10)
257+
258+
Get non-deleted runs:
259+
260+
>>> FilterBy.deleted_at() == None
261+
"""
262+
return _FilterBuilder(DeletedAtFilter)
263+
264+
@staticmethod
265+
def tag(key):
266+
"""Filter by run tag.
267+
268+
Examples
269+
--------
270+
Get runs with a particular tag:
271+
272+
>>> FilterBy.tag("key") == "value"
273+
274+
Get runs where a tag is set, with any value:
275+
276+
>>> FilterBy.tag("key") != None
277+
"""
278+
return _FilterBuilder(TagFilter, key)
279+
280+
@staticmethod
281+
def param(key):
282+
"""Filter by parameter.
283+
284+
Examples
285+
--------
286+
Get runs with a particular parameter value:
287+
288+
>>> FilterBy.param("key") == "value"
289+
290+
Params also support filtering by numeric value:
291+
292+
>>> FilterBy.param("alpha") > 0.2
293+
"""
294+
return _FilterBuilder(ParamFilter, key)
295+
296+
@staticmethod
297+
def metric(key):
298+
"""Filter by metric.
299+
300+
Examples
301+
--------
302+
Get runs with matching metric values:
303+
304+
>>> FilterBy.metric("accuracy") > 0.9
305+
306+
To filter a range of values, combine them with ``&``:
307+
308+
>>> (
309+
... (FilterBy.metric("accuracy") > 0.8 ) &
310+
... (FilterBy.metric("accuracy") > 0.9)
311+
... )
312+
"""
313+
return _FilterBuilder(MetricFilter, key)

0 commit comments

Comments
 (0)