Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions src/pybamm/expression_tree/operations/serialise.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import importlib
import inspect
import json
import numbers
import re
Expand Down Expand Up @@ -547,6 +548,72 @@ def load_custom_model(filename, battery_model=None):

return model

@staticmethod
def save_parameters(parameters: dict, filename=None):
"""
Serializes a dictionary of parameters to a JSON file.
The values can be numbers, PyBaMM symbols, or callables.

Parameters
----------
parameters : dict
A dictionary of parameter names and values.
Values can be numeric, PyBaMM symbols, or callables.

filename : str, optional
If given, saves the serialized parameters to this file.
"""
parameter_values_dict = {}

for k, v in parameters.items():
if callable(v):
parameter_values_dict[k] = Serialise.convert_symbol_to_json(
Serialise.convert_function_to_symbolic_expression(v, k)
)
else:
parameter_values_dict[k] = Serialise.convert_symbol_to_json(v)

if filename is not None:
with open(filename, "w") as f:
json.dump(parameter_values_dict, f, indent=4)

@staticmethod
def load_parameters(filename):
"""
Load a JSON file of parameters (either from Serialise.save_parameters
or from a standard pybamm.ParameterValues.save), and return a
pybamm.ParameterValues object.

- If a value is a dict with a "type" key, deserialize it as a PyBaMM symbol.
- Otherwise (float, int, bool, str, list, dict-without-type), leave it as-is.
"""
with open(filename) as f:
raw_dict = json.load(f)

deserialized = {}
for key, val in raw_dict.items():
if isinstance(val, dict) and "type" in val:
deserialized[key] = Serialise.convert_symbol_from_json(val)

elif isinstance(val, list):
deserialized[key] = val

elif isinstance(val, (numbers.Number | bool)):
deserialized[key] = val

elif isinstance(val, str):
deserialized[key] = val

elif isinstance(val, dict):
deserialized[key] = val

else:
raise ValueError(
f"Unsupported parameter format for key '{key}': {val!r}"
)

return pybamm.ParameterValues(deserialized)

# Helper functions

def _get_pybamm_class(self, snippet: dict):
Expand Down Expand Up @@ -711,6 +778,48 @@ def _convert_options(self, d):
else:
return d

def convert_function_to_symbolic_expression(func, name=None, param_map=None):
"""
Converts a Python function to a PyBaMM symbolic expression, allowing
mapping of parameter names to actual PyBaMM parameter names.

Parameters
----------
func : callable
The Python function to convert

name : str, optional
The name of the function to use in the symbolic expression. If not provided,
the function's own name is used.

param_map : dict, optional
Mapping from argument names in the function to PyBaMM parameter names.
For example: {"T": "Ambient temperature [K]"}

Returns
-------
pybamm.Symbol
The PyBaMM symbolic expression
"""
sig = inspect.signature(func)
# func_name = name or func.__name__

param_names = list(sig.parameters.keys())

# Use mapped names if provided
sym_inputs = []
for param in param_names:
mapped_name = (
func.param_map[param]
if param_map and param in func.param_map
else param
)
sym_inputs.append(pybamm.Parameter(mapped_name))

sym_output = func(*sym_inputs)

return sym_output

@staticmethod
def convert_symbol_to_json(symbol):
"""
Expand Down Expand Up @@ -790,6 +899,12 @@ def convert_symbol_to_json(symbol):
"interpolator": symbol.interpolator,
"entries_string": symbol.entries_string,
}
elif isinstance(symbol, pybamm.InputParameter): # <-- ADDED BLOCK
json_dict = {
"type": "InputParameter",
"name": symbol.name,
"domain": symbol.domain,
}

elif isinstance(symbol, pybamm.Variable):
json_dict = {
Expand Down Expand Up @@ -911,6 +1026,7 @@ def convert_symbol_from_json(json_data):
Scalar(0x21569ea463d7fb2, 42.0, children=[], domains={})

"""

if isinstance(json_data, float | int | bool):
return json_data

Expand Down Expand Up @@ -938,6 +1054,12 @@ def convert_symbol_from_json(json_data):
interpolator=json_data["interpolator"],
entries_string=json_data["entries_string"],
)
elif symbol_type == "InputParameter":
name = json_data["name"]
domain = json_data.get("domain", {})

return pybamm.InputParameter(name, domain=domain)

elif symbol_type == "FunctionParameter":
diff_variable = json_data["diff_variable"]
if diff_variable is not None:
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/test_serialisation/test_serialisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,3 +1344,41 @@ def test_plotting_serialised_models(self):
sim.plot(show_plot=False)

os.remove(f"{name}.json")

def test_models_and_params(self):
import numpy as np

import pybamm
from pybamm.expression_tree.operations.serialise import Serialise

param = pybamm.ParameterValues(
"Marquis2019"
) # or another chemistry like "Chen2020"

Serialise.save_parameters(param, filename="my_params.json")

model = pybamm.lithium_ion.SPM()
Serialise.save_custom_model(model, "my_model.json")
model2 = Serialise.load_custom_model(
"my_model.json", battery_model=pybamm.lithium_ion.BaseModel()
)

# Load parameter values
param_dict = Serialise.load_parameters("my_params.json")
param3 = pybamm.ParameterValues(param_dict)

param3.update({"Electrode height [m]": "[input]"})

# Load the model
model2 = Serialise.load_custom_model(
"my_model.json", battery_model=pybamm.lithium_ion.BaseModel()
)

# ✅ Process the model AFTER update() so the necessary .input_* keys get created
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@medha-14 Could you clean up the comments and refine the test functions more, seems like this was generated through an LLM, not against it, I just want to be sure that we keep the code readable and works as intended if you haven't tested it. I think you can add a short docstring above each test function instead of comments. Also, if you see any test functions that are irrelevant or not needed, feel free to remove them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the same suggestion here. I am also not against the use of AI models to generate code, but you must understand what the generated code does both from a microscopic level (say, what's happening in a specific test) and from a macroscopic one (what feature is being implemented, does the API match what we need, are we reusing functions we've previously defined, whether or not the AI system being used is confabulating code or not, etc.).

param3.process_model(model2)

# Create simulation
t_eval = np.linspace(0, 60, 11)
inputs = {"Electrode height [m]": 0.2}
sim = pybamm.Simulation(model=model2, parameter_values=param3)
sim.solve(t_eval=t_eval, inputs=inputs)
Loading