From 4a24e0a1365830d905fbabdaa4bf9b4af1564d92 Mon Sep 17 00:00:00 2001 From: medha-14 Date: Sun, 19 Jan 2025 09:34:38 +0530 Subject: [PATCH 01/14] added plot functionality --- src/pybamm/plotting/quick_plot.py | 198 +++++++++++++++++++----------- 1 file changed, 129 insertions(+), 69 deletions(-) diff --git a/src/pybamm/plotting/quick_plot.py b/src/pybamm/plotting/quick_plot.py index babfd2e761..9ef02cae71 100644 --- a/src/pybamm/plotting/quick_plot.py +++ b/src/pybamm/plotting/quick_plot.py @@ -109,6 +109,7 @@ def __init__( spatial_unit="um", variable_limits="fixed", n_t_linear=100, + x_axis="Time", ): solutions = self.preprocess_solutions(solutions) @@ -168,52 +169,72 @@ def __init__( else: raise ValueError(f"spatial unit '{spatial_unit}' not recognized") - # Time parameters - self.ts_seconds = [solution.t for solution in solutions] - min_t = np.min([t[0] for t in self.ts_seconds]) - max_t = np.max([t[-1] for t in self.ts_seconds]) - - hermite_interp = all(sol.hermite_interpolation for sol in solutions) - - def t_sample(sol): - if hermite_interp and n_t_linear > 2: - # Linearly spaced time points - t_linspace = np.linspace(sol.t[0], sol.t[-1], n_t_linear + 2)[1:-1] - t_plot = np.union1d(sol.t, t_linspace) - else: - t_plot = sol.t - return t_plot - - ts_seconds = [] - for sol in solutions: - # Sample time points for each sub-solution - t_sol = [t_sample(sub_sol) for sub_sol in sol.sub_solutions] - ts_seconds.append(np.concatenate(t_sol)) - self.ts_seconds = ts_seconds - - # Set timescale - if time_unit is None: - # defaults depend on how long the simulation is - if max_t >= 3600: - time_scaling_factor = 3600 # time in hours + # Set time or discharge capacity as x-axis + if x_axis == "Discharge Capacity [A.h]": + print("yay") + # Use discharge capacity as x-axis + self.x_axis = "Discharge capacity [A.h]" + + # Extract discharge capacities for all solutions + discharge_capacities = [ + solution["Discharge capacity [A.h]"].entries for solution in solutions + ] + self.dc_values = discharge_capacities # Store as the x-axis values + + # Set discharge capacity range + self.min_dc = min(dc[0] for dc in discharge_capacities) + self.max_dc = max(dc[-1] for dc in discharge_capacities) + + # Scaling and unit specific to discharge capacity + self.dc_scaling_factor = 1 # No scaling needed for discharge capacity + self.dc_unit = "A.h" + else: + # Default to time + self.ts_seconds = [solution.t for solution in solutions] + min_t = np.min([t[0] for t in self.ts_seconds]) + max_t = np.max([t[-1] for t in self.ts_seconds]) + + hermite_interp = all(sol.hermite_interpolation for sol in solutions) + + def t_sample(sol): + if hermite_interp and n_t_linear > 2: + # Linearly spaced time points + t_linspace = np.linspace(sol.t[0], sol.t[-1], n_t_linear + 2)[1:-1] + t_plot = np.union1d(sol.t, t_linspace) + else: + t_plot = sol.t + return t_plot + + ts_seconds = [] + for sol in solutions: + # Sample time points for each sub-solution + t_sol = [t_sample(sub_sol) for sub_sol in sol.sub_solutions] + ts_seconds.append(np.concatenate(t_sol)) + self.ts_seconds = ts_seconds + + # Set timescale + if time_unit is None: + # defaults depend on how long the simulation is + if max_t >= 3600: + time_scaling_factor = 3600 # time in hours + self.time_unit = "h" + else: + time_scaling_factor = 1 # time in seconds + self.time_unit = "s" + elif time_unit == "seconds": + time_scaling_factor = 1 + self.time_unit = "s" + elif time_unit == "minutes": + time_scaling_factor = 60 + self.time_unit = "min" + elif time_unit == "hours": + time_scaling_factor = 3600 self.time_unit = "h" else: - time_scaling_factor = 1 # time in seconds - self.time_unit = "s" - elif time_unit == "seconds": - time_scaling_factor = 1 - self.time_unit = "s" - elif time_unit == "minutes": - time_scaling_factor = 60 - self.time_unit = "min" - elif time_unit == "hours": - time_scaling_factor = 3600 - self.time_unit = "h" - else: - raise ValueError(f"time unit '{time_unit}' not recognized") - self.time_scaling_factor = time_scaling_factor - self.min_t = min_t / time_scaling_factor - self.max_t = max_t / time_scaling_factor + raise ValueError(f"time unit '{time_unit}' not recognized") + self.time_scaling_factor = time_scaling_factor + self.min_t = min_t / time_scaling_factor + self.max_t = max_t / time_scaling_factor # Prepare dictionary of variables # output_variables is a list of strings or lists, e.g. @@ -520,8 +541,12 @@ def plot(self, t, dynamic=False): variable_handles = [] # Set labels for the first subplot only (avoid repetition) if variable_lists[0][0].dimensions == 0: - # 0D plot: plot as a function of time, indicating time t with a line - ax.set_xlabel(f"Time [{self.time_unit}]") + if self.x_axis == "Time": + # 0D plot: plot as a function of time, indicating time t with a line + ax.set_xlabel(f"Time [{self.time_unit}]") + elif self.x_axis == "Discharge capacity [A.h]": + ax.set_xlabel(f"Discharge Capacity [{self.dc_unit}]") + for i, variable_list in enumerate(variable_lists): for j, variable in enumerate(variable_list): if len(variable_list) == 1: @@ -531,13 +556,24 @@ def plot(self, t, dynamic=False): # multiple variables -> use linestyle to differentiate # variables (color differentiates models) linestyle = self.linestyles[j] - full_t = self.ts_seconds[i] - (self.plots[key][i][j],) = ax.plot( - full_t / self.time_scaling_factor, - variable(full_t), - color=self.colors[i], - linestyle=linestyle, - ) + + if self.x_axis[:4] == "Time": + full_t = self.ts_seconds[i] + (self.plots[key][i][j],) = ax.plot( + full_t / self.time_scaling_factor, + variable(full_t), + color=self.colors[i], + linestyle=linestyle, + ) + elif self.x_axis == "Discharge capacity [A.h]": + full_dc = self.dc_values[i] + (self.plots[key][i][j],) = ax.plot( + full_dc / self.dc_scaling_factor, + variable(full_dc), + color=self.colors[i], + linestyle=linestyle, + ) + variable_handles.append(self.plots[key][0][j]) solution_handles.append(self.plots[key][i][0]) y_min, y_max = ax.get_ylim() @@ -668,13 +704,13 @@ def plot(self, t, dynamic=False): def dynamic_plot(self, show_plot=True, step=None): """ - Generate a dynamic plot with a slider to control the time. + Generate a dynamic plot with a slider to control the x-axis. Parameters ---------- step : float, optional For notebook mode, size of steps to allow in the slider. Defaults to 1/100th - of the total time. + of the total range (time or discharge capacity). show_plot : bool, optional Whether to show the plots. Default is True. Set to False if you want to only display the plot after plt.show() has been called. @@ -683,29 +719,53 @@ def dynamic_plot(self, show_plot=True, step=None): if pybamm.is_notebook(): # pragma: no cover import ipywidgets as widgets - step = step or self.max_t / 100 - widgets.interact( - lambda t: self.plot(t, dynamic=False), - t=widgets.FloatSlider( - min=self.min_t, max=self.max_t, step=step, value=self.min_t - ), - continuous_update=False, - ) + # Determine step size based on x-axis + if self.x_axis == "Discharge capacity [A.h]": + step = step or (self.max_dc - self.min_dc) / 100 + widgets.interact( + lambda dc: self.plot(dc, dynamic=False), + dc=widgets.FloatSlider( + min=self.min_dc, + max=self.max_dc, + step=step, + value=self.min_dc, + ), + continuous_update=False, + ) + else: # Default to time + step = step or self.max_t / 100 + widgets.interact( + lambda t: self.plot(t, dynamic=False), + t=widgets.FloatSlider( + min=self.min_t, + max=self.max_t, + step=step, + value=self.min_t, + ), + continuous_update=False, + ) else: plt = import_optional_dependency("matplotlib.pyplot") Slider = import_optional_dependency("matplotlib.widgets", "Slider") - # create an initial plot at time self.min_t - self.plot(self.min_t, dynamic=True) + # Set initial x-axis values and slider + if self.x_axis == "Discharge capacity [A.h]": + self.plot(self.min_dc, dynamic=True) + ax_label = f"Discharge capacity [{self.time_unit}]" # Update time_unit to relevant unit + ax_min, ax_max, val_init = self.min_dc, self.max_dc, self.min_dc + else: # Default to time + self.plot(self.min_t, dynamic=True) + ax_label = f"Time [{self.time_unit}]" + ax_min, ax_max, val_init = self.min_t, self.max_t, self.min_t axcolor = "lightgoldenrodyellow" ax_slider = plt.axes([0.315, 0.02, 0.37, 0.03], facecolor=axcolor) self.slider = Slider( ax_slider, - f"Time [{self.time_unit}]", - self.min_t, - self.max_t, - valinit=self.min_t, + ax_label, + ax_min, + ax_max, + valinit=val_init, color="#1f77b4", ) self.slider.on_changed(self.slider_update) From 62bea1e78788ec236eb03ce34d1e09991685bb80 Mon Sep 17 00:00:00 2001 From: medha-14 Date: Sun, 19 Jan 2025 10:19:00 +0530 Subject: [PATCH 02/14] few changes --- src/pybamm/plotting/quick_plot.py | 151 ++++++++++++++++-------------- 1 file changed, 83 insertions(+), 68 deletions(-) diff --git a/src/pybamm/plotting/quick_plot.py b/src/pybamm/plotting/quick_plot.py index 9ef02cae71..2d0fbe7313 100644 --- a/src/pybamm/plotting/quick_plot.py +++ b/src/pybamm/plotting/quick_plot.py @@ -175,66 +175,61 @@ def __init__( # Use discharge capacity as x-axis self.x_axis = "Discharge capacity [A.h]" - # Extract discharge capacities for all solutions discharge_capacities = [ solution["Discharge capacity [A.h]"].entries for solution in solutions ] - self.dc_values = discharge_capacities # Store as the x-axis values + self.dc_values = discharge_capacities - # Set discharge capacity range self.min_dc = min(dc[0] for dc in discharge_capacities) self.max_dc = max(dc[-1] for dc in discharge_capacities) - # Scaling and unit specific to discharge capacity - self.dc_scaling_factor = 1 # No scaling needed for discharge capacity self.dc_unit = "A.h" - else: - # Default to time - self.ts_seconds = [solution.t for solution in solutions] - min_t = np.min([t[0] for t in self.ts_seconds]) - max_t = np.max([t[-1] for t in self.ts_seconds]) - - hermite_interp = all(sol.hermite_interpolation for sol in solutions) - - def t_sample(sol): - if hermite_interp and n_t_linear > 2: - # Linearly spaced time points - t_linspace = np.linspace(sol.t[0], sol.t[-1], n_t_linear + 2)[1:-1] - t_plot = np.union1d(sol.t, t_linspace) - else: - t_plot = sol.t - return t_plot - - ts_seconds = [] - for sol in solutions: - # Sample time points for each sub-solution - t_sol = [t_sample(sub_sol) for sub_sol in sol.sub_solutions] - ts_seconds.append(np.concatenate(t_sol)) - self.ts_seconds = ts_seconds - - # Set timescale - if time_unit is None: - # defaults depend on how long the simulation is - if max_t >= 3600: - time_scaling_factor = 3600 # time in hours - self.time_unit = "h" - else: - time_scaling_factor = 1 # time in seconds - self.time_unit = "s" - elif time_unit == "seconds": - time_scaling_factor = 1 - self.time_unit = "s" - elif time_unit == "minutes": - time_scaling_factor = 60 - self.time_unit = "min" - elif time_unit == "hours": - time_scaling_factor = 3600 + + # Default to time + self.ts_seconds = [solution.t for solution in solutions] + min_t = np.min([t[0] for t in self.ts_seconds]) + max_t = np.max([t[-1] for t in self.ts_seconds]) + + hermite_interp = all(sol.hermite_interpolation for sol in solutions) + + def t_sample(sol): + if hermite_interp and n_t_linear > 2: + # Linearly spaced time points + t_linspace = np.linspace(sol.t[0], sol.t[-1], n_t_linear + 2)[1:-1] + t_plot = np.union1d(sol.t, t_linspace) + else: + t_plot = sol.t + return t_plot + ts_seconds = [] + for sol in solutions: + # Sample time points for each sub-solution + t_sol = [t_sample(sub_sol) for sub_sol in sol.sub_solutions] + ts_seconds.append(np.concatenate(t_sol)) + self.ts_seconds = ts_seconds + + # Set timescale + if time_unit is None: + # defaults depend on how long the simulation is + if max_t >= 3600: + time_scaling_factor = 3600 # time in hours self.time_unit = "h" else: - raise ValueError(f"time unit '{time_unit}' not recognized") - self.time_scaling_factor = time_scaling_factor - self.min_t = min_t / time_scaling_factor - self.max_t = max_t / time_scaling_factor + time_scaling_factor = 1 # time in seconds + self.time_unit = "s" + elif time_unit == "seconds": + time_scaling_factor = 1 + self.time_unit = "s" + elif time_unit == "minutes": + time_scaling_factor = 60 + self.time_unit = "min" + elif time_unit == "hours": + time_scaling_factor = 3600 + self.time_unit = "h" + else: + raise ValueError(f"time unit '{time_unit}' not recognized") + self.time_scaling_factor = time_scaling_factor + self.min_t = min_t / time_scaling_factor + self.max_t = max_t / time_scaling_factor # Prepare dictionary of variables # output_variables is a list of strings or lists, e.g. @@ -435,8 +430,12 @@ def reset_axis(self): self.axis_limits = {} for key, variable_lists in self.variables.items(): if variable_lists[0][0].dimensions == 0: - x_min = self.min_t - x_max = self.max_t + if self.x_axis == "Discharge capacity [A.h]": + x_min = self.min_dc + x_max = self.max_dc + else: + x_min = self.min_t + x_max = self.max_t elif variable_lists[0][0].dimensions == 1: x_min = self.first_spatial_variable[key][0] x_max = self.first_spatial_variable[key][-1] @@ -458,22 +457,38 @@ def reset_axis(self): # Get min and max variable values if self.variable_limits[key] == "fixed": - # fixed variable limits: calculate "globlal" min and max + # fixed variable limits: calculate "global" min and max spatial_vars = self.spatial_variable_dict[key] - var_min = np.min( - [ - ax_min(var(self.ts_seconds[i], **spatial_vars)) - for i, variable_list in enumerate(variable_lists) - for var in variable_list - ] - ) - var_max = np.max( - [ - ax_max(var(self.ts_seconds[i], **spatial_vars)) - for i, variable_list in enumerate(variable_lists) - for var in variable_list - ] - ) + if self.x_axis == "Discharge capacity [A.h]": + var_min = np.min( + [ + ax_min(var(self.dc_values[i], **spatial_vars)) + for i, variable_list in enumerate(variable_lists) + for var in variable_list + ] + ) + var_max = np.max( + [ + ax_max(var(self.dc_values[i], **spatial_vars)) + for i, variable_list in enumerate(variable_lists) + for var in variable_list + ] + ) + else: + var_min = np.min( + [ + ax_min(var(self.ts_seconds[i], **spatial_vars)) + for i, variable_list in enumerate(variable_lists) + for var in variable_list + ] + ) + var_max = np.max( + [ + ax_max(var(self.ts_seconds[i], **spatial_vars)) + for i, variable_list in enumerate(variable_lists) + for var in variable_list + ] + ) if np.isnan(var_min) or np.isnan(var_max): raise ValueError( "The variable limits are set to 'fixed' but the min and max " @@ -568,7 +583,7 @@ def plot(self, t, dynamic=False): elif self.x_axis == "Discharge capacity [A.h]": full_dc = self.dc_values[i] (self.plots[key][i][j],) = ax.plot( - full_dc / self.dc_scaling_factor, + full_dc, variable(full_dc), color=self.colors[i], linestyle=linestyle, From 5e0b742835e3017c14a200f128cbe5b368c8b78d Mon Sep 17 00:00:00 2001 From: medha-14 Date: Sun, 19 Jan 2025 11:12:46 +0530 Subject: [PATCH 03/14] minor fix --- src/pybamm/plotting/quick_plot.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/pybamm/plotting/quick_plot.py b/src/pybamm/plotting/quick_plot.py index 2d0fbe7313..f2c5bb6fba 100644 --- a/src/pybamm/plotting/quick_plot.py +++ b/src/pybamm/plotting/quick_plot.py @@ -169,22 +169,21 @@ def __init__( else: raise ValueError(f"spatial unit '{spatial_unit}' not recognized") - # Set time or discharge capacity as x-axis + # set x_axis + self.x_axis = x_axis + if x_axis == "Discharge Capacity [A.h]": - print("yay") # Use discharge capacity as x-axis - self.x_axis = "Discharge capacity [A.h]" - discharge_capacities = [ solution["Discharge capacity [A.h]"].entries for solution in solutions ] - self.dc_values = discharge_capacities + self.dc_values = discharge_capacities self.min_dc = min(dc[0] for dc in discharge_capacities) self.max_dc = max(dc[-1] for dc in discharge_capacities) self.dc_unit = "A.h" - + # Default to time self.ts_seconds = [solution.t for solution in solutions] min_t = np.min([t[0] for t in self.ts_seconds]) @@ -200,6 +199,7 @@ def t_sample(sol): else: t_plot = sol.t return t_plot + ts_seconds = [] for sol in solutions: # Sample time points for each sub-solution From 213d84e7a40fcafd4fa26e43c6c7158e7c7ca16e Mon Sep 17 00:00:00 2001 From: medha-14 Date: Wed, 5 Feb 2025 14:58:44 +0530 Subject: [PATCH 04/14] simplified variable names --- src/pybamm/plotting/quick_plot.py | 174 ++++++++++++------------------ 1 file changed, 66 insertions(+), 108 deletions(-) diff --git a/src/pybamm/plotting/quick_plot.py b/src/pybamm/plotting/quick_plot.py index f2c5bb6fba..220dbab848 100644 --- a/src/pybamm/plotting/quick_plot.py +++ b/src/pybamm/plotting/quick_plot.py @@ -169,22 +169,7 @@ def __init__( else: raise ValueError(f"spatial unit '{spatial_unit}' not recognized") - # set x_axis - self.x_axis = x_axis - - if x_axis == "Discharge Capacity [A.h]": - # Use discharge capacity as x-axis - discharge_capacities = [ - solution["Discharge capacity [A.h]"].entries for solution in solutions - ] - self.dc_values = discharge_capacities - - self.min_dc = min(dc[0] for dc in discharge_capacities) - self.max_dc = max(dc[-1] for dc in discharge_capacities) - - self.dc_unit = "A.h" - - # Default to time + # Time parameters self.ts_seconds = [solution.t for solution in solutions] min_t = np.min([t[0] for t in self.ts_seconds]) max_t = np.max([t[-1] for t in self.ts_seconds]) @@ -231,6 +216,30 @@ def t_sample(sol): self.min_t = min_t / time_scaling_factor self.max_t = max_t / time_scaling_factor + # set x_axis + self.x_axis = x_axis + + if x_axis == "Discharge capacity": + # Use discharge capacity as x-axis + discharge_capacities = [ + solution["Discharge capacity [A.h]"].entries for solution in solutions + ] + self.x_values = discharge_capacities + + self.x_min = min(dc[0] for dc in discharge_capacities) + self.x_max = max(dc[-1] for dc in discharge_capacities) + self.x_scaling_factor = 1 + self.x_unit = "A.h" + + elif x_axis == "Time": + self.x_values = ts_seconds + + self.x_min = self.min_t + self.x_max = self.max_t + self.x_scaling_factor = self.time_scaling_factor + + self.x_unit = self.time_unit + # Prepare dictionary of variables # output_variables is a list of strings or lists, e.g. # ["var 1", ["variable 2", "var 3"]] @@ -430,12 +439,8 @@ def reset_axis(self): self.axis_limits = {} for key, variable_lists in self.variables.items(): if variable_lists[0][0].dimensions == 0: - if self.x_axis == "Discharge capacity [A.h]": - x_min = self.min_dc - x_max = self.max_dc - else: - x_min = self.min_t - x_max = self.max_t + x_min = self.x_min + x_max = self.x_max elif variable_lists[0][0].dimensions == 1: x_min = self.first_spatial_variable[key][0] x_max = self.first_spatial_variable[key][-1] @@ -459,36 +464,20 @@ def reset_axis(self): if self.variable_limits[key] == "fixed": # fixed variable limits: calculate "global" min and max spatial_vars = self.spatial_variable_dict[key] - if self.x_axis == "Discharge capacity [A.h]": - var_min = np.min( - [ - ax_min(var(self.dc_values[i], **spatial_vars)) - for i, variable_list in enumerate(variable_lists) - for var in variable_list - ] - ) - var_max = np.max( - [ - ax_max(var(self.dc_values[i], **spatial_vars)) - for i, variable_list in enumerate(variable_lists) - for var in variable_list - ] - ) - else: - var_min = np.min( - [ - ax_min(var(self.ts_seconds[i], **spatial_vars)) - for i, variable_list in enumerate(variable_lists) - for var in variable_list - ] - ) - var_max = np.max( - [ - ax_max(var(self.ts_seconds[i], **spatial_vars)) - for i, variable_list in enumerate(variable_lists) - for var in variable_list - ] - ) + var_min = np.min( + [ + ax_min(var(self.ts_seconds[i], **spatial_vars)) + for i, variable_list in enumerate(variable_lists) + for var in variable_list + ] + ) + var_max = np.max( + [ + ax_max(var(self.ts_seconds[i], **spatial_vars)) + for i, variable_list in enumerate(variable_lists) + for var in variable_list + ] + ) if np.isnan(var_min) or np.isnan(var_max): raise ValueError( "The variable limits are set to 'fixed' but the min and max " @@ -556,12 +545,8 @@ def plot(self, t, dynamic=False): variable_handles = [] # Set labels for the first subplot only (avoid repetition) if variable_lists[0][0].dimensions == 0: - if self.x_axis == "Time": - # 0D plot: plot as a function of time, indicating time t with a line - ax.set_xlabel(f"Time [{self.time_unit}]") - elif self.x_axis == "Discharge capacity [A.h]": - ax.set_xlabel(f"Discharge Capacity [{self.dc_unit}]") - + # 0D plot: plot as a function of time, indicating time t with a line + ax.set_xlabel(f"{self.x_axis} [{self.x_unit}]") for i, variable_list in enumerate(variable_lists): for j, variable in enumerate(variable_list): if len(variable_list) == 1: @@ -572,23 +557,13 @@ def plot(self, t, dynamic=False): # variables (color differentiates models) linestyle = self.linestyles[j] - if self.x_axis[:4] == "Time": - full_t = self.ts_seconds[i] - (self.plots[key][i][j],) = ax.plot( - full_t / self.time_scaling_factor, - variable(full_t), - color=self.colors[i], - linestyle=linestyle, - ) - elif self.x_axis == "Discharge capacity [A.h]": - full_dc = self.dc_values[i] - (self.plots[key][i][j],) = ax.plot( - full_dc, - variable(full_dc), - color=self.colors[i], - linestyle=linestyle, - ) - + full_val = self.x_values[i] + (self.plots[key][i][j],) = ax.plot( + full_val, + variable(full_val), + color=self.colors[i], + linestyle=linestyle, + ) variable_handles.append(self.plots[key][0][j]) solution_handles.append(self.plots[key][i][0]) y_min, y_max = ax.get_ylim() @@ -734,44 +709,27 @@ def dynamic_plot(self, show_plot=True, step=None): if pybamm.is_notebook(): # pragma: no cover import ipywidgets as widgets - # Determine step size based on x-axis - if self.x_axis == "Discharge capacity [A.h]": - step = step or (self.max_dc - self.min_dc) / 100 - widgets.interact( - lambda dc: self.plot(dc, dynamic=False), - dc=widgets.FloatSlider( - min=self.min_dc, - max=self.max_dc, - step=step, - value=self.min_dc, - ), - continuous_update=False, - ) - else: # Default to time - step = step or self.max_t / 100 - widgets.interact( - lambda t: self.plot(t, dynamic=False), - t=widgets.FloatSlider( - min=self.min_t, - max=self.max_t, - step=step, - value=self.min_t, - ), - continuous_update=False, - ) + step = step or self.max_t / 100 + widgets.interact( + lambda t: self.plot(t, dynamic=False), + t=widgets.FloatSlider( + min=self.min_t, + max=self.max_t, + step=step, + value=self.min_t, + ), + continuous_update=False, + ) else: plt = import_optional_dependency("matplotlib.pyplot") Slider = import_optional_dependency("matplotlib.widgets", "Slider") # Set initial x-axis values and slider - if self.x_axis == "Discharge capacity [A.h]": - self.plot(self.min_dc, dynamic=True) - ax_label = f"Discharge capacity [{self.time_unit}]" # Update time_unit to relevant unit - ax_min, ax_max, val_init = self.min_dc, self.max_dc, self.min_dc - else: # Default to time - self.plot(self.min_t, dynamic=True) - ax_label = f"Time [{self.time_unit}]" - ax_min, ax_max, val_init = self.min_t, self.max_t, self.min_t + self.plot(self.x_min, dynamic=True) + ax_label = ( + f"{self.x_axis}[{self.x_unit}]" # Update time_unit to relevant unit + ) + ax_min, ax_max, val_init = self.x_min, self.x_max, self.x_min axcolor = "lightgoldenrodyellow" ax_slider = plt.axes([0.315, 0.02, 0.37, 0.03], facecolor=axcolor) From 5d2b176cdc2468ddda5fa46a5979ab7c2511d720 Mon Sep 17 00:00:00 2001 From: medha-14 Date: Sun, 9 Feb 2025 15:23:50 +0530 Subject: [PATCH 05/14] minor fix --- src/pybamm/plotting/quick_plot.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/pybamm/plotting/quick_plot.py b/src/pybamm/plotting/quick_plot.py index 220dbab848..bd66e229d3 100644 --- a/src/pybamm/plotting/quick_plot.py +++ b/src/pybamm/plotting/quick_plot.py @@ -226,16 +226,16 @@ def t_sample(sol): ] self.x_values = discharge_capacities - self.x_min = min(dc[0] for dc in discharge_capacities) - self.x_max = max(dc[-1] for dc in discharge_capacities) + self.x_axis_min = min(dc[0] for dc in discharge_capacities) + self.x_axis_max = max(dc[-1] for dc in discharge_capacities) self.x_scaling_factor = 1 self.x_unit = "A.h" elif x_axis == "Time": self.x_values = ts_seconds - self.x_min = self.min_t - self.x_max = self.max_t + self.x_axis_min = self.min_t + self.x_axis_max = self.max_t self.x_scaling_factor = self.time_scaling_factor self.x_unit = self.time_unit @@ -439,8 +439,8 @@ def reset_axis(self): self.axis_limits = {} for key, variable_lists in self.variables.items(): if variable_lists[0][0].dimensions == 0: - x_min = self.x_min - x_max = self.x_max + x_min = self.x_axis_min + x_max = self.x_axis_max elif variable_lists[0][0].dimensions == 1: x_min = self.first_spatial_variable[key][0] x_max = self.first_spatial_variable[key][-1] @@ -546,7 +546,7 @@ def plot(self, t, dynamic=False): # Set labels for the first subplot only (avoid repetition) if variable_lists[0][0].dimensions == 0: # 0D plot: plot as a function of time, indicating time t with a line - ax.set_xlabel(f"{self.x_axis} [{self.x_unit}]") + ax.set_xlabel(f"Time [{self.time_unit}]") for i, variable_list in enumerate(variable_lists): for j, variable in enumerate(variable_list): if len(variable_list) == 1: @@ -556,11 +556,10 @@ def plot(self, t, dynamic=False): # multiple variables -> use linestyle to differentiate # variables (color differentiates models) linestyle = self.linestyles[j] - - full_val = self.x_values[i] + full_t = self.ts_seconds[i] (self.plots[key][i][j],) = ax.plot( - full_val, - variable(full_val), + full_t / self.time_scaling_factor, + variable(full_t), color=self.colors[i], linestyle=linestyle, ) @@ -725,11 +724,11 @@ def dynamic_plot(self, show_plot=True, step=None): Slider = import_optional_dependency("matplotlib.widgets", "Slider") # Set initial x-axis values and slider - self.plot(self.x_min, dynamic=True) + self.plot(self.x_axis_min, dynamic=True) ax_label = ( f"{self.x_axis}[{self.x_unit}]" # Update time_unit to relevant unit ) - ax_min, ax_max, val_init = self.x_min, self.x_max, self.x_min + ax_min, ax_max, val_init = self.x_axis_min, self.x_axis_max, self.x_axis_min axcolor = "lightgoldenrodyellow" ax_slider = plt.axes([0.315, 0.02, 0.37, 0.03], facecolor=axcolor) From 871e1331c707fe5c55ea6c6a611386c09f5779be Mon Sep 17 00:00:00 2001 From: medha-14 Date: Tue, 11 Feb 2025 08:24:43 +0530 Subject: [PATCH 06/14] modified functionality of plot function --- src/pybamm/plotting/quick_plot.py | 34 ++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/src/pybamm/plotting/quick_plot.py b/src/pybamm/plotting/quick_plot.py index bd66e229d3..b1c808f42d 100644 --- a/src/pybamm/plotting/quick_plot.py +++ b/src/pybamm/plotting/quick_plot.py @@ -219,7 +219,7 @@ def t_sample(sol): # set x_axis self.x_axis = x_axis - if x_axis == "Discharge capacity": + if x_axis == "Discharge capacity [A.h]": # Use discharge capacity as x-axis discharge_capacities = [ solution["Discharge capacity [A.h]"].entries for solution in solutions @@ -546,7 +546,11 @@ def plot(self, t, dynamic=False): # Set labels for the first subplot only (avoid repetition) if variable_lists[0][0].dimensions == 0: # 0D plot: plot as a function of time, indicating time t with a line - ax.set_xlabel(f"Time [{self.time_unit}]") + if self.x_axis == "Time": + ax.set_xlabel(f"Time [{self.time_unit}]") + if self.x_axis == "Discharge capacity [A.h]": + ax.set_xlabel("Discharge capacity [A.h]") + for i, variable_list in enumerate(variable_lists): for j, variable in enumerate(variable_list): if len(variable_list) == 1: @@ -556,10 +560,10 @@ def plot(self, t, dynamic=False): # multiple variables -> use linestyle to differentiate # variables (color differentiates models) linestyle = self.linestyles[j] - full_t = self.ts_seconds[i] + full_val = self.x_values[i] (self.plots[key][i][j],) = ax.plot( - full_t / self.time_scaling_factor, - variable(full_t), + full_val / self.x_scaling_factor, + variable(full_val), color=self.colors[i], linestyle=linestyle, ) @@ -708,14 +712,14 @@ def dynamic_plot(self, show_plot=True, step=None): if pybamm.is_notebook(): # pragma: no cover import ipywidgets as widgets - step = step or self.max_t / 100 + step = step or (self.x_axis_max - self.x_axis_min) / 100 widgets.interact( lambda t: self.plot(t, dynamic=False), t=widgets.FloatSlider( - min=self.min_t, - max=self.max_t, + min=self.x_axis_min, + max=self.x_axis_max, step=step, - value=self.min_t, + value=self.x_axis_min, ), continuous_update=False, ) @@ -725,9 +729,15 @@ def dynamic_plot(self, show_plot=True, step=None): # Set initial x-axis values and slider self.plot(self.x_axis_min, dynamic=True) - ax_label = ( - f"{self.x_axis}[{self.x_unit}]" # Update time_unit to relevant unit - ) + + # Set x-axis label correctly + if self.x_axis == "Time": + ax_label = f"Time [{self.time_unit}]" + elif self.x_axis == "Discharge capacity [A.h]": + ax_label = "Discharge capacity [A.h]" + else: + ax_label = self.x_axis # Use the string directly if unknown + ax_min, ax_max, val_init = self.x_axis_min, self.x_axis_max, self.x_axis_min axcolor = "lightgoldenrodyellow" From dc933cec52ef178b6d24e6717b14c1ac6b4aa600 Mon Sep 17 00:00:00 2001 From: medha-14 Date: Tue, 11 Feb 2025 09:00:03 +0530 Subject: [PATCH 07/14] added tests --- tests/unit/test_plotting/test_quick_plot.py | 23 +++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/unit/test_plotting/test_quick_plot.py b/tests/unit/test_plotting/test_quick_plot.py index 5fcafdd37b..5b79815cfb 100644 --- a/tests/unit/test_plotting/test_quick_plot.py +++ b/tests/unit/test_plotting/test_quick_plot.py @@ -266,6 +266,29 @@ def test_simple_ode_model(self, solver): pybamm.close_plots() + def test_plot_with_discharge_capacity(self): + """Test that the x-axis is correctly set to Discharge capacity [A.h]""" + model = pybamm.lithium_ion.BaseModel(name="Simple ODE Model") + a = pybamm.Variable("a", domain=[]) + model.rhs = {a: pybamm.Scalar(0.2)} + model.initial_conditions = {a: pybamm.Scalar(0)} + model.variables = {"a": a, "Discharge capacity [A.h]": a * 2} + + t_eval = np.linspace(0, 2, 100) + solution = pybamm.CasadiSolver().solve(model, t_eval) + + quick_plot_capacity = pybamm.QuickPlot( + solution, + ["a"], + x_axis="Discharge capacity [A.h]", + ) + quick_plot_capacity.plot(0) + + np.testing.assert_allclose( + quick_plot_capacity.plots[("a",)][0][0].get_xdata(), + solution["Discharge capacity [A.h]"].data, + ) + def test_plot_with_different_models(self): model = pybamm.BaseModel() a = pybamm.Variable("a") From ceb0c6a419e920490e7810939f10db330531a9d8 Mon Sep 17 00:00:00 2001 From: medha-14 Date: Wed, 12 Feb 2025 22:03:57 +0530 Subject: [PATCH 08/14] added tests --- src/pybamm/plotting/quick_plot.py | 5 ++++- tests/unit/test_plotting/test_quick_plot.py | 22 +++++++++++++++++---- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/pybamm/plotting/quick_plot.py b/src/pybamm/plotting/quick_plot.py index b1c808f42d..20302050f7 100644 --- a/src/pybamm/plotting/quick_plot.py +++ b/src/pybamm/plotting/quick_plot.py @@ -239,7 +239,10 @@ def t_sample(sol): self.x_scaling_factor = self.time_scaling_factor self.x_unit = self.time_unit - + else: + msg = f"Invalid value for `x_axis`." + raise ValueError(msg) + # Prepare dictionary of variables # output_variables is a list of strings or lists, e.g. # ["var 1", ["variable 2", "var 3"]] diff --git a/tests/unit/test_plotting/test_quick_plot.py b/tests/unit/test_plotting/test_quick_plot.py index c343c42932..a1d0dc635e 100644 --- a/tests/unit/test_plotting/test_quick_plot.py +++ b/tests/unit/test_plotting/test_quick_plot.py @@ -281,8 +281,17 @@ def test_simple_ode_model(self, solver): pybamm.close_plots() + def test_invalid_x_axis(self): + model = pybamm.lithium_ion.SPM() + sim = pybamm.Simulation(model) + solution = sim.solve([0, 3600]) + + with pytest.raises(ValueError, match="Invalid value for `x_axis`."): + pybamm.QuickPlot([solution], x_axis="Invalid axis") + + def test_plot_with_discharge_capacity(self): - """Test that the x-axis is correctly set to Discharge capacity [A.h]""" + model = pybamm.lithium_ion.BaseModel(name="Simple ODE Model") a = pybamm.Variable("a", domain=[]) model.rhs = {a: pybamm.Scalar(0.2)} @@ -292,18 +301,23 @@ def test_plot_with_discharge_capacity(self): t_eval = np.linspace(0, 2, 100) solution = pybamm.CasadiSolver().solve(model, t_eval) - quick_plot_capacity = pybamm.QuickPlot( + quick_plot = pybamm.QuickPlot( solution, ["a"], x_axis="Discharge capacity [A.h]", ) - quick_plot_capacity.plot(0) + quick_plot.plot(0) + # Test discharge capacity values np.testing.assert_allclose( - quick_plot_capacity.plots[("a",)][0][0].get_xdata(), + quick_plot.plots[("a",)][0][0].get_xdata(), solution["Discharge capacity [A.h]"].data, ) + # Test x-axis label + x_label = quick_plot.fig.axes[0].get_xlabel() + assert x_label == "Discharge capacity [A.h]", f"Expected 'Discharge capacity [A.h]', got '{x_label}'" + def test_plot_with_different_models(self): model = pybamm.BaseModel() a = pybamm.Variable("a") From a1eca74efff80b02fb32e5d97c0fabdca3a1480c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Feb 2025 16:35:58 +0000 Subject: [PATCH 09/14] style: pre-commit fixes --- src/pybamm/plotting/quick_plot.py | 4 ++-- tests/unit/test_plotting/test_quick_plot.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/pybamm/plotting/quick_plot.py b/src/pybamm/plotting/quick_plot.py index 20302050f7..192d507de5 100644 --- a/src/pybamm/plotting/quick_plot.py +++ b/src/pybamm/plotting/quick_plot.py @@ -240,9 +240,9 @@ def t_sample(sol): self.x_unit = self.time_unit else: - msg = f"Invalid value for `x_axis`." + msg = "Invalid value for `x_axis`." raise ValueError(msg) - + # Prepare dictionary of variables # output_variables is a list of strings or lists, e.g. # ["var 1", ["variable 2", "var 3"]] diff --git a/tests/unit/test_plotting/test_quick_plot.py b/tests/unit/test_plotting/test_quick_plot.py index a1d0dc635e..ffdb938989 100644 --- a/tests/unit/test_plotting/test_quick_plot.py +++ b/tests/unit/test_plotting/test_quick_plot.py @@ -285,13 +285,11 @@ def test_invalid_x_axis(self): model = pybamm.lithium_ion.SPM() sim = pybamm.Simulation(model) solution = sim.solve([0, 3600]) - + with pytest.raises(ValueError, match="Invalid value for `x_axis`."): pybamm.QuickPlot([solution], x_axis="Invalid axis") - def test_plot_with_discharge_capacity(self): - model = pybamm.lithium_ion.BaseModel(name="Simple ODE Model") a = pybamm.Variable("a", domain=[]) model.rhs = {a: pybamm.Scalar(0.2)} @@ -314,9 +312,11 @@ def test_plot_with_discharge_capacity(self): solution["Discharge capacity [A.h]"].data, ) - # Test x-axis label + # Test x-axis label x_label = quick_plot.fig.axes[0].get_xlabel() - assert x_label == "Discharge capacity [A.h]", f"Expected 'Discharge capacity [A.h]', got '{x_label}'" + assert x_label == "Discharge capacity [A.h]", ( + f"Expected 'Discharge capacity [A.h]', got '{x_label}'" + ) def test_plot_with_different_models(self): model = pybamm.BaseModel() From 95930ef09a8e548a4858fa695fed4cd8e8af05b1 Mon Sep 17 00:00:00 2001 From: medha-14 Date: Thu, 13 Feb 2025 12:27:37 +0530 Subject: [PATCH 10/14] simplified variables --- src/pybamm/plotting/quick_plot.py | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/src/pybamm/plotting/quick_plot.py b/src/pybamm/plotting/quick_plot.py index c524567ec4..9b30cbe679 100644 --- a/src/pybamm/plotting/quick_plot.py +++ b/src/pybamm/plotting/quick_plot.py @@ -225,19 +225,13 @@ def t_sample(sol): ] self.x_values = discharge_capacities - self.x_axis_min = min(dc[0] for dc in discharge_capacities) - self.x_axis_max = max(dc[-1] for dc in discharge_capacities) self.x_scaling_factor = 1 - self.x_unit = "A.h" + self.x_label = "Discharge capacity [A.h]" elif x_axis == "Time": self.x_values = ts_seconds - - self.x_axis_min = self.min_t - self.x_axis_max = self.max_t self.x_scaling_factor = self.time_scaling_factor - self.x_unit = self.time_unit else: msg = "Invalid value for `x_axis`." raise ValueError(msg) @@ -441,8 +435,8 @@ def reset_axis(self): self.axis_limits = {} for key, variable_lists in self.variables.items(): if variable_lists[0][0].dimensions == 0: - x_min = self.x_axis_min - x_max = self.x_axis_max + x_min = self.min_t + x_max = self.max_t elif variable_lists[0][0].dimensions == 1: x_min = self.first_spatial_variable[key][0] x_max = self.first_spatial_variable[key][-1] @@ -551,7 +545,7 @@ def plot(self, t, dynamic=False): if self.x_axis == "Time": ax.set_xlabel(f"Time [{self.time_unit}]") if self.x_axis == "Discharge capacity [A.h]": - ax.set_xlabel("Discharge capacity [A.h]") + ax.set_xlabel(f"{self.x_label}") for i, variable_list in enumerate(variable_lists): for j, variable in enumerate(variable_list): @@ -714,14 +708,11 @@ def dynamic_plot(self, show_plot=True, step=None): if pybamm.is_notebook(): # pragma: no cover import ipywidgets as widgets - step = step or (self.x_axis_max - self.x_axis_min) / 100 + step = step or self.max_t / 100 widgets.interact( lambda t: self.plot(t, dynamic=False), t=widgets.FloatSlider( - min=self.x_axis_min, - max=self.x_axis_max, - step=step, - value=self.x_axis_min, + min=self.min_t, max=self.max_t, step=step, value=self.min_t ), continuous_update=False, ) @@ -730,7 +721,7 @@ def dynamic_plot(self, show_plot=True, step=None): Slider = import_optional_dependency("matplotlib.widgets", "Slider") # Set initial x-axis values and slider - self.plot(self.x_axis_min, dynamic=True) + self.plot(self.min_t, dynamic=True) # Set x-axis label correctly if self.x_axis == "Time": @@ -740,7 +731,7 @@ def dynamic_plot(self, show_plot=True, step=None): else: ax_label = self.x_axis # Use the string directly if unknown - ax_min, ax_max, val_init = self.x_axis_min, self.x_axis_max, self.x_axis_min + ax_min, ax_max, val_init = self.min_t, self.max_t, self.min_t axcolor = "lightgoldenrodyellow" ax_slider = plt.axes([0.315, 0.02, 0.37, 0.03], facecolor=axcolor) From e29b0198de7a08a08999a5a64658b7b7f286d2a9 Mon Sep 17 00:00:00 2001 From: medha-14 Date: Thu, 13 Feb 2025 12:31:21 +0530 Subject: [PATCH 11/14] added changelog entry --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 583f63f129..3539c11675 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Features +- Added functionality to set `Discharge capacity` as the x-axis in `QuickPlot`.([#4775](https://github.com/pybamm-team/PyBaMM/pull/4775)) - Creates a 'calc_esoh' property in battery models ([#4825](https://github.com/pybamm-team/PyBaMM/pull/4825)) - Added 'get_summary_variables' to return dictionary of computed summary variables ([#4824](https://github.com/pybamm-team/PyBaMM/pull/4824)) - Added support for particle size distributions combined with particle mechanics. ([#4807](https://github.com/pybamm-team/PyBaMM/pull/4807)) From 7f40403a451e3a2e8a1f1693d6f8b251a362428b Mon Sep 17 00:00:00 2001 From: medha-14 Date: Thu, 13 Feb 2025 13:04:57 +0530 Subject: [PATCH 12/14] added docstring --- src/pybamm/plotting/quick_plot.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pybamm/plotting/quick_plot.py b/src/pybamm/plotting/quick_plot.py index 9b30cbe679..5826fedffd 100644 --- a/src/pybamm/plotting/quick_plot.py +++ b/src/pybamm/plotting/quick_plot.py @@ -91,6 +91,10 @@ class QuickPlot: - "tight": make axes tight to plot at each time - dictionary: fine-grain control for each variable, can be either "fixed" or \ "tight" or a specific tuple (lower, upper). + x_axis : str, optional + The variable to use for the x-axis. Options are: + - "Time" (default): Use time as the x-axis. + - "Discharge capacity [A.h]": Use discharge capacity as the x-axis. """ From 48512323023be2b8e4537536e39b2f7008c6058b Mon Sep 17 00:00:00 2001 From: medha-14 Date: Wed, 19 Feb 2025 15:09:13 +0530 Subject: [PATCH 13/14] added tests --- src/pybamm/plotting/quick_plot.py | 2 -- tests/unit/test_plotting/test_quick_plot.py | 4 ++++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/pybamm/plotting/quick_plot.py b/src/pybamm/plotting/quick_plot.py index 7da0aecca9..926b91e688 100644 --- a/src/pybamm/plotting/quick_plot.py +++ b/src/pybamm/plotting/quick_plot.py @@ -730,8 +730,6 @@ def dynamic_plot(self, show_plot=True, step=None): ax_label = f"Time [{self.time_unit}]" elif self.x_axis == "Discharge capacity [A.h]": ax_label = "Discharge capacity [A.h]" - else: - ax_label = self.x_axis # Use the string directly if unknown ax_min, ax_max, val_init = self.min_t, self.max_t, self.min_t diff --git a/tests/unit/test_plotting/test_quick_plot.py b/tests/unit/test_plotting/test_quick_plot.py index ffdb938989..26e8a2f124 100644 --- a/tests/unit/test_plotting/test_quick_plot.py +++ b/tests/unit/test_plotting/test_quick_plot.py @@ -318,6 +318,10 @@ def test_plot_with_discharge_capacity(self): f"Expected 'Discharge capacity [A.h]', got '{x_label}'" ) + # check dynamic plot loads + quick_plot.dynamic_plot(show_plot=False) + quick_plot.slider_update(0.01) + def test_plot_with_different_models(self): model = pybamm.BaseModel() a = pybamm.Variable("a") From a43388c0fff7881a5f1d767a5669712b279da952 Mon Sep 17 00:00:00 2001 From: medha-14 Date: Sat, 1 Mar 2025 17:38:40 +0530 Subject: [PATCH 14/14] minor fix --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cb9b0edd87..234c310528 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # [Unreleased](https://github.com/pybamm-team/PyBaMM/) -## Feature +## Features - Added functionality to set `Discharge capacity` as the x-axis in `QuickPlot`.([#4775](https://github.com/pybamm-team/PyBaMM/pull/4775)) - Revision of the hysteresis notebook to include the method implemented in the module `axen_ocp`. ([#4880](https://github.com/pybamm-team/PyBaMM/pull/4880))