Skip to content

Commit

Permalink
Add csv export
Browse files Browse the repository at this point in the history
  • Loading branch information
yngve-sk committed Jan 21, 2025
1 parent 0cd49c8 commit 2e61f85
Show file tree
Hide file tree
Showing 11 changed files with 442 additions and 6 deletions.
176 changes: 173 additions & 3 deletions src/everest/api/everest_data_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pathlib import Path

import polars
Expand Down Expand Up @@ -268,7 +269,176 @@ def summary_values(self, batches=None, keys=None):
def output_folder(self):
return self._config.output_dir

def export_dataframes(
self,
) -> tuple[polars.DataFrame, polars.DataFrame, polars.DataFrame]:
batch_dfs_to_join = {}
realization_dfs_to_join = {}
perturbation_dfs_to_join = {}

batch_ids = [b.batch_id for b in self._ever_storage.data.batches]
all_controls = self._ever_storage.data.controls["control_name"].to_list()

def _try_append_df(
batch_id: int,
df: polars.DataFrame | None,
target: dict[str, list[polars.DataFrame]],
):
if df is not None:
if batch_id not in target:
target[batch.batch_id] = []

target[batch_id].append(df)

def try_append_batch_dfs(batch_id: int, *dfs: polars.DataFrame):
for df_ in dfs:
_try_append_df(batch_id, df_, batch_dfs_to_join)

def try_append_realization_dfs(batch_id: int, *dfs: polars.DataFrame):
for df_ in dfs:
_try_append_df(batch_id, df_, realization_dfs_to_join)

def try_append_perturbation_dfs(batch_id: int, *dfs: polars.DataFrame):
for df_ in dfs:
_try_append_df(batch_id, df_, perturbation_dfs_to_join)

def pivot_gradient(df: polars.DataFrame) -> polars.DataFrame:
pivoted_ = df.pivot(on="control_name", index="batch_id", separator=" wrt ")
return pivoted_.rename(
{
col: f"grad({col})"
for col in pivoted_.columns
if col != "batch_id" and col not in all_controls
}
)

for batch in self._ever_storage.data.batches:
try_append_perturbation_dfs(
batch.batch_id,
batch.perturbation_objectives,
batch.perturbation_constraints,
)

try_append_realization_dfs(
batch.batch_id,
batch.realization_objectives,
batch.realization_controls,
batch.realization_constraints,
)

if batch.batch_objective_gradient is not None:
try_append_batch_dfs(
batch.batch_id, pivot_gradient(batch.batch_objective_gradient)
)

if batch.batch_constraint_gradient is not None:
try_append_batch_dfs(
batch.batch_id,
pivot_gradient(batch.batch_constraint_gradient),
)

try_append_batch_dfs(
batch.batch_id, batch.batch_objectives, batch.batch_constraints
)

def _join_by_batch(
dfs: dict[int, list[polars.DataFrame]], on: list[str]
) -> list[polars.DataFrame]:
"""
Creates one dataframe per batch, with one column per input/output,
including control, objective, constraint, gradient value.
"""
dfs_to_concat_ = []
for batch_id in batch_ids:
if batch_id not in dfs:
continue

batch_df_ = dfs[batch_id][0]
for bdf_ in dfs[batch_id][1:]:
if set(all_controls).issubset(set(bdf_.columns)) and set(
all_controls
).issubset(set(batch_df_.columns)):
bdf_ = bdf_.drop(all_controls)

batch_df_ = batch_df_.join(
bdf_,
on=on,
)

dfs_to_concat_.append(batch_df_)

return dfs_to_concat_

batch_dfs_to_concat = _join_by_batch(batch_dfs_to_join, on=["batch_id"])
batch_df = polars.concat(batch_dfs_to_concat, how="diagonal")

realization_dfs_to_concat = _join_by_batch(
realization_dfs_to_join, on=["batch_id", "realization", "simulation_id"]
)
realization_df = polars.concat(realization_dfs_to_concat, how="diagonal")

perturbation_dfs_to_concat = _join_by_batch(
perturbation_dfs_to_join, on=["batch_id", "realization", "perturbation"]
)
perturbation_df = polars.concat(perturbation_dfs_to_concat, how="diagonal")

pert_real_df = polars.concat([realization_df, perturbation_df], how="diagonal")

pert_real_df = pert_real_df.select(
"batch_id",
"realization",
"perturbation",
*list(
set(pert_real_df.columns) - {"batch_id", "realization", "perturbation"}
),
)

# Avoid name collisions when joining with simulations
batch_df_renamed = batch_df.rename(
{
col: f"batch_{col}"
for col in batch_df.columns
if col != "batch_id" and not col.startswith("grad")
}
)
combined_df = pert_real_df.join(
batch_df_renamed, on="batch_id", how="full", coalesce=True
)

def _sort_df(df: polars.DataFrame, index: list[str]):
sorted_cols = index + sorted(set(df.columns) - set(index))
df_ = df.select(sorted_cols).sort(by=index)
return df_

return (
_sort_df(
combined_df,
["batch_id", "realization", "simulation_id", "perturbation"],
),
_sort_df(
pert_real_df,
[
"batch_id",
"realization",
"perturbation",
"simulation_id",
],
),
_sort_df(batch_df, ["batch_id", "total_objective_value"]),
)

@property
def csv_export(self):
print(self._ever_storage.data)
return {}
def everest_csv(self):
export_filename = (
self._config.export.csv_output_filepath
if self._config.export is not None
else f"{self._config.config_file}.csv"
)

full_path = os.path.join(self.output_folder, export_filename)

if not os.path.exists(full_path):
combined_df, _, _ = self.export_dataframes()
combined_df.write_csv(full_path)

return os.path.join(self.output_folder, export_filename)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
batch_id,distance,grad(distance wrt point_x-0),grad(distance wrt point_x-1),grad(distance wrt point_x-2),grad(distance.total wrt point_x-0),grad(distance.total wrt point_x-1),grad(distance.total wrt point_x-2),merit_value,point_x-0,point_x-1,point_x-2,total_objective_value,x-0_coord,x-0_coord.violation
0,-1.6875,-0.4907662687030493,-0.5019101422144723,0.5042454858798938,-0.4907662687030493,-0.5019101422144723,0.5042454858798938,,0.9999923242971123,-0.00002866170271420768,-0.00003482388228645279,-1.6875,0.15,0.0
1,-1.6928177326917648,-0.5220932558524333,-0.2857933053410886,0.6486969102170087,-0.5220932558524333,-0.2857933053410886,0.6486969102170087,389.85565356542713,1.0000197666303703,7.318438233973908e-6,0.000054286921875937026,-1.6928177326917648,0.15904500484466552,0.0
2,-1.735619194805622,-0.6428129143568377,-0.005219520850941156,0.7315266224689232,-0.6428129143568377,-0.005219520850941156,0.7315266224689232,381.4612304677212,0.9999868957126514,2.561012506432357e-6,0.000025109042352951607,-1.735619194805622,0.21863600015640258,0.0
3,-1.6236748024821281,-0.4727620137163322,-0.07854769964884697,0.5153976529252893,-0.4727620137163322,-0.07854769964884697,0.5153976529252893,29.58547833176726,1.0000243522797598,-0.000045940653105887436,7.74341493223104e-6,-1.6236748024821281,0.1411780059337616,0.0
4,-1.539792999625206,-0.3036942605860327,0.023143745090389523,0.25834193323625815,-0.3036942605860327,0.023143745090389523,0.25834193323625815,13.246954160900696,0.9999717043148817,8.450350216384595e-6,-5.568694767103104e-6,-1.539792999625206,0.050914996862411493,0.0
5,-1.5256189703941345,-0.21656121993381236,-0.005788275557673861,0.21722039756315908,-0.21656121993381236,-0.005788275557673861,0.21722039756315908,11.750325264919065,0.9999855616128173,-0.000012794222306632448,3.542883847480084e-7,-1.5256189703941345,0.014527001976966852,0.0
Loading

0 comments on commit 2e61f85

Please sign in to comment.