Skip to content

Commit e2edc85

Browse files
committed
Add csv export endpoint to everest data api
1 parent 18bd199 commit e2edc85

File tree

11 files changed

+454
-0
lines changed

11 files changed

+454
-0
lines changed

src/everest/api/everest_data_api.py

+175
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from pathlib import Path
23

34
import polars
@@ -269,3 +270,177 @@ def summary_values(self, batches=None, keys=None):
269270
@property
270271
def output_folder(self):
271272
return self._config.output_dir
273+
274+
def export_dataframes(
275+
self,
276+
) -> tuple[polars.DataFrame, polars.DataFrame, polars.DataFrame]:
277+
batch_dfs_to_join = {}
278+
realization_dfs_to_join = {}
279+
perturbation_dfs_to_join = {}
280+
281+
batch_ids = [b.batch_id for b in self._ever_storage.data.batches]
282+
all_controls = self._ever_storage.data.controls["control_name"].to_list()
283+
284+
def _try_append_df(
285+
batch_id: int,
286+
df: polars.DataFrame | None,
287+
target: dict[str, list[polars.DataFrame]],
288+
):
289+
if df is not None:
290+
if batch_id not in target:
291+
target[batch.batch_id] = []
292+
293+
target[batch_id].append(df)
294+
295+
def try_append_batch_dfs(batch_id: int, *dfs: polars.DataFrame):
296+
for df_ in dfs:
297+
_try_append_df(batch_id, df_, batch_dfs_to_join)
298+
299+
def try_append_realization_dfs(batch_id: int, *dfs: polars.DataFrame):
300+
for df_ in dfs:
301+
_try_append_df(batch_id, df_, realization_dfs_to_join)
302+
303+
def try_append_perturbation_dfs(batch_id: int, *dfs: polars.DataFrame):
304+
for df_ in dfs:
305+
_try_append_df(batch_id, df_, perturbation_dfs_to_join)
306+
307+
def pivot_gradient(df: polars.DataFrame) -> polars.DataFrame:
308+
pivoted_ = df.pivot(on="control_name", index="batch_id", separator=" wrt ")
309+
return pivoted_.rename(
310+
{
311+
col: f"grad({col})"
312+
for col in pivoted_.columns
313+
if col != "batch_id" and col not in all_controls
314+
}
315+
)
316+
317+
for batch in self._ever_storage.data.batches:
318+
try_append_perturbation_dfs(
319+
batch.batch_id,
320+
batch.perturbation_objectives,
321+
batch.perturbation_constraints,
322+
)
323+
324+
try_append_realization_dfs(
325+
batch.batch_id,
326+
batch.realization_objectives,
327+
batch.realization_controls,
328+
batch.realization_constraints,
329+
)
330+
331+
if batch.batch_objective_gradient is not None:
332+
try_append_batch_dfs(
333+
batch.batch_id, pivot_gradient(batch.batch_objective_gradient)
334+
)
335+
336+
if batch.batch_constraint_gradient is not None:
337+
try_append_batch_dfs(
338+
batch.batch_id,
339+
pivot_gradient(batch.batch_constraint_gradient),
340+
)
341+
342+
try_append_batch_dfs(
343+
batch.batch_id, batch.batch_objectives, batch.batch_constraints
344+
)
345+
346+
def _join_by_batch(
347+
dfs: dict[int, list[polars.DataFrame]], on: list[str]
348+
) -> list[polars.DataFrame]:
349+
"""
350+
Creates one dataframe per batch, with one column per input/output,
351+
including control, objective, constraint, gradient value.
352+
"""
353+
dfs_to_concat_ = []
354+
for batch_id in batch_ids:
355+
if batch_id not in dfs:
356+
continue
357+
358+
batch_df_ = dfs[batch_id][0]
359+
for bdf_ in dfs[batch_id][1:]:
360+
if set(all_controls).issubset(set(bdf_.columns)) and set(
361+
all_controls
362+
).issubset(set(batch_df_.columns)):
363+
bdf_ = bdf_.drop(all_controls)
364+
365+
batch_df_ = batch_df_.join(
366+
bdf_,
367+
on=on,
368+
)
369+
370+
dfs_to_concat_.append(batch_df_)
371+
372+
return dfs_to_concat_
373+
374+
batch_dfs_to_concat = _join_by_batch(batch_dfs_to_join, on=["batch_id"])
375+
batch_df = polars.concat(batch_dfs_to_concat, how="diagonal")
376+
377+
realization_dfs_to_concat = _join_by_batch(
378+
realization_dfs_to_join, on=["batch_id", "realization", "simulation_id"]
379+
)
380+
realization_df = polars.concat(realization_dfs_to_concat, how="diagonal")
381+
382+
perturbation_dfs_to_concat = _join_by_batch(
383+
perturbation_dfs_to_join, on=["batch_id", "realization", "perturbation"]
384+
)
385+
perturbation_df = polars.concat(perturbation_dfs_to_concat, how="diagonal")
386+
387+
pert_real_df = polars.concat([realization_df, perturbation_df], how="diagonal")
388+
389+
pert_real_df = pert_real_df.select(
390+
"batch_id",
391+
"realization",
392+
"perturbation",
393+
*list(
394+
set(pert_real_df.columns) - {"batch_id", "realization", "perturbation"}
395+
),
396+
)
397+
398+
# Avoid name collisions when joining with simulations
399+
batch_df_renamed = batch_df.rename(
400+
{
401+
col: f"batch_{col}"
402+
for col in batch_df.columns
403+
if col != "batch_id" and not col.startswith("grad")
404+
}
405+
)
406+
combined_df = pert_real_df.join(
407+
batch_df_renamed, on="batch_id", how="full", coalesce=True
408+
)
409+
410+
def _sort_df(df: polars.DataFrame, index: list[str]):
411+
sorted_cols = index + sorted(set(df.columns) - set(index))
412+
df_ = df.select(sorted_cols).sort(by=index)
413+
return df_
414+
415+
return (
416+
_sort_df(
417+
combined_df,
418+
["batch_id", "realization", "simulation_id", "perturbation"],
419+
),
420+
_sort_df(
421+
pert_real_df,
422+
[
423+
"batch_id",
424+
"realization",
425+
"perturbation",
426+
"simulation_id",
427+
],
428+
),
429+
_sort_df(batch_df, ["batch_id", "total_objective_value"]),
430+
)
431+
432+
@property
433+
def everest_csv(self):
434+
export_filename = (
435+
self._config.export.csv_output_filepath
436+
if self._config.export is not None
437+
else f"{self._config.config_file}.csv"
438+
)
439+
440+
full_path = os.path.join(self.output_folder, export_filename)
441+
442+
if not os.path.exists(full_path):
443+
combined_df, _, _ = self.export_dataframes()
444+
combined_df.write_csv(full_path)
445+
446+
return os.path.join(self.output_folder, export_filename)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
batch_id,distance,grad(distance wrt point_x-0),grad(distance wrt point_x-1),grad(distance wrt point_x-2),grad(distance.total wrt point_x-0),grad(distance.total wrt point_x-1),grad(distance.total wrt point_x-2),merit_value,point_x-0,point_x-1,point_x-2,total_objective_value,x-0_coord,x-0_coord.violation
2+
0,-1.6875,-0.4907662687030493,-0.5019101422144723,0.5042454858798938,-0.4907662687030493,-0.5019101422144723,0.5042454858798938,,0.9999923242971123,-0.00002866170271420768,-0.00003482388228645279,-1.6875,0.15,0.0
3+
1,-1.6928177326917648,-0.5220932558524333,-0.2857933053410886,0.6486969102170087,-0.5220932558524333,-0.2857933053410886,0.6486969102170087,389.85565356542713,1.0000197666303703,7.318438233973908e-6,0.000054286921875937026,-1.6928177326917648,0.15904500484466552,0.0
4+
2,-1.735619194805622,-0.6428129143568377,-0.005219520850941156,0.7315266224689232,-0.6428129143568377,-0.005219520850941156,0.7315266224689232,381.4612304677212,0.9999868957126514,2.561012506432357e-6,0.000025109042352951607,-1.735619194805622,0.21863600015640258,0.0
5+
3,-1.6236748024821281,-0.4727620137163322,-0.07854769964884697,0.5153976529252893,-0.4727620137163322,-0.07854769964884697,0.5153976529252893,29.58547833176726,1.0000243522797598,-0.000045940653105887436,7.74341493223104e-6,-1.6236748024821281,0.1411780059337616,0.0
6+
4,-1.539792999625206,-0.3036942605860327,0.023143745090389523,0.25834193323625815,-0.3036942605860327,0.023143745090389523,0.25834193323625815,13.246954160900696,0.9999717043148817,8.450350216384595e-6,-5.568694767103104e-6,-1.539792999625206,0.050914996862411493,0.0
7+
5,-1.5256189703941345,-0.21656121993381236,-0.005788275557673861,0.21722039756315908,-0.21656121993381236,-0.005788275557673861,0.21722039756315908,11.750325264919065,0.9999855616128173,-0.000012794222306632448,3.542883847480084e-7,-1.5256189703941345,0.014527001976966852,0.0

0 commit comments

Comments
 (0)