From be87a85e254a484438c7458e0cdfa9bbe948ee3f Mon Sep 17 00:00:00 2001 From: medha-14 Date: Mon, 28 Jul 2025 13:27:38 +0530 Subject: [PATCH] save and load parameters --- .../expression_tree/operations/serialise.py | 122 ++++++++++++++++++ .../test_serialisation/test_serialisation.py | 38 ++++++ 2 files changed, 160 insertions(+) diff --git a/src/pybamm/expression_tree/operations/serialise.py b/src/pybamm/expression_tree/operations/serialise.py index 29a0ba968d..e5875c969e 100644 --- a/src/pybamm/expression_tree/operations/serialise.py +++ b/src/pybamm/expression_tree/operations/serialise.py @@ -1,6 +1,7 @@ from __future__ import annotations import importlib +import inspect import json import numbers import re @@ -547,6 +548,72 @@ def load_custom_model(filename, battery_model=None): return model + @staticmethod + def save_parameters(parameters: dict, filename=None): + """ + Serializes a dictionary of parameters to a JSON file. + The values can be numbers, PyBaMM symbols, or callables. + + Parameters + ---------- + parameters : dict + A dictionary of parameter names and values. + Values can be numeric, PyBaMM symbols, or callables. + + filename : str, optional + If given, saves the serialized parameters to this file. + """ + parameter_values_dict = {} + + for k, v in parameters.items(): + if callable(v): + parameter_values_dict[k] = Serialise.convert_symbol_to_json( + Serialise.convert_function_to_symbolic_expression(v, k) + ) + else: + parameter_values_dict[k] = Serialise.convert_symbol_to_json(v) + + if filename is not None: + with open(filename, "w") as f: + json.dump(parameter_values_dict, f, indent=4) + + @staticmethod + def load_parameters(filename): + """ + Load a JSON file of parameters (either from Serialise.save_parameters + or from a standard pybamm.ParameterValues.save), and return a + pybamm.ParameterValues object. + + - If a value is a dict with a "type" key, deserialize it as a PyBaMM symbol. + - Otherwise (float, int, bool, str, list, dict-without-type), leave it as-is. + """ + with open(filename) as f: + raw_dict = json.load(f) + + deserialized = {} + for key, val in raw_dict.items(): + if isinstance(val, dict) and "type" in val: + deserialized[key] = Serialise.convert_symbol_from_json(val) + + elif isinstance(val, list): + deserialized[key] = val + + elif isinstance(val, (numbers.Number | bool)): + deserialized[key] = val + + elif isinstance(val, str): + deserialized[key] = val + + elif isinstance(val, dict): + deserialized[key] = val + + else: + raise ValueError( + f"Unsupported parameter format for key '{key}': {val!r}" + ) + + return pybamm.ParameterValues(deserialized) + # Helper functions def _get_pybamm_class(self, snippet: dict): @@ -711,6 +778,48 @@ def _convert_options(self, d): else: return d + def convert_function_to_symbolic_expression(func, name=None, param_map=None): + """ + Converts a Python function to a PyBaMM symbolic expression, allowing + mapping of parameter names to actual PyBaMM parameter names. + + Parameters + ---------- + func : callable + The Python function to convert + + name : str, optional + The name of the function to use in the symbolic expression. If not provided, + the function's own name is used. + + param_map : dict, optional + Mapping from argument names in the function to PyBaMM parameter names. + For example: {"T": "Ambient temperature [K]"} + + Returns + ------- + pybamm.Symbol + The PyBaMM symbolic expression + """ + sig = inspect.signature(func) + # func_name = name or func.__name__ + + param_names = list(sig.parameters.keys()) + + # Use mapped names if provided + sym_inputs = [] + for param in param_names: + mapped_name = ( + func.param_map[param] + if param_map and param in func.param_map + else param + ) + sym_inputs.append(pybamm.Parameter(mapped_name)) + + sym_output = func(*sym_inputs) + + return sym_output + @staticmethod def convert_symbol_to_json(symbol): """ @@ -790,6 +899,12 @@ def convert_symbol_to_json(symbol): "interpolator": symbol.interpolator, "entries_string": symbol.entries_string, } + elif isinstance(symbol, pybamm.InputParameter): # <-- ADDED BLOCK + json_dict = { + "type": "InputParameter", + "name": symbol.name, + "domain": symbol.domain, + } elif isinstance(symbol, pybamm.Variable): json_dict = { @@ -911,6 +1026,7 @@ def convert_symbol_from_json(json_data): Scalar(0x21569ea463d7fb2, 42.0, children=[], domains={}) """ + if isinstance(json_data, float | int | bool): return json_data @@ -938,6 +1054,12 @@ def convert_symbol_from_json(json_data): interpolator=json_data["interpolator"], entries_string=json_data["entries_string"], ) + elif symbol_type == "InputParameter": + name = json_data["name"] + domain = json_data.get("domain", {}) + + return pybamm.InputParameter(name, domain=domain) + elif symbol_type == "FunctionParameter": diff_variable = json_data["diff_variable"] if diff_variable is not None: diff --git a/tests/unit/test_serialisation/test_serialisation.py b/tests/unit/test_serialisation/test_serialisation.py index 9ddacd8794..7175fe08cd 100644 --- a/tests/unit/test_serialisation/test_serialisation.py +++ b/tests/unit/test_serialisation/test_serialisation.py @@ -1344,3 +1344,41 @@ def test_plotting_serialised_models(self): sim.plot(show_plot=False) os.remove(f"{name}.json") + + def test_models_and_params(self): + import numpy as np + + import pybamm + from pybamm.expression_tree.operations.serialise import Serialise + + param = pybamm.ParameterValues( + "Marquis2019" + ) # or another chemistry like "Chen2020" + + Serialise.save_parameters(param, filename="my_params.json") + + model = pybamm.lithium_ion.SPM() + Serialise.save_custom_model(model, "my_model.json") + model2 = Serialise.load_custom_model( + "my_model.json", battery_model=pybamm.lithium_ion.BaseModel() + ) + + # Load parameter values + param_dict = Serialise.load_parameters("my_params.json") + param3 = pybamm.ParameterValues(param_dict) + + param3.update({"Electrode height [m]": "[input]"}) + + # Load the model + model2 = Serialise.load_custom_model( + "my_model.json", battery_model=pybamm.lithium_ion.BaseModel() + ) + + # ✅ Process the model AFTER update() so the necessary .input_* keys get created + param3.process_model(model2) + + # Create simulation + t_eval = np.linspace(0, 60, 11) + inputs = {"Electrode height [m]": 0.2} + sim = pybamm.Simulation(model=model2, parameter_values=param3) + sim.solve(t_eval=t_eval, inputs=inputs)