diff --git a/src/pybamm/__init__.py b/src/pybamm/__init__.py index e269831966..b50e0dff7c 100644 --- a/src/pybamm/__init__.py +++ b/src/pybamm/__init__.py @@ -58,6 +58,7 @@ from .expression_tree.operations.jacobian import Jacobian from .expression_tree.operations.convert_to_casadi import CasadiConverter from .expression_tree.operations.unpack_symbols import SymbolUnpacker +from .expression_tree.operations.serialise import Serialise # Model classes from .models.base_model import BaseModel diff --git a/src/pybamm/expression_tree/operations/serialise.py b/src/pybamm/expression_tree/operations/serialise.py index 4cf5b29c7a..e5875c969e 100644 --- a/src/pybamm/expression_tree/operations/serialise.py +++ b/src/pybamm/expression_tree/operations/serialise.py @@ -1,14 +1,19 @@ from __future__ import annotations import importlib +import inspect import json +import numbers import re from datetime import datetime +from enum import Enum import numpy as np import pybamm +SUPPORTED_SCHEMA_VERSION = "1.0" + class Serialise: """ @@ -237,6 +242,378 @@ def load_model( """ ) + @staticmethod + def _json_encoder(obj): + if isinstance(obj, Enum): + return obj.name + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.integer): + return int(obj) + else: + raise TypeError(f"Object of type {type(obj)} is not JSON serializable.") + + @staticmethod + def save_custom_model(model, filename=None): + """ + Saves a custom (non-discretised) PyBaMM model to a JSON file. Works for user defined models that are subclasses of BaseModel. + + This includes symbolic expressions for rhs, algebraic, initial and boundary + conditions, events, and variables. Useful for storing or sharing models + before discretisation. + + Parameters + ---------- + model : :class:`pybamm.BaseModel` + The custom symbolic model to be saved. + filename : str, optional + The desired name of the JSON file. If not provided, a name will be + generated from the model name and current datetime. + + Example + ------- + >>> import pybamm + >>> model = pybamm.lithium_ion.BasicDFN() + >>> from pybamm.expression_tree.operations.serialise import Serialise + >>> Serialise.save_custom_model(model, "basicdfn_model") + + """ + required_attrs = [ + "rhs", + "algebraic", + "initial_conditions", + "boundary_conditions", + "events", + "variables", + ] + missing = [attr for attr in required_attrs if not hasattr(model, attr)] + if missing: + raise AttributeError(f"Model is missing required sections: {missing}") + + try: + SCHEMA_VERSION = "1.0" + model_json = { + "pybamm_version": pybamm.__version__, + "schema_version": SCHEMA_VERSION, + "name": getattr(model, "name", "unnamed_model"), + "options": getattr(model, "options", {}), + "rhs": [ + ( + Serialise.convert_symbol_to_json(variable), + Serialise.convert_symbol_to_json(rhs_expression), + ) + for variable, rhs_expression in getattr(model, "rhs", {}).items() + ], + "algebraic": [ + ( + Serialise.convert_symbol_to_json(variable), + Serialise.convert_symbol_to_json(algebraic_expression), + ) + for variable, algebraic_expression in getattr( + model, "algebraic", {} + ).items() + ], + "initial_conditions": [ + ( + Serialise.convert_symbol_to_json(variable), + Serialise.convert_symbol_to_json(initial_value), + ) + for variable, initial_value in getattr( + model, "initial_conditions", {} + ).items() + ], + "boundary_conditions": [ + ( + Serialise.convert_symbol_to_json(variable), + { + side: [ + Serialise.convert_symbol_to_json(expression), + boundary_type, + ] + for side, (expression, boundary_type) in conditions.items() + }, + ) + for variable, conditions in getattr( + model, "boundary_conditions", {} + ).items() + ], + "events": [ + { + "name": event.name, + "expression": Serialise.convert_symbol_to_json( + event.expression + ), + "event_type": event.event_type, + } + for event in getattr(model, "events", []) + ], + "variables": { + str(variable_name): Serialise.convert_symbol_to_json(expression) + for variable_name, expression in getattr( + model, "variables", {} + ).items() + }, + } + + if filename is None: + safe_name = re.sub(r"[^\w\-_.]", "_", model.name or "unnamed_model") + timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") + filename = f"{safe_name}_{timestamp}.json" + else: + # Clean provided filename and ensure .json + base = re.sub(r'[<>:"/\\|?*\x00-\x1F]', "_", filename) + filename = base if base.endswith(".json") else f"{base}.json" + + try: + with open(filename, "w") as f: + json.dump(model_json, f, indent=2, default=Serialise._json_encoder) + except OSError as file_err: + raise OSError( + f"Failed to write model JSON to file '{filename}': {file_err}" + ) from file_err + + except Exception as e: + raise ValueError(f"Failed to save custom model: {e}") from e + + @staticmethod + def _create_symbol_key(symbol_json: dict) -> str: + """ + Given the JSON‐dict for a symbol, return a unique, hashable key. + We just sort the dict keys and dump to a string. + """ + return json.dumps(symbol_json, sort_keys=True) + + @staticmethod + def load_custom_model(filename, battery_model=None): + """ + Loads a custom (symbolic) PyBaMM model from a JSON file. + + Reconstructs a model saved using `save_custom_model`, including its rhs, + algebraic equations, initial and boundary conditions, events, and variables. + Returns a fully symbolic model ready for further processing or discretisation. + + Parameters + ---------- + filename : str + Path to the JSON file containing the saved model. + battery_model : :class:`pybamm.BaseModel`, optional + An optional existing model instance to populate. If not provided, a new + :class:`pybamm.BaseModel` is created. + + Returns + ------- + :class:`pybamm.BaseModel` or subclass + The reconstructed symbolic PyBaMM model. + + Example + ------- + >>> import pybamm + >>> model = pybamm.lithium_ion.BasicDFN() + >>> from pybamm.expression_tree.operations.serialise import Serialise + >>> Serialise.save_custom_model(model, "basicdfn_model") + >>> loaded_model = Serialise.load_custom_model("basicdfn_model.json", battery_model=pybamm.lithium_ion.BaseModel()) + + """ + try: + with open(filename) as file: + model_data = json.load(file) + except FileNotFoundError as err: + raise FileNotFoundError(f"Could not find file: {filename}") from err + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in file '{filename}': {e!s}") from e + + # check missing secitons in json + required = [ + "name", + "rhs", + "initial_conditions", + "algebraic", + "boundary_conditions", + "events", + "variables", + ] + missing = [k for k in required if k not in model_data] + if missing: + raise KeyError(f"Missing required model sections: {missing}") + + schema_version = model_data.get("schema_version", SUPPORTED_SCHEMA_VERSION) + if schema_version != SUPPORTED_SCHEMA_VERSION: + raise ValueError( + f"Unsupported schema version: {schema_version}. " + f"Expected: {SUPPORTED_SCHEMA_VERSION}" + ) + + model = battery_model if battery_model is not None else pybamm.BaseModel() + + model.name = model_data["name"] + + model.schema_version = schema_version + + all_variable_keys = ( + [lhs_json for lhs_json, _ in model_data["rhs"]] + + [lhs_json for lhs_json, _ in model_data["initial_conditions"]] + + [lhs_json for lhs_json, _ in model_data["algebraic"]] + + [variable_json for variable_json, _ in model_data["boundary_conditions"]] + ) + + symbol_map = {} + for variable_json in all_variable_keys: + try: + symbol = Serialise.convert_symbol_from_json(variable_json) + key = Serialise._create_symbol_key(variable_json) + symbol_map[key] = symbol + except Exception as e: + raise ValueError( + f"Failed to process symbol key for variable {variable_json}: {e!s}" + ) from e + + model.rhs = {} + for lhs_json, rhs_expr_json in model_data["rhs"]: + try: + lhs = symbol_map[Serialise._create_symbol_key(lhs_json)] + rhs = Serialise.convert_symbol_from_json(rhs_expr_json) + model.rhs[lhs] = rhs + except Exception as e: + raise ValueError( + f"Failed to convert rhs entry for {lhs_json}: {e!s}" + ) from e + + model.algebraic = {} + for lhs_json, algebraic_expr_json in model_data["algebraic"]: + try: + lhs = symbol_map[Serialise._create_symbol_key(lhs_json)] + rhs = Serialise.convert_symbol_from_json(algebraic_expr_json) + model.algebraic[lhs] = rhs + except Exception as e: + raise ValueError( + f"Failed to convert algebraic entry for {lhs_json}: {e!s}" + ) from e + + model.initial_conditions = {} + for lhs_json, initial_value_json in model_data["initial_conditions"]: + try: + lhs = symbol_map[Serialise._create_symbol_key(lhs_json)] + rhs = Serialise.convert_symbol_from_json(initial_value_json) + model.initial_conditions[lhs] = rhs + except Exception as e: + raise ValueError( + f"Failed to convert initial condition entry for {lhs_json}: {e!s}" + ) from e + + model.boundary_conditions = {} + for variable_json, condition_dict in model_data["boundary_conditions"]: + try: + variable = symbol_map[Serialise._create_symbol_key(variable_json)] + sides = {} + for side, (expression_json, boundary_type) in condition_dict.items(): + try: + expr = Serialise.convert_symbol_from_json(expression_json) + sides[side] = (expr, boundary_type) + except Exception as e: + raise ValueError( + f"Failed to convert boundary expression for variable {variable_json} on side '{side}': {e!s}" + ) from e + model.boundary_conditions[variable] = sides + except Exception as e: + raise ValueError( + f"Failed to convert boundary condition entry for variable {variable_json}: {e!s}" + ) from e + + model.events = [] + for event_data in model_data["events"]: + try: + name = event_data["name"] + expr = Serialise.convert_symbol_from_json(event_data["expression"]) + event_type = event_data["event_type"] + model.events.append(pybamm.Event(name, expr, event_type)) + except Exception as e: + raise ValueError( + f"Failed to convert event '{event_data.get('name', 'UNKNOWN')}': {e!s}" + ) from e + + model.variables = {} + for variable_name, expression_json in model_data["variables"].items(): + try: + key = Serialise._create_symbol_key(expression_json) + symbol = symbol_map.get(key) + if symbol is None: + symbol = Serialise.convert_symbol_from_json(expression_json) + model.variables[variable_name] = symbol + except Exception as e: + raise ValueError( + f"Failed to convert variable '{variable_name}': {e!s}" + ) from e + + return model + + @staticmethod + def save_parameters(parameters: dict, filename=None): + """ + Serializes a dictionary of parameters to a JSON file. + The values can be numbers, PyBaMM symbols, or callables. + + Parameters + ---------- + parameters : dict + A dictionary of parameter names and values. + Values can be numeric, PyBaMM symbols, or callables. + + filename : str, optional + If given, saves the serialized parameters to this file. + """ + parameter_values_dict = {} + + for k, v in parameters.items(): + if callable(v): + parameter_values_dict[k] = Serialise.convert_symbol_to_json( + Serialise.convert_function_to_symbolic_expression(v, k) + ) + else: + parameter_values_dict[k] = Serialise.convert_symbol_to_json(v) + + if filename is not None: + with open(filename, "w") as f: + json.dump(parameter_values_dict, f, indent=4) + + @staticmethod + def load_parameters(filename): + """ + Load a JSON file of parameters (either from Serialise.save_parameters + or from a standard pybamm.ParameterValues.save), and return a + pybamm.ParameterValues object. + + - If a value is a dict with a "type" key, deserialize it as a PyBaMM symbol. + - Otherwise (float, int, bool, str, list, dict-without-type), leave it as-is. + """ + with open(filename) as f: + raw_dict = json.load(f) + + deserialized = {} + for key, val in raw_dict.items(): + if isinstance(val, dict) and "type" in val: + deserialized[key] = Serialise.convert_symbol_from_json(val) + + elif isinstance(val, list): + deserialized[key] = val + + elif isinstance(val, (numbers.Number | bool)): + deserialized[key] = val + + elif isinstance(val, str): + deserialized[key] = val + + elif isinstance(val, dict): + deserialized[key] = val + + else: + raise ValueError( + f"Unsupported parameter format for key '{key}': {val!r}" + ) + + return pybamm.ParameterValues(deserialized) + # Helper functions def _get_pybamm_class(self, snippet: dict): @@ -360,6 +737,7 @@ def _reconstruct_pybamm_dict(self, obj: dict): {"rod": {SpatialVariable(name="spat_var"): {"min":0.0, "max":2.0} } } + """ def recurse(obj): @@ -399,3 +777,354 @@ def _convert_options(self, d): return tuple(self._convert_options(item) for item in d) else: return d + + def convert_function_to_symbolic_expression(func, name=None, param_map=None): + """ + Converts a Python function to a PyBaMM symbolic expression, allowing + mapping of parameter names to actual PyBaMM parameter names. + + Parameters + ---------- + func : callable + The Python function to convert + + name : str, optional + The name of the function to use in the symbolic expression. If not provided, + the function's own name is used. + + param_map : dict, optional + Mapping from argument names in the function to PyBaMM parameter names. + For example: {"T": "Ambient temperature [K]"} + + Returns + ------- + pybamm.Symbol + The PyBaMM symbolic expression + """ + sig = inspect.signature(func) + # func_name = name or func.__name__ + + param_names = list(sig.parameters.keys()) + + # Use mapped names if provided + sym_inputs = [] + for param in param_names: + mapped_name = ( + func.param_map[param] + if param_map and param in func.param_map + else param + ) + sym_inputs.append(pybamm.Parameter(mapped_name)) + + sym_output = func(*sym_inputs) + + return sym_output + + @staticmethod + def convert_symbol_to_json(symbol): + """ + Recursively converts a PyBaMM symbolic expression into a JSON-serializable format. + + Supports most PyBaMM symbol types, including scalars, variables, parameters, + operators, broadcasts, and interpolants. + + Parameters + ---------- + symbol : pybamm.Symbol or compatible type + The expression or object to convert. + + Returns + ------- + dict + A JSON-compatible representation of the input. + + Examples + -------- + >>> import pybamm + >>> from pybamm.expression_tree.operations.serialise import Serialise + >>> s = pybamm.Scalar(5) + >>> Serialise.convert_symbol_to_json(s) + {'type': 'Scalar', 'value': 5.0} + + >>> v = pybamm.Variable("c") + >>> Serialise.convert_symbol_to_json(v) + {'type': 'Variable', 'name': 'c', 'domains': {'primary': [], 'secondary': [], 'tertiary': [], 'quaternary': []}, 'bounds': [{'type': 'Scalar', 'value': -inf}, {'type': 'Scalar', 'value': inf}]} + """ + + if isinstance(symbol, numbers.Number | list): + return symbol + elif isinstance(symbol, pybamm.Time): + return {"type": "Time"} + + elif isinstance(symbol, pybamm.Parameter): + return {"type": "Parameter", "name": symbol.name} + + elif isinstance(symbol, pybamm.Scalar): + return {"type": "Scalar", "value": symbol.value} + + elif isinstance(symbol, pybamm.PrimaryBroadcast): + json_dict = { + "type": "PrimaryBroadcast", + "children": [Serialise.convert_symbol_to_json(symbol.child)], + "broadcast_domain": symbol.broadcast_domain, + } + elif isinstance(symbol, pybamm.FunctionParameter): + input_names = symbol.input_names + inputs = { + input_names[i]: Serialise.convert_symbol_to_json(symbol.orphans[i]) + for i in range(len(input_names)) + } + dv = symbol.diff_variable + if dv is not None: + dv_json = Serialise.convert_symbol_to_json(dv) + else: + dv_json = None + json_dict = { + "type": symbol.__class__.__name__, + "inputs": inputs, + "diff_variable": dv_json, + "name": symbol.name, + "domains": symbol.domains, + } + + elif isinstance(symbol, pybamm.Interpolant): + json_dict = { + "type": symbol.__class__.__name__, + "x": [x.tolist() for x in symbol.x], + "y": symbol.y.tolist(), + "children": [ + Serialise.convert_symbol_to_json(c) for c in symbol.children + ], + "name": symbol.name, + "interpolator": symbol.interpolator, + "entries_string": symbol.entries_string, + } + elif isinstance(symbol, pybamm.InputParameter): # <-- ADDED BLOCK + json_dict = { + "type": "InputParameter", + "name": symbol.name, + "domain": symbol.domain, + } + + elif isinstance(symbol, pybamm.Variable): + json_dict = { + "type": "Variable", + "name": symbol.name, + "domains": symbol.domains, + "bounds": [ + Serialise.convert_symbol_to_json(symbol.bounds[0]), + Serialise.convert_symbol_to_json(symbol.bounds[1]), + ], + } + elif isinstance(symbol, pybamm.ConcatenationVariable): + json_dict = { + "type": "ConcatenationVariable", + "name": symbol.name, + "children": [ + Serialise.convert_symbol_to_json(child) for child in symbol.children + ], + } + elif isinstance(symbol, pybamm.FullBroadcast): + json_dict = { + "type": "FullBroadcast", + "children": [Serialise.convert_symbol_to_json(symbol.child)], + "domains": symbol.domains, + } + elif isinstance(symbol, pybamm.SecondaryBroadcast): + json_dict = { + "type": "SecondaryBroadcast", + "children": [Serialise.convert_symbol_to_json(symbol.child)], + "broadcast_domain": symbol.broadcast_domain, + } + elif isinstance(symbol, pybamm.SpatialVariable): + json_dict = { + "type": "SpatialVariable", + "name": symbol.name, + "domains": symbol.domains, + "coord_sys": symbol.coord_sys, + } + elif isinstance(symbol, pybamm.IndefiniteIntegral): + integration_var = ( + symbol.integration_variable[0] + if isinstance(symbol.integration_variable, list) + else symbol.integration_variable + ) + json_dict = { + "type": "IndefiniteIntegral", + "children": [Serialise.convert_symbol_to_json(symbol.child)], + "integration_variable": Serialise.convert_symbol_to_json( + integration_var + ), + } + elif isinstance(symbol, pybamm.BoundaryValue): + json_dict = { + "type": "BoundaryValue", + "side": symbol.side, + "children": [Serialise.convert_symbol_to_json(symbol.orphans[0])], + } + elif isinstance(symbol, pybamm.SpecificFunction): + if symbol.__class__ == pybamm.SpecificFunction: + raise NotImplementedError("SpecificFunction is not supported directly") + json_dict = { + "type": symbol.__class__.__name__, + "children": [ + Serialise.convert_symbol_to_json(c) for c in symbol.children + ], + } + + elif isinstance(symbol, pybamm.UnaryOperator | pybamm.BinaryOperator): + json_dict = { + "type": symbol.__class__.__name__, + "children": [ + Serialise.convert_symbol_to_json(c) for c in symbol.children + ], + } + + elif isinstance(symbol, pybamm.Symbol): + # Generic fallback for other symbols with children + json_dict = { + "type": symbol.__class__.__name__, + "domains": symbol.domains, + "children": [ + Serialise.convert_symbol_to_json(c) for c in symbol.children + ], + } + if hasattr(symbol, "name"): + json_dict["name"] = symbol.name + + else: + raise ValueError( + f"Error processing '{symbol.name}'. Unknown symbol type: {type(symbol)}" + ) + return json_dict + + @staticmethod + def convert_symbol_from_json(json_data): + """ + Recursively reconstructs a PyBaMM symbolic expression from a JSON dictionary. + + Supports all major PyBaMM symbol types, including :class:`pybamm.Scalar`, :class:`pybamm.Variable`, :class:`pybamm.Parameter`, :class:`pybamm.Operator`, :class:`pybamm.FunctionParameter`, :class:`pybamm.Broadcast`, and :class:`pybamm.Interpolant`. + + + Parameters + ---------- + json_data : dict + A JSON-serialized representation of a PyBaMM expression, produced + by `Serialise.convert_symbol_to_json`. + + Returns + ------- + pybamm.Symbol or primitive + The reconstructed PyBaMM symbolic expression or a primitive (float, int, bool). + + Examples + -------- + >>> import pybamm + >>> from pybamm.expression_tree.operations.serialise import Serialise + >>> json_expr = {'type': 'Scalar', 'value': 42} + >>> Serialise.convert_symbol_from_json(json_expr) # doctest: +SKIP + Scalar(0x21569ea463d7fb2, 42.0, children=[], domains={}) + + """ + + if isinstance(json_data, float | int | bool): + return json_data + + if isinstance(json_data, str): + raise ValueError(f"Unexpected raw string in JSON: {json_data}") + + if json_data is None: + return None + + symbol_type = json_data.get("type") + + if symbol_type == "Parameter": + return pybamm.Parameter( + json_data["name"], + ) + elif symbol_type == "Scalar": + return pybamm.Scalar(json_data["value"]) + + elif symbol_type == "Interpolant": + return pybamm.Interpolant( + [np.array(x) for x in json_data["x"]], + np.array(json_data["y"]), + [Serialise.convert_symbol_from_json(c) for c in json_data["children"]], + name=json_data["name"], + interpolator=json_data["interpolator"], + entries_string=json_data["entries_string"], + ) + elif symbol_type == "InputParameter": + name = json_data["name"] + domain = json_data.get("domain", {}) + + return pybamm.InputParameter(name, domain=domain) + + elif symbol_type == "FunctionParameter": + diff_variable = json_data["diff_variable"] + if diff_variable is not None: + diff_variable = Serialise.convert_symbol_from_json(diff_variable) + return pybamm.FunctionParameter( + json_data["name"], + { + k: Serialise.convert_symbol_from_json(v) + for k, v in json_data["inputs"].items() + }, + diff_variable=diff_variable, + ) + elif symbol_type == "PrimaryBroadcast": + child = Serialise.convert_symbol_from_json(json_data["children"][0]) + domain = json_data["broadcast_domain"] + return pybamm.PrimaryBroadcast(child, domain) + elif symbol_type == "FullBroadcast": + child = Serialise.convert_symbol_from_json(json_data["children"][0]) + domains = json_data["domains"] + return pybamm.FullBroadcast(child, "broadcast", domains) + elif symbol_type == "SecondaryBroadcast": + child = Serialise.convert_symbol_from_json(json_data["children"][0]) + domain = json_data["broadcast_domain"] + return pybamm.SecondaryBroadcast(child, domain) + elif symbol_type == "BoundaryValue": + child = Serialise.convert_symbol_from_json(json_data["children"][0]) + side = json_data["side"] + return pybamm.BoundaryValue(child, side) + elif symbol_type == "Time": + return pybamm.t + elif symbol_type == "Variable": + bounds = tuple( + Serialise.convert_symbol_from_json(b) + for b in json_data.get("bounds", [-float("inf"), float("inf")]) + ) + return pybamm.Variable( + json_data["name"], + domains=json_data["domains"], + bounds=bounds, + ) + elif symbol_type == "SpatialVariable": + return pybamm.SpatialVariable( + json_data["name"], + coord_sys=json_data.get("coord_sys", "cartesian"), + domains=json_data.get("domains"), + ) + elif symbol_type == "IndefiniteIntegral": + child = Serialise.convert_symbol_from_json(json_data["children"][0]) + integration_var_json = json_data["integration_variable"] + integration_variable = Serialise.convert_symbol_from_json( + integration_var_json + ) + if not isinstance(integration_variable, pybamm.SpatialVariable): + raise TypeError( + f"Expected SpatialVariable, got {type(integration_variable)}" + ) + return pybamm.IndefiniteIntegral(child, [integration_variable]) + elif symbol_type == "Symbol": + return pybamm.Symbol( + json_data["name"], + domains=json_data.get("domains", {}), + ) + elif "children" in json_data: + return getattr(pybamm, symbol_type)( + *[Serialise.convert_symbol_from_json(c) for c in json_data["children"]] + ) + else: + raise ValueError(f"Unhandled symbol type or malformed entry: {json_data}") diff --git a/tests/unit/test_serialisation/test_serialisation.py b/tests/unit/test_serialisation/test_serialisation.py index 6f965c4bf6..7175fe08cd 100644 --- a/tests/unit/test_serialisation/test_serialisation.py +++ b/tests/unit/test_serialisation/test_serialisation.py @@ -4,14 +4,21 @@ import json import os +import re from datetime import datetime +from unittest.mock import mock_open, patch import numpy as np import pytest from numpy import testing import pybamm -from pybamm.expression_tree.operations.serialise import Serialise +from pybamm.expression_tree.operations.serialise import ( + SUPPORTED_SCHEMA_VERSION, + Serialise, +) +from pybamm.models.full_battery_models.lithium_ion.basic_dfn import BasicDFN +from pybamm.models.full_battery_models.lithium_ion.basic_spm import BasicSPM def scalar_var_dict(mocker): @@ -591,3 +598,787 @@ def test_serialised_model_plotting(self): # check dynamic plot loads new_solution.plot(show_plot=False) + + # testing custom models serilaisation and deserialisation + def test_serialise_scalar(self): + S = pybamm.Scalar(2.718) + j = Serialise.convert_symbol_to_json(S) + S2 = Serialise.convert_symbol_from_json(j) + assert isinstance(S2, pybamm.Scalar) + assert S2.value == pytest.approx(2.718) + + def test_serialise_time(self): + t = pybamm.Time() + j = Serialise.convert_symbol_to_json(t) + t2 = Serialise.convert_symbol_from_json(j) + assert isinstance(t2, pybamm.Time) + + def test_convert_symbol_to_json_with_number_and_list(self): + for val in (0, 3.14, -7, True): + out = Serialise.convert_symbol_to_json(val) + assert out is val or out == val + + sample = [1, 2, 3, "foo", 4.5] + out = Serialise.convert_symbol_to_json(sample) + assert out is sample + + def test_convert_symbol_from_json_with_primitives(self): + assert Serialise.convert_symbol_from_json(3.14) == 3.14 + assert Serialise.convert_symbol_from_json(42) == 42 + assert Serialise.convert_symbol_from_json(True) is True + + def test_convert_symbol_from_json_with_none(self): + assert Serialise.convert_symbol_from_json(None) is None + + def test_convert_symbol_from_json_unexpected_string(self): + with pytest.raises(ValueError, match=r"Unexpected raw string in JSON: foo"): + Serialise.convert_symbol_from_json("foo") + + def test_numpy_array_conversion(self): + arr = np.array([1, 2, 3]) + assert Serialise._json_encoder(arr) == [1, 2, 3] + + def test_numpy_float_conversion(self): + val1 = np.float32(2.71) + result1 = Serialise._json_encoder(val1) + assert result1 == float(val1) + assert isinstance(result1, float) + + val2 = np.float64(3.14) + result2 = Serialise._json_encoder(val2) + assert result2 == float(val2) + assert isinstance(result2, float) + + def test_numpy_int_conversion(self): + val1 = np.int32(42) + result1 = Serialise._json_encoder(val1) + assert result1 == int(val1) + assert isinstance(result1, int) + + val2 = np.int64(123) + result2 = Serialise._json_encoder(val2) + assert result2 == int(val2) + assert isinstance(result2, int) + + def test_unsupported_type_raises(self): + class Dummy: + pass + + with pytest.raises(TypeError, match="is not JSON serializable"): + Serialise._json_encoder(Dummy()) + + def test_create_symbol_key(self): + var1 = pybamm.Variable("x", bounds=(0, 1)) + var2 = pybamm.Variable("x", bounds=(0, 2)) + + json1 = Serialise.convert_symbol_to_json(var1) + json2 = Serialise.convert_symbol_to_json(var2) + + key1 = Serialise._create_symbol_key(json1) + key2 = Serialise._create_symbol_key(json2) + + assert isinstance(key1, str) + assert isinstance(key2, str) + assert key1 != key2 + + def test_primary_broadcast_serialisation(self): + child = pybamm.Scalar(42) + symbol = pybamm.PrimaryBroadcast(child, "negative electrode") + json_dict = Serialise.convert_symbol_to_json(symbol) + symbol2 = Serialise.convert_symbol_from_json(json_dict) + + assert isinstance(symbol2, pybamm.PrimaryBroadcast) + assert symbol2.broadcast_domain == ["negative electrode"] + assert isinstance(symbol2.orphans[0], pybamm.Scalar) + assert symbol2.orphans[0].value == 42 + + def test_interpolant_serialisation(self): + x = np.linspace(0, 1, 5) + y = np.array([0, 1, 4, 9, 16]) + child = pybamm.Variable("z") + interp = pybamm.Interpolant( + x, y, child, name="test_interplot", interpolator="linear" + ) + json_dict = Serialise.convert_symbol_to_json(interp) + interp2 = Serialise.convert_symbol_from_json(json_dict) + + assert isinstance(interp2, pybamm.Interpolant) + assert interp2.name == "test_interplot" + assert interp2.interpolator == "linear" + assert isinstance(interp2.x[0], np.ndarray) + assert isinstance(interp2.y, np.ndarray) + assert interp2.children[0].name == "z" + + def test_variable_serialisation(self): + var = pybamm.Variable("var", domain="separator") + json_dict = Serialise.convert_symbol_to_json(var) + var2 = Serialise.convert_symbol_from_json(json_dict) + + assert isinstance(var2, pybamm.Variable) + assert var2.name == "var" + assert var2.domains["primary"] == ["separator"] + assert var2.bounds[0].value == -float("inf") + assert var2.bounds[1].value == float("inf") + + def test_concatenation_variable_serialisation(self): + var1 = pybamm.Variable("a", domain="negative electrode") + var2 = pybamm.Variable("a", domain="separator") + var3 = pybamm.Variable("a", domain="positive electrode") + concat_var = pybamm.ConcatenationVariable(var1, var2, var3, name="conc_var") + json_dict = Serialise.convert_symbol_to_json(concat_var) + concat_var2 = Serialise.convert_symbol_from_json(json_dict) + + assert isinstance(concat_var2, pybamm.ConcatenationVariable) + assert concat_var2.name == "a" + assert len(concat_var2.children) == 3 + domains = [child.domains["primary"] for child in concat_var2.children] + assert domains == [ + ["negative electrode"], + ["separator"], + ["positive electrode"], + ] + + def test_full_broadcast_serialisation(self): + child = pybamm.Scalar(5) + fb = pybamm.FullBroadcast( + child, + "negative electrode", + {"primary": ["negative electrode"], "secondary": ["current collector"]}, + ) + json_dict = Serialise.convert_symbol_to_json(fb) + fb2 = Serialise.convert_symbol_from_json(json_dict) + + assert isinstance(fb2, pybamm.FullBroadcast) + assert fb2.broadcast_domain == ["negative electrode"] + assert fb2.domains["primary"] == ["negative electrode"] + assert fb2.domains["secondary"] == ["current collector"] + assert isinstance(fb2.child, pybamm.Scalar) + assert fb2.child.value == 5 + + def test_secondary_broadcast_serialisation(self): + child = pybamm.Variable("c", domain="negative electrode") + sb = pybamm.SecondaryBroadcast(child, "current collector") + + json_dict = Serialise.convert_symbol_to_json(sb) + sb2 = Serialise.convert_symbol_from_json(json_dict) + + assert isinstance(sb2, pybamm.SecondaryBroadcast) + assert sb2.broadcast_domain == ["current collector"] + assert sb2.child.name == "c" + assert sb2.child.domain == ["negative electrode"] + + def test_spatial_variable_serialisation(self): + sv = pybamm.SpatialVariable( + "x", domain="negative electrode", coord_sys="cartesian" + ) + json_dict = Serialise.convert_symbol_to_json(sv) + sv2 = Serialise.convert_symbol_from_json(json_dict) + + assert isinstance(sv2, pybamm.SpatialVariable) + assert sv2.name == "x" + assert sv2.domains["primary"] == ["negative electrode"] + assert sv2.coord_sys == "cartesian" + + def test_boundary_value_serialisation(self): + var = pybamm.SpatialVariable("x", domain="electrode") + bv = pybamm.BoundaryValue(var, "left") + json_dict = Serialise.convert_symbol_to_json(bv) + bv2 = Serialise.convert_symbol_from_json(json_dict) + + assert isinstance(bv2, pybamm.BoundaryValue) + assert bv2.side == "left" + assert isinstance(bv2.orphans[0], pybamm.SpatialVariable) + assert bv2.orphans[0].name == "x" + + def test_specific_function_not_supported(self): + def dummy_func(x): + return x + + symbol = pybamm.SpecificFunction(dummy_func, pybamm.Scalar(1)) + with pytest.raises( + NotImplementedError, match="SpecificFunction is not supported directly" + ): + Serialise.convert_symbol_to_json(symbol) + + def test_unary_operator_serialisation(self): + expr = pybamm.Negate(pybamm.Scalar(5)) + json_dict = Serialise.convert_symbol_to_json(expr) + expr2 = Serialise.convert_symbol_from_json(json_dict) + + assert isinstance(expr2, pybamm.Negate) + assert isinstance(expr2.child, pybamm.Scalar) + assert expr2.child.value == 5 + + def test_binary_operator_serialisation(self): + expr = pybamm.Addition(pybamm.Scalar(2), pybamm.Scalar(3)) + json_dict = Serialise.convert_symbol_to_json(expr) + expr2 = Serialise.convert_symbol_from_json(json_dict) + + assert isinstance(expr2, pybamm.Addition) + values = [c.value for c in expr2.children] + assert values == [2, 3] + + def test_function_parameter_with_diff_variable_serialisation(self): + x = pybamm.Variable("x") + diff_var = pybamm.Variable("r") + func_param = pybamm.FunctionParameter("my_func", {"x": x}, diff_var) + + json_dict = Serialise.convert_symbol_to_json(func_param) + assert "diff_variable" in json_dict + assert json_dict["diff_variable"]["type"] == "Variable" + assert json_dict["diff_variable"]["name"] == "r" + + expr2 = Serialise.convert_symbol_from_json(json_dict) + assert isinstance(expr2, pybamm.FunctionParameter) + assert expr2.diff_variable.name == "r" + assert expr2.name == "my_func" + assert list(expr2.input_names) == ["x"] + + def test_indefinite_integral_serialisation(self): + x = pybamm.SpatialVariable("x", domain="negative electrode") + ind_int = pybamm.IndefiniteIntegral(x, x) + + json_dict = Serialise.convert_symbol_to_json(ind_int) + assert json_dict["type"] == "IndefiniteIntegral" + + assert ( + isinstance(json_dict["children"], list) and len(json_dict["children"]) == 1 + ) + child_json = json_dict["children"][0] + assert child_json["type"] == "SpatialVariable" + assert child_json["name"] == "x" + + int_var_json = json_dict["integration_variable"] + assert int_var_json["type"] == "SpatialVariable" + assert int_var_json["name"] == "x" + + expr2 = Serialise.convert_symbol_from_json(json_dict) + assert isinstance(expr2, pybamm.IndefiniteIntegral) + assert isinstance(expr2.child, pybamm.SpatialVariable) + + assert expr2.child.name == "x" + assert isinstance(expr2.integration_variable, list) + assert len(expr2.integration_variable) == 1 + assert isinstance(expr2.integration_variable[0], pybamm.SpatialVariable) + assert expr2.integration_variable[0].name == "x" + + def test_symbol_fallback_serialisation(self): + var = pybamm.Variable("v", domain="electrode") + diff = pybamm.Gradient(var) + json_dict = Serialise.convert_symbol_to_json(diff) + diff2 = Serialise.convert_symbol_from_json(json_dict) + + assert isinstance(diff2, pybamm.Gradient) + assert isinstance(diff2.children[0], pybamm.Variable) + assert diff2.children[0].name == "v" + assert diff2.children[0].domains["primary"] == ["electrode"] + + def test_unhandled_symbol_type_error(self): + class NotSymbol: + def __init__(self): + self.name = "not_a_symbol" + + dummy = NotSymbol() + with pytest.raises(ValueError) as e: + Serialise.convert_symbol_to_json(dummy) + + assert "Error processing 'not_a_symbol'. Unknown symbol type:" in str(e.value) + + def test_deserialising_unhandled_type(self): + unhandled_json = {"type": "NotARealSymbol", "foo": "bar"} + with pytest.raises( + ValueError, + match=r"Unhandled symbol type or malformed entry: .*NotARealSymbol", + ): + Serialise.convert_symbol_from_json(unhandled_json) + + unhandled_json2 = {"a": 1, "b": 2} + with pytest.raises( + ValueError, match=r"Unhandled symbol type or malformed entry: .*" + ): + Serialise.convert_symbol_from_json(unhandled_json2) + + def test_file_write_raises_ioerror(self): + # testing behaviour when file system is read-only to raise exception + model = pybamm.lithium_ion.SPM() + + with patch("builtins.open", mock_open()) as file: + file.side_effect = OSError("file system is read-only") + + with pytest.raises( + ValueError, + match="Failed to save custom model: Failed to write model JSON to file", + ): + Serialise.save_custom_model(model, "readonly_test") + + def test_symbol_conversion_failure_raises_value_error(self): + model = pybamm.BaseModel() + model.name = "TestModel" + model.rhs = {pybamm.Variable("c"): pybamm.Variable("c")} + + with patch.object( + Serialise, + "convert_symbol_to_json", + side_effect=Exception("conversion failed"), + ): + with pytest.raises( + ValueError, match="Failed to save custom model: conversion failed" + ): + Serialise.save_custom_model(model, "conversion_fail") + + def test_unsupported_schema_version(self): + unhandled_schema_json = { + "schema_version": "9.9", # Unsupported + "pybamm_version": pybamm.__version__, + "name": "BadModel", + "rhs": [], + "algebraic": [], + "initial_conditions": [], + "boundary_conditions": [], + "events": [], + "variables": {}, + } + + file = "model.json" + + with open(file, "w") as f: + json.dump(unhandled_schema_json, f) + + try: + with pytest.raises(ValueError, match="Unsupported schema version: 9.9"): + Serialise.load_custom_model(file, battery_model=pybamm.BaseModel()) + finally: + os.remove(file) + + def test_model_has_correct_schema_version(self): + model = BasicDFN() + filename = "test_scehma_version" + + Serialise.save_custom_model(model, filename=filename) + loaded_model = Serialise.load_custom_model( + f"{filename}.json", battery_model=pybamm.lithium_ion.BaseModel() + ) + + try: + assert hasattr(loaded_model, "schema_version") + assert loaded_model.schema_version == SUPPORTED_SCHEMA_VERSION + finally: + # Clean up + os.remove(f"{filename}.json") + + def test_load_invalid_json(self): + invalid_json = "{ invalid json" + with patch("builtins.open", mock_open(read_data=invalid_json)): + with pytest.raises(ValueError) as e: + Serialise.load_custom_model("invalid_json.json") + assert "Invalid JSON in file" in str(e.value) + + def test_load_custom_model_file_not_found(self): + with pytest.raises(FileNotFoundError) as e: + Serialise.load_custom_model("non_existent_file.json") + assert "Could not find file" in str(e.value) + + def test_invalid_symbol_key_raises_value_error(self): + # Malformed LHS (invalid symbol type) + bad_lhs = {"not_a_valid_symbol": 123} + rhs_expr = {"type": "Scalar", "value": 1.0} + + model_json = { + "schema_version": "1.0", + "pybamm_version": pybamm.__version__, + "name": "BadSymbolKeyModel", + "rhs": [[bad_lhs, rhs_expr]], + "algebraic": [], + "initial_conditions": [], + "boundary_conditions": [], + "events": [], + "variables": {}, + } + + file = "model.json" + + with open(file, "w") as f: + json.dump(model_json, f) + + with pytest.raises(ValueError) as e: + Serialise.load_custom_model(str(file)) + + msg = str(e.value).lower() + assert "failed to process symbol key for variable" in msg + assert "unhandled symbol type or malformed entry" in msg + os.remove(file) + + def test_save_raises_for_missing_sections(self): + class DummyModelMissing: + # e.g. only has rhs and algebraic + def __init__(self): + self.rhs = {} + self.algebraic = {} + + m = DummyModelMissing() + with pytest.raises(AttributeError) as e: + Serialise.save_custom_model(m, filename="irrelevant") + msg = str(e.value) + assert "missing required sections" in msg.lower() + assert any( + section in msg for section in ["initial_conditions", "events", "variables"] + ) + + def test_model_with_missing_json_sections(self): + model_json = { + "schema_version": "1.0", + "pybamm_version": pybamm.__version__, + "name": "BadModel", + "algebraic": [], + "initial_conditions": [], + } + file = "model1.json" + + with open(file, "w") as f: + json.dump(model_json, f) + + with pytest.raises(KeyError) as e: + Serialise.load_custom_model(str(file)) + + msg = str(e.value).lower() + for missing_section in ["rhs", "boundary_conditions", "events", "variables"]: + assert missing_section in msg, ( + f"Error message should mention missing '{missing_section}'" + ) + os.remove(file) + + def test_invalid_rhs_entry_raises_value_error(self): + # Build JSON with all required keys, but rhs has a bad entry + good_lhs = { + "type": "Variable", + "name": "x", + "domains": {}, + } + bad_rhs = {"this_will_fail": True} + + # 2) Build JSON with all required keys + model_json = { + "schema_version": "1.0", + "pybamm_version": pybamm.__version__, + "name": "BadModel", + # One valid pair in RHS + "rhs": [[good_lhs, bad_rhs]], + "algebraic": [], + "initial_conditions": [], + "boundary_conditions": [], + "events": [], + "variables": {}, + } + file = "model2.json" + + with open(file, "w") as f: + json.dump(model_json, f) + + with pytest.raises(ValueError) as e: + Serialise.load_custom_model(str(file)) + + msg = str(e.value).lower() + assert "failed to convert rhs" in msg + assert "unhandled symbol type or malformed entry" in msg + os.remove(file) + + def test_invalid_algebraic_entry_raises_value_error(self): + # Build JSON with all required keys, but rhs has a bad entry + good_lhs = { + "type": "Variable", + "name": "x", + "domains": {}, + } + bad_rhs = {"this_will_fail": True} + + # 2) Build JSON with all required keys + model_json = { + "schema_version": "1.0", + "pybamm_version": pybamm.__version__, + "name": "BadModel", + # One valid pair in RHS + "rhs": [], + "algebraic": [[good_lhs, bad_rhs]], + "initial_conditions": [], + "boundary_conditions": [], + "events": [], + "variables": {}, + } + file = "model3.json" + + with open(file, "w") as f: + json.dump(model_json, f) + + with pytest.raises(ValueError) as e: + Serialise.load_custom_model(str(file)) + + msg = str(e.value).lower() + assert "failed to convert algebraic" in msg + assert "unhandled symbol type or malformed entry" in msg + os.remove(file) + + def test_invalid_initial_conditions_entry_raises_value_error(self): + # Build JSON with all required keys, but rhs has a bad entry + good_lhs = { + "type": "Variable", + "name": "x", + "domains": {}, + } + bad_rhs = {"this_will_fail": True} + + # 2) Build JSON with all required keys + model_json = { + "schema_version": "1.0", + "pybamm_version": pybamm.__version__, + "name": "BadModel", + # One valid pair in RHS + "rhs": [], + "algebraic": [], + "initial_conditions": [[good_lhs, bad_rhs]], + "boundary_conditions": [], + "events": [], + "variables": {}, + } + file = "model4.json" + + with open(file, "w") as f: + json.dump(model_json, f) + + with pytest.raises(ValueError) as e: + Serialise.load_custom_model(str(file)) + + msg = str(e.value).lower() + assert "failed to convert initial condition" in msg + assert "unhandled symbol type or malformed entry" in msg + os.remove(file) + + def test_invalid_boundary_conditions_raise_value_error(self): + good_variable = { + "type": "Variable", + "name": "x", + "domains": {}, + } + + # Malformed RHS: missing tuple structure + bad_condition_dict = { + "left": { + "this_is_not_valid": True + }, # Should be (expression_json, boundary_type) + } + + model_json = { + "schema_version": "1.0", + "pybamm_version": pybamm.__version__, + "name": "BadBoundaryModel", + "rhs": [], + "algebraic": [], + "initial_conditions": [], + "boundary_conditions": [[good_variable, bad_condition_dict]], + "events": [], + "variables": {}, + "all_variable_keys": [good_variable], + } + + file = "model5.json" + + with open(file, "w") as f: + json.dump(model_json, f) + + # Expect the load to raise a ValueError + with pytest.raises(ValueError) as e: + Serialise.load_custom_model(str(file)) + + msg = str(e.value).lower() + assert "failed to convert boundary" in msg + assert " not enough values to unpack" in msg + os.remove(file) + + # Valid variable + variable_json = { + "type": "Variable", + "name": "c", + "domains": {}, + } + + invalid_expression_json = "not_a_valid_expression" + + condition_dict = {"left": (invalid_expression_json, "Dirichlet")} + + model_data = { + "schema_version": "1.0", + "pybamm_version": pybamm.__version__, + "name": "BadBoundaryExpressionModel", + "rhs": [], + "algebraic": [], + "initial_conditions": [], + "boundary_conditions": [[variable_json, condition_dict]], + "events": [], + "variables": {}, + "all_variable_keys": [variable_json], + } + + model_file = "bad_boundary_expr.json" + with open(model_file, "w") as f: + json.dump(model_data, f) + + with pytest.raises(ValueError) as e: + Serialise.load_custom_model(str(model_file)) + + msg = str(e.value) + assert "Failed to convert boundary expression for variable" in msg + assert "left" in msg + assert "not_a_valid_expression" in msg or "Invalid" in msg + os.remove(model_file) + + def test_event_conversion_failure(self): + model_data = { + "schema_version": "1.0", + "pybamm_version": pybamm.__version__, + "name": "BadEventModel", + "rhs": [], + "algebraic": [], + "initial_conditions": [], + "boundary_conditions": [], + "variables": {}, + "events": [ + { + "name": "Bad Event", + "expression": {"bad": "structure"}, # malformed + "event_type": "termination", + } + ], + } + + file = "bad_event_model.json" + with open(file, "w") as f: + json.dump(model_data, f) + + with pytest.raises(ValueError) as e: + Serialise.load_custom_model(str(file)) + + msg = str(e.value).lower() + assert "failed to convert event 'bad event'" in msg + assert "unhandled symbol type or malformed entry" in msg + os.remove(file) + + def test_variable_conversion_failure(tmp_path): + model_data = { + "schema_version": "1.0", + "pybamm_version": pybamm.__version__, + "name": "BadVariableModel", + "rhs": [], + "algebraic": [], + "initial_conditions": [], + "boundary_conditions": [], + "events": [], + "variables": {"Bad Variable": {"bad": "structure"}}, + } + + file = "bad_variable_model.json" + with open(file, "w") as f: + json.dump(model_data, f) + + with pytest.raises(ValueError) as e: + Serialise.load_custom_model(str(file)) + + msg = str(e.value).lower() + assert "failed to convert variable 'bad variable'" in msg + assert "unhandled symbol type or malformed entry" in msg + os.remove(file) + + def test_save_and_load_custom_model(self): + model = pybamm.BaseModel(name="test_model") + a = pybamm.Variable("a", domain="electrode") + b = pybamm.Variable("b", domain="electrode") + model.rhs = {a: b} + model.initial_conditions = {a: pybamm.Scalar(1)} + model.algebraic = {} + model.boundary_conditions = {a: {"left": (pybamm.Scalar(0), "Dirichlet")}} + model.events = [pybamm.Event("terminal", pybamm.Scalar(1) - b, "TERMINATION")] + model.variables = {"a": a, "b": b} + + # save model + Serialise.save_custom_model(model, filename="test_model") + + # check json exists + assert os.path.exists("test_model.json") + + # saving with defualt filename + Serialise().save_custom_model(model) + pattern = r"test_model_\d{4}_\d{2}_\d{2}_\d{2}_\d{2}_\d{2}\.json" + matched = [f for f in os.listdir(".") if re.fullmatch(pattern, f)] + assert matched + + for f in matched: + os.remove(f) + + # load model + loaded_model = Serialise.load_custom_model("test_model.json") + os.remove("test_model.json") + + assert loaded_model.name == "test_model" + assert isinstance(loaded_model.rhs, dict) + assert next(iter(loaded_model.rhs.keys())).name == "a" + assert next(iter(loaded_model.rhs.values())).name == "b" + + def test_plotting_serialised_models(self): + models = [ + BasicSPM(), + BasicDFN(), + pybamm.lithium_ion.SPM(), + pybamm.lithium_ion.DFN(), + ] + filenames = ["basic_spm", "basic_dfn", "spm", "dfn"] + + for model, name in zip(models, filenames, strict=True): + # Save the model + Serialise.save_custom_model(model, filename=name) + + # Load the model + loaded_model = Serialise.load_custom_model( + f"{name}.json", battery_model=pybamm.lithium_ion.BaseModel() + ) + + sim = pybamm.Simulation(loaded_model) + sim.solve([0, 3600]) + sim.plot(show_plot=False) + + os.remove(f"{name}.json") + + def test_models_and_params(self): + import numpy as np + + import pybamm + from pybamm.expression_tree.operations.serialise import Serialise + + param = pybamm.ParameterValues( + "Marquis2019" + ) # or another chemistry like "Chen2020" + + Serialise.save_parameters(param, filename="my_params.json") + + model = pybamm.lithium_ion.SPM() + Serialise.save_custom_model(model, "my_model.json") + model2 = Serialise.load_custom_model( + "my_model.json", battery_model=pybamm.lithium_ion.BaseModel() + ) + + # Load parameter values + param_dict = Serialise.load_parameters("my_params.json") + param3 = pybamm.ParameterValues(param_dict) + + param3.update({"Electrode height [m]": "[input]"}) + + # Load the model + model2 = Serialise.load_custom_model( + "my_model.json", battery_model=pybamm.lithium_ion.BaseModel() + ) + + # ✅ Process the model AFTER update() so the necessary .input_* keys get created + param3.process_model(model2) + + # Create simulation + t_eval = np.linspace(0, 60, 11) + inputs = {"Electrode height [m]": 0.2} + sim = pybamm.Simulation(model=model2, parameter_values=param3) + sim.solve(t_eval=t_eval, inputs=inputs)