diff --git a/gridfm_datakit/perturbations/load_perturbation.py b/gridfm_datakit/perturbations/load_perturbation.py index 0d2b2e335..e151d94dd 100644 --- a/gridfm_datakit/perturbations/load_perturbation.py +++ b/gridfm_datakit/perturbations/load_perturbation.py @@ -567,6 +567,101 @@ def __call__( return load_profiles +class PrecomputedProfile(LoadScenarioGeneratorBase): + """Reads precomputed bus-demand scenarios and returns them as (n_buses, n_scenarios, 2). + + CSV/XLSX columns: + - load_scenario : int scenario index (0..S-1) + - load : int BUS INDEX (0..n_buses-1) in *continuous indexing* used by Network + (i.e., after mapping to 0..n_buses-1) + - p_mw, q_mvar : floats + """ + + def __init__(self, scenario_file: str): + self.scenario_file = scenario_file + + def _read(self) -> pd.DataFrame: + p = self.scenario_file + if p.lower().endswith((".xlsx", ".xls")): + return pd.read_excel(p) + return pd.read_csv(p) + + def __call__( + self, + net, # type: Network + n_scenarios: int, + scenario_log: str, + max_iter: int, # unused, kept for interface compatibility + seed: int, + ) -> np.ndarray: + df = self._read() + + required = {"load_scenario", "load", "p_mw", "q_mvar"} + missing = required - set(df.columns) + if missing: + raise ValueError( + f"Scenario file must contain columns {sorted(required)}; missing {sorted(missing)}. " + f"Got {list(df.columns)}" + ) + + df["load_scenario"] = df["load_scenario"].astype(int) + df["load"] = df["load"].astype(int) + + n_buses = int(np.asarray(net.buses).shape[0]) + + # Scenario index validation (0..n_scenarios-1) + min_scenario = int(df["load_scenario"].min()) + max_scenario = int(df["load_scenario"].max()) + if min_scenario < 0 or max_scenario >= n_scenarios: + raise ValueError( + "Scenario file contains out-of-range scenario indices in column 'load_scenario'. " + f"Expected 0..{n_scenarios - 1}, got min={min_scenario}, max={max_scenario}." + ) + + # Bus index validation (continuous indices 0..n_buses-1) + min_bus = int(df["load"].min()) + max_bus = int(df["load"].max()) + if min_bus < 0 or max_bus >= n_buses: + raise ValueError( + "Scenario file contains out-of-range bus indices in column 'load'. " + f"Expected 0..{n_buses - 1}, got min={min_bus}, max={max_bus}." + ) + + # uniqueness of (load_scenario, load) pairs + dup_mask = df.duplicated(subset=["load_scenario", "load"]) + if dup_mask.any(): + dup_count = int(dup_mask.sum()) + raise ValueError( + f"Scenario file contains {dup_count} duplicate (load_scenario, load) pairs. " + "Each pair must be unique." + ) + + # check: require all scenario-bus pairs present + expected = n_buses * n_scenarios + actual = len(df) + if actual != expected: + raise ValueError( + f"Scenario file must contain exactly {expected} rows " + f"({n_buses} buses x {n_scenarios} scenarios); got {actual}." + ) + + # Allocate output: (n_buses, n_scenarios, 2) + out = np.zeros((n_buses, n_scenarios, 2), dtype=float) + + s = df["load_scenario"].to_numpy(dtype=int) + b = df["load"].to_numpy(dtype=int) + out[b, s, 0] = df["p_mw"].to_numpy(dtype=float) + out[b, s, 1] = df["q_mvar"].to_numpy(dtype=float) + + if scenario_log: + with open(scenario_log, "a") as f: + f.write( + f"precomputed_profile: scenarios={n_scenarios}, buses={n_buses}, " + f"path={self.scenario_file}\n" + ) + + return out + if __name__ == "__main__": """ diff --git a/gridfm_datakit/utils/param_handler.py b/gridfm_datakit/utils/param_handler.py index cd1fca724..a275e0041 100644 --- a/gridfm_datakit/utils/param_handler.py +++ b/gridfm_datakit/utils/param_handler.py @@ -3,6 +3,7 @@ LoadScenarioGeneratorBase, LoadScenariosFromAggProfile, Powergraph, + PrecomputedProfile, ) from typing import Dict, Any import warnings @@ -173,6 +174,7 @@ def get_load_scenario_generator(args: NestedNamespace) -> LoadScenarioGeneratorB Note: Currently supports 'agg_load_profile' and 'powergraph' generator types. """ + print("args.generator: ", args.generator) if args.generator == "agg_load_profile": return LoadScenariosFromAggProfile( args.agg_profile, @@ -194,8 +196,11 @@ def get_load_scenario_generator(args: NestedNamespace) -> LoadScenarioGeneratorB f"The following arguments are not used by the powergraph generator: {unused_args}", UserWarning, ) + - return Powergraph(args.agg_profile) + if args.generator == "precomputed_profile": + print("precomputed_profile being used") + return PrecomputedProfile(args.scenario_file) def initialize_topology_generator(