20
20
from faculty ._util .resolvers import resolve_project_id
21
21
from faculty .clients .experiment import ExperimentClient
22
22
23
-
24
- class QueryResult (object ):
25
- def __init__ (self , iterable ):
26
- self .iterable = iterable
27
-
28
- def __iter__ (self ):
29
- return iter (self .iterable )
30
-
31
-
32
- class ExperimentRunQueryResult (QueryResult ):
33
- def as_dataframe (self ):
34
- records = []
35
- for run in self :
36
- row = {
37
- "Experiment ID" : run .experiment_id ,
38
- "Run ID" : run .id ,
39
- "Status" : run .status .value ,
40
- "Started At" : run .started_at ,
41
- }
42
- for metric in run .metrics :
43
- row [metric .key ] = row [metric .value ]
44
- records .append (row )
45
- return pandas .DataFrame (records )
23
+ from faculty .clients .experiment import (
24
+ ComparisonOperator ,
25
+ DeletedAtFilter ,
26
+ ExperimentIdFilter ,
27
+ MetricFilter ,
28
+ ParamFilter ,
29
+ RunIdFilter ,
30
+ TagFilter ,
31
+ )
46
32
47
33
48
34
@attrs
49
35
class ExperimentRun (object ):
36
+ """A single run of an experiment."""
37
+
50
38
id = attrib ()
51
39
run_number = attrib ()
52
40
experiment_id = attrib ()
@@ -67,6 +55,52 @@ def _from_client_model(cls, client_object):
67
55
68
56
@classmethod
69
57
def query (cls , project = None , filter = None , sort = None , ** session_config ):
58
+ """Query the platform for experiment runs.
59
+
60
+ Parameters
61
+ ----------
62
+ project : str, UUID, or None
63
+ The name or ID of a project. If ``None`` is passed (the default),
64
+ the project will be inferred from the runtime context.
65
+ filter : a filter object from ``faculty.clients.experiment``
66
+ Condition(s) to filter experiment runs by. ``FilterBy`` provides a
67
+ convenience interface for constructing filter objects.
68
+ sort : a sequence of sort objects from ``faculty.clients.experiment``
69
+ Condition(s) to sort experiment runs by.
70
+ **session_config
71
+ Configuration options to build the session with.
72
+
73
+ Returns
74
+ -------
75
+ ExperimentRunList
76
+
77
+ Examples
78
+ --------
79
+ Get all experiment runs in the current project:
80
+
81
+ >>> ExperimentRun.query()
82
+ ExperimentRunList([ExperimentRun(...)])
83
+
84
+ Get all experiment runs in a named project:
85
+
86
+ >>> ExperimentRun.query("my project")
87
+ ExperimentRunList([ExperimentRun(...)])
88
+
89
+ Filter experiment runs by experiment ID:
90
+
91
+ >>> ExperimentRun.query(filter=FilterBy.experiment_id() == 2)
92
+ ExperimentRunList([ExperimentRun(...)])
93
+
94
+ Filter experiment runs by a more complex condition:
95
+
96
+ >>> filter = (
97
+ ... FilterBy.experiment_id().one_of([2, 3, 4]) &
98
+ ... (FilterBy.metric("accuracy") > 0.9) &
99
+ ... (FilterBy.param("alpha") < 0.3)
100
+ ... )
101
+ >>> ExperimentRun.query("my project", filter)
102
+ ExperimentRunList([ExperimentRun(...)])
103
+ """
70
104
71
105
session = get_session (** session_config )
72
106
project_id = resolve_project_id (session , project )
@@ -75,7 +109,6 @@ def _get_runs():
75
109
client = ExperimentClient (session )
76
110
77
111
response = client .query_runs (project_id , filter , sort )
78
- # return map(cls._from_client_model, response.runs)
79
112
yield from map (cls ._from_client_model , response .runs )
80
113
81
114
while response .pagination .next is not None :
@@ -88,4 +121,193 @@ def _get_runs():
88
121
)
89
122
yield from map (cls ._from_client_model , response .runs )
90
123
91
- return ExperimentRunQueryResult (list (_get_runs ()))
124
+ return ExperimentRunList (_get_runs ())
125
+
126
+
127
+ class ExperimentRunList (list ):
128
+ """A list of experiment runs.
129
+
130
+ This collection is a subclass of ``list``, and so supports all its
131
+ functionality, but adds the ``as_dataframe`` method which returns a
132
+ representation of the contained ExperimentRuns as a ``pandas.DataFrame``.
133
+ """
134
+
135
+ def __repr__ (self ):
136
+ return "{}({})" .format (
137
+ self .__class__ .__name__ , super (ExperimentRunList , self ).__repr__ ()
138
+ )
139
+
140
+ def as_dataframe (self ):
141
+ """Get the experiment runs as a pandas DataFrame.
142
+
143
+ Returns
144
+ -------
145
+ pandas.DataFrame
146
+ """
147
+
148
+ records = []
149
+ for run in self :
150
+ row = {
151
+ ("experiment_id" , "" ): run .experiment_id ,
152
+ ("run_id" , "" ): run .id ,
153
+ ("run_number" , "" ): run .run_number ,
154
+ ("status" , "" ): run .status .value ,
155
+ ("started_at" , "" ): run .started_at ,
156
+ ("ended_at" , "" ): run .ended_at ,
157
+ }
158
+ for param in run .params :
159
+ row [("params" , param .key )] = param .value
160
+ for metric in run .metrics :
161
+ row [("metrics" , metric .key )] = metric .value
162
+ records .append (row )
163
+
164
+ df = pandas .DataFrame (records )
165
+ df .columns = pandas .MultiIndex .from_tuples (df .columns )
166
+
167
+ # Reorder columns and return
168
+ return df [
169
+ [
170
+ "experiment_id" ,
171
+ "run_id" ,
172
+ "run_number" ,
173
+ "status" ,
174
+ "started_at" ,
175
+ "ended_at" ,
176
+ "params" ,
177
+ "metrics" ,
178
+ ]
179
+ ]
180
+
181
+
182
+ class _FilterBuilder (object ):
183
+ def __init__ (self , constructor , * constructor_args ):
184
+ self .constructor = constructor
185
+ self .constructor_args = constructor_args
186
+
187
+ def _build (self , * args ):
188
+ return self .constructor (* (self .constructor_args + args ))
189
+
190
+ def defined (self , value = True ):
191
+ return self ._build (ComparisonOperator .DEFINED , value )
192
+
193
+ def __eq__ (self , value ):
194
+ return self ._build (ComparisonOperator .EQUAL_TO , value )
195
+
196
+ def __ne__ (self , value ):
197
+ return self ._build (ComparisonOperator .NOT_EQUAL_TO , value )
198
+
199
+ def __gt__ (self , value ):
200
+ return self ._build (ComparisonOperator .GREATER_THAN , value )
201
+
202
+ def __ge__ (self , value ):
203
+ return self ._build (ComparisonOperator .GREATER_THAN_OR_EQUAL_TO , value )
204
+
205
+ def __lt__ (self , value ):
206
+ return self ._build (ComparisonOperator .LESS_THAN , value )
207
+
208
+ def __le__ (self , value ):
209
+ return self ._build (ComparisonOperator .LESS_THAN_OR_EQUAL_TO , value )
210
+
211
+ def one_of (self , values ):
212
+ try :
213
+ first , remaining = values [0 ], values [1 :]
214
+ except IndexError :
215
+ raise ValueError ("Must provide at least one value" )
216
+ filter = self == first
217
+ for val in remaining :
218
+ filter |= self == val
219
+ return filter
220
+
221
+
222
+ class FilterBy (object ):
223
+ @staticmethod
224
+ def experiment_id ():
225
+ """Filter by experiment ID.
226
+
227
+ Examples
228
+ --------
229
+ Get runs for experiment 4:
230
+
231
+ >>> FilterBy.experiment_id() == 4
232
+ """
233
+ return _FilterBuilder (ExperimentIdFilter )
234
+
235
+ @staticmethod
236
+ def run_id ():
237
+ """Filter by run ID.
238
+
239
+ Examples
240
+ --------
241
+ Get the run with a specified ID:
242
+
243
+ >>> FilterBy.run_id() == "945f1d96-9937-4b95-aa3f-addcdd1c8749"
244
+ """
245
+ return _FilterBuilder (RunIdFilter )
246
+
247
+ @staticmethod
248
+ def deleted_at ():
249
+ """Filter by run deletion time.
250
+
251
+ Examples
252
+ --------
253
+ Get runs deleted more than ten minutes ago:
254
+
255
+ >>> from datetime import datetime, timedelta
256
+ >>> FilterBy.deleted_at() < datetime.now() - timedelta(minutes=10)
257
+
258
+ Get non-deleted runs:
259
+
260
+ >>> FilterBy.deleted_at() == None
261
+ """
262
+ return _FilterBuilder (DeletedAtFilter )
263
+
264
+ @staticmethod
265
+ def tag (key ):
266
+ """Filter by run tag.
267
+
268
+ Examples
269
+ --------
270
+ Get runs with a particular tag:
271
+
272
+ >>> FilterBy.tag("key") == "value"
273
+
274
+ Get runs where a tag is set, with any value:
275
+
276
+ >>> FilterBy.tag("key") != None
277
+ """
278
+ return _FilterBuilder (TagFilter , key )
279
+
280
+ @staticmethod
281
+ def param (key ):
282
+ """Filter by parameter.
283
+
284
+ Examples
285
+ --------
286
+ Get runs with a particular parameter value:
287
+
288
+ >>> FilterBy.param("key") == "value"
289
+
290
+ Params also support filtering by numeric value:
291
+
292
+ >>> FilterBy.param("alpha") > 0.2
293
+ """
294
+ return _FilterBuilder (ParamFilter , key )
295
+
296
+ @staticmethod
297
+ def metric (key ):
298
+ """Filter by metric.
299
+
300
+ Examples
301
+ --------
302
+ Get runs with matching metric values:
303
+
304
+ >>> FilterBy.metric("accuracy") > 0.9
305
+
306
+ To filter a range of values, combine them with ``&``:
307
+
308
+ >>> (
309
+ ... (FilterBy.metric("accuracy") > 0.8 ) &
310
+ ... (FilterBy.metric("accuracy") > 0.9)
311
+ ... )
312
+ """
313
+ return _FilterBuilder (MetricFilter , key )
0 commit comments