From 0e13e48ec131b686c204ec2f40de214a8472d007 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Wed, 29 Oct 2025 18:15:50 +0000 Subject: [PATCH 01/49] Increase title spacing and restore axes titles --- .pre-commit-config.yaml | 4 ++-- ice_station_zebra/visualisations/layout.py | 15 +++++++++++---- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0c8f6c3a..5a969794 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,7 @@ repos: # Run the linter - id: ruff-check args: ["--fix", "--show-fixes"] - pass_filenames: false + pass_filenames: true # Run the formatter - id: ruff-format - pass_filenames: false + pass_filenames: true diff --git a/ice_station_zebra/visualisations/layout.py b/ice_station_zebra/visualisations/layout.py index f2bacdd8..4d584707 100644 --- a/ice_station_zebra/visualisations/layout.py +++ b/ice_station_zebra/visualisations/layout.py @@ -53,9 +53,7 @@ DEFAULT_CBAR_WIDTH = ( 0.06 # Width allocated for colourbar slots (fraction of panel width) ) -DEFAULT_TITLE_SPACE = ( - 0.08 # Vertical space reserved for figure title (prevents overlap) -) +DEFAULT_TITLE_SPACE = 0.11 # Allow for the figure title, warning badge, and axes titles # Horizontal colourbar sizing (fractions of figure height) DEFAULT_CBAR_HEIGHT = ( @@ -458,7 +456,16 @@ def _set_titles(axs: list[Axes], plot_spec: PlotSpec) -> None: titles = [plot_spec.title_groundtruth, plot_spec.title_prediction, title_difference] for ax, title in zip(axs, titles, strict=False): if title is not None: - ax.set_title(title) + ax.set_title( + title, + fontfamily="monospace", + bbox={ + "facecolor": "white", + "edgecolor": "none", + "pad": 2.0, + "alpha": 1.0, + }, + ) def _style_axes(axs: Sequence[Axes]) -> None: From 0d7efe85f4df147fc3698b15a0bb4fe2c48c0e27 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Wed, 29 Oct 2025 18:26:27 +0000 Subject: [PATCH 02/49] Extract meta info about training and data for the title --- ice_station_zebra/callbacks/metadata.py | 502 ++++++++++++++++++++++++ 1 file changed, 502 insertions(+) create mode 100644 ice_station_zebra/callbacks/metadata.py diff --git a/ice_station_zebra/callbacks/metadata.py b/ice_station_zebra/callbacks/metadata.py new file mode 100644 index 00000000..d539a134 --- /dev/null +++ b/ice_station_zebra/callbacks/metadata.py @@ -0,0 +1,502 @@ +"""Metadata extraction and formatting for plot titles. + +This module provides functions to extract training metadata from Hydra configs +and format them for display in plot titles. +""" + +import contextlib +import logging +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from datetime import datetime +from typing import Any + +from ice_station_zebra.data_loaders import CombinedDataset + +logger = logging.getLogger(__name__) + + +@dataclass +class Metadata: + """Structured metadata extracted from training configuration. + + Attributes: + model: Model name (if available). + epochs: Maximum number of training epochs (if available). + start: Training start date string (if available). + end: Training end date string (if available). + cadence: Training data cadence string (if available). + n_points: Number of training points calculated from date range and cadence. + vars_by_source: Dictionary mapping dataset source names to lists of variable names. + + """ + + model: str | None = None + epochs: int | None = None + start: str | None = None + end: str | None = None + cadence: str | None = None + n_points: int | None = None + vars_by_source: dict[str, list[str]] | None = None + + +def extract_variables_by_source(config: dict[str, Any]) -> dict[str, list[str]]: # noqa: C901 + """Extract variable names grouped by dataset source (group_as). + + Args: + config: Configuration dictionary containing dataset definitions. + + Returns: + Dictionary mapping dataset group names to lists of their variable names. + Example: {"era5": ["10u", "10v", "2t"], "osisaf-south": ["sea ice"]} + + """ + vars_by_source: dict[str, list[str]] = {} + + def _extract_weather_params(ds: dict[str, Any]) -> list[str]: + """Collect 'param' fields from nested input/join/mars blocks.""" + input_cfg = ds.get("input") + if not isinstance(input_cfg, dict): + return [] + + join_items = input_cfg.get("join", []) + if not isinstance(join_items, list): + return [] + + params: list[str] = [] + for item in join_items: + if not isinstance(item, dict): + continue + # Accept both MARS and FORCINGS sources + for source_type in ("mars", "forcings"): + cfg = item.get(source_type) + if isinstance(cfg, dict): + param_list = cfg.get("param", []) + if isinstance(param_list, list): + params.extend(str(p) for p in param_list if p) + + return sorted(set(params)) + + try: + datasets = config.get("datasets", {}) + if not isinstance(datasets, dict): + return vars_by_source + + for ds in datasets.values(): + if not isinstance(ds, dict): + continue + + ds_name = str(ds.get("name", "")).lower() + group_name = ds.get("group_as") + if not isinstance(group_name, str): + continue + + # --- Infer variables based on dataset type --- + variables: list[str] = [] + if "sicnorth" in ds_name or "sicsouth" in ds_name: + variables = ["sea ice"] + elif "weather" in ds_name: + variables = _extract_weather_params(ds) + + # --- Add variables if any --- + if not variables: + continue + + group_vars = vars_by_source.setdefault(group_name, []) + for v in variables: + if v not in group_vars: + group_vars.append(v) + + except (AttributeError, TypeError, ValueError) as exc: + logger.debug("Failed to extract variables from config: %s", exc, exc_info=True) + + # Sort variables alphabetically per source + return {src: sorted(vs) for src, vs in vars_by_source.items()} + + +def calculate_training_points( + start_str: str | None, end_str: str | None, cadence_str: str | None +) -> int | None: + """Calculate number of training points from date range and cadence. + + Calculates the number of time points in an inclusive date range given a cadence. + For example, Jan 1 to Jan 10 with 1d cadence = 10 points (inclusive endpoints). + + Args: + start_str: Start date string (ISO format, with or without time). + Time components are stripped before calculation. + end_str: End date string (ISO format, with or without time). + Time components are stripped before calculation. + cadence_str: Cadence string (e.g., "1d", "3h", "daily", "24h"). + + Returns: + Number of points (at least 1) or None if calculation fails. + Returns None if any input is None/empty or if cadence format is unrecognized. + + """ + if not start_str or not end_str or not cadence_str: + return None + + try: + delta_days = _inclusive_days(start_str, end_str) + computed_points = _points_from_cadence(delta_days, cadence_str) + except (ValueError, TypeError) as exc: + logger.debug( + "Failed to calculate training points from dates/cadence: %s", + exc, + exc_info=True, + ) + return None + else: + return computed_points + + +def format_cadence_display(cadence_str: str | None) -> str | None: + """Format cadence string for display (converts 1d/1h to daily/hourly). + + Args: + cadence_str: Raw cadence string from config. + + Returns: + Formatted cadence string (daily/hourly or original if not 1d/1h). + + """ + if not cadence_str: + return None + norm = cadence_str.strip().lower() + if norm in {"1d", "1day", "1 day"}: + return "daily" + if norm in {"1h", "1hr", "1 hour"}: + return "hourly" + return cadence_str + + +def extract_cadence_from_config(config: dict[str, Any]) -> str | None: + """Extract cadence (frequency) from dataset config for the prediction target group. + + Args: + config: Configuration dictionary. + + Returns: + Cadence string (e.g., "1d", "3h") or None if not found. + + """ + try: + predict_group = config.get("predict", {}).get("dataset_group") + datasets_cfg = config.get("datasets", {}) + if isinstance(predict_group, str) and isinstance(datasets_cfg, dict): + for ds in datasets_cfg.values(): + if not isinstance(ds, dict): + continue + if ds.get("group_as") == predict_group: + dates_section = ds.get("dates") + if isinstance(dates_section, dict): + freq_candidate = dates_section.get("frequency") + if isinstance(freq_candidate, str) and freq_candidate: + return freq_candidate + except (AttributeError, TypeError) as exc: + logger.debug("Failed to extract cadence from config: %s", exc, exc_info=True) + return None + + +def extract_training_date_range( + config: dict[str, Any], +) -> tuple[str | None, str | None]: + """Extract training date range from split config. + + Args: + config: Configuration dictionary. + + Returns: + Tuple of (start_date_str, end_date_str) or (None, None) if not found. + + """ + start_str: str | None = None + end_str: str | None = None + try: + split_cfg = config.get("split", {}) + train_ranges = split_cfg.get("train") if isinstance(split_cfg, dict) else None + if isinstance(train_ranges, list) and train_ranges: + starts = [r.get("start") for r in train_ranges if isinstance(r, dict)] + ends = [r.get("end") for r in train_ranges if isinstance(r, dict)] + non_null_starts = [s for s in starts if isinstance(s, str) and s] + non_null_ends = [e for e in ends if isinstance(e, str) and e] + start_str = min(non_null_starts) if non_null_starts else None + end_str = max(non_null_ends) if non_null_ends else None + except (AttributeError, TypeError) as exc: + logger.debug( + "Failed to extract training date range from config: %s", exc, exc_info=True + ) + return start_str, end_str + + +def build_metadata( + config: dict[str, Any], + model_name: str | None = None, +) -> Metadata: + """Build structured metadata from configuration. + + Extracts training metadata from config and returns a Metadata dataclass. + All fields are optional and will be None if the corresponding information + is not available in the config. + + Args: + config: Configuration dictionary containing training and dataset info. + model_name: Optional model name (if not provided, will not be included). + + Returns: + Metadata dataclass instance with extracted information. + + """ + # Extract training date range + start_str, end_str = extract_training_date_range(config) + + # Extract cadence + cadence_str = extract_cadence_from_config(config) + training_points = calculate_training_points(start_str, end_str, cadence_str) + + # Get epochs + trainer_cfg = config.get("train", {}).get("trainer", {}) + max_epochs = ( + trainer_cfg.get("max_epochs") if isinstance(trainer_cfg, dict) else None + ) + + # Get variables grouped by source + vars_by_source = extract_variables_by_source(config) + + return Metadata( + model=model_name if isinstance(model_name, str) and model_name else None, + epochs=max_epochs if isinstance(max_epochs, int) else None, + start=start_str, + end=end_str, + cadence=cadence_str, + n_points=training_points, + vars_by_source=vars_by_source if vars_by_source else None, + ) + + +def format_metadata_subtitle(metadata: Metadata) -> str | None: # noqa: C901 + """Format metadata dataclass as a compact multi-line subtitle for plot titles. + + Lines: + 1) Model: Epoch: Training Dates: () pts + 2) Training Data: () () + + Args: + metadata: Metadata dataclass instance to format. + + Returns: + Formatted metadata string with newlines, or None if no metadata available. + + """ + lines: list[str] = [] + + # Line 1: Model/Epoch/Dates + info_parts: list[str] = [] + if metadata.model: + info_parts.append(f"Model: {metadata.model}") + if metadata.epochs is not None: + info_parts.append(f"Epoch: {metadata.epochs}") + + if metadata.start or metadata.end: + s_clean = ( + metadata.start.split("T")[0] + if metadata.start and "T" in metadata.start + else (metadata.start if metadata.start else "?") + ) + e_clean = ( + metadata.end.split("T")[0] + if metadata.end and "T" in metadata.end + else (metadata.end if metadata.end else "?") + ) + cadence_display = format_cadence_display(metadata.cadence) + dates_part = f"Training Dates: {s_clean} — {e_clean}" + if cadence_display: + dates_part += f" ({cadence_display})" + if metadata.n_points is not None: + dates_part += f" {metadata.n_points} pts" + info_parts.append(dates_part) + + if info_parts: + lines.append(" ".join(info_parts)) + + # Line 2: Training data sources and variables + if metadata.vars_by_source: + source_parts = [] + for source in sorted(metadata.vars_by_source.keys()): + vars_list = metadata.vars_by_source[source] + if vars_list: + vars_str = ",".join(vars_list) + source_parts.append(f"{source} ({vars_str})") + else: + source_parts.append(source) + if source_parts: + lines.append(f"Training Data: {' '.join(source_parts)}") + + return "\n".join(lines) if lines else None + + +def build_metadata_subtitle( + config: dict[str, Any], + model_name: str | None = None, +) -> str | None: + """Build metadata subtitle for plot titles. + + Convenience function that combines build_metadata and format_metadata_subtitle. + Maintains backward compatibility with existing code. + + Args: + config: Configuration dictionary containing training and dataset info. + model_name: Optional model name (if not provided, will not be included). + + Returns: + Formatted metadata string with newlines, or None if no metadata available. + + """ + metadata = build_metadata(config, model_name=model_name) + return format_metadata_subtitle(metadata) + + +def infer_hemisphere(dataset: CombinedDataset) -> str | None: # noqa: C901, PLR0912 + """Infer hemisphere from dataset name or config as a fallback. + + Priority: + 1) CombinedDataset.target.name containing "north"/"south" + 2) Any input dataset name containing "north"/"south" + 3) Dataset-level name or config strings containing the keywords + + Args: + dataset: CombinedDataset instance to infer hemisphere from. + + Returns: + "north" or "south" (lowercase) when detected, otherwise None. + + """ + candidate_names: list[str] = [] + + # 1) Target dataset name + target = getattr(dataset, "target", None) + target_name = getattr(target, "name", None) + if isinstance(target_name, str) and target_name: + candidate_names.append(target_name) + + # 2) Top-level dataset name + ds_name = getattr(dataset, "name", None) + if isinstance(ds_name, str) and ds_name: + candidate_names.append(ds_name) + + # 3) Inputs: may be a Sequence of objects, mappings or plain strings + inputs = getattr(dataset, "inputs", None) + if isinstance(inputs, Sequence) and not isinstance(inputs, (str, bytes)): + for item in inputs: + # If the item is a mapping-like object (dict), try key access + if isinstance(item, Mapping): + name = item.get("name") or item.get("dataset_name") or None + else: + # Otherwise try attribute access, then try if item itself is a string + name = ( + getattr(item, "name", None) if not isinstance(item, str) else item + ) + + if isinstance(name, str) and name: + candidate_names.append(name) + + # 4) Generic config-like hints: look for a config attribute (mapping) and make a string of a few keys + config_like = getattr(dataset, "config", None) or getattr( + dataset, "dataset_config", None + ) + if isinstance(config_like, Mapping): + # Check a few plausible keys + for key in ("name", "dataset", "dataset_name", "target"): + val = config_like.get(key) + if isinstance(val, str) and val: + candidate_names.append(val) + # As a last resort, make the mapping (small) into a string and use as a candidate + try: + maybe_str = str(config_like) + if maybe_str: + candidate_names.append(maybe_str) + except TypeError as exc: + logger.debug( + "Failed to extract config hint for hemisphere inference: %s", + exc, + exc_info=True, + ) + + # Normalise and search for hemisphere keywords. + for cand in candidate_names: + low = cand.lower() + if "north" in low: + logger.debug("Inferred hemisphere 'north' from dataset hint: %s", cand) + return "north" + if "south" in low: + logger.debug("Inferred hemisphere 'south' from dataset hint: %s", cand) + return "south" + + return None + + +# --- Internal helpers to reduce complexity/branching --- + + +def _clean_date_str(date_str: str) -> str: + """Return date-only portion of an ISO string (strip any time part).""" + return date_str.split("T")[0] if "T" in date_str else date_str + + +def _inclusive_days(start_str: str, end_str: str) -> int: + """Return inclusive number of days between two ISO date strings.""" + start_dt = datetime.fromisoformat(_clean_date_str(start_str)) + end_dt = datetime.fromisoformat(_clean_date_str(end_str)) + return (end_dt - start_dt).days + 1 + + +def _normalise_cadence(raw: str) -> str: + """Normalise common cadence synonyms to canonical forms like '1d' or '1h'.""" + cad = raw.strip().lower() + if cad in ("daily", "day"): + return "1d" + if cad in ("hourly", "hour"): + return "1h" + return cad + + +def _points_from_cadence(delta_days: int, cadence: str) -> int | None: + """Compute point count from inclusive day span and normalized cadence. + + Supports day- and hour-based cadences (e.g., '1d', '2day', '3h', '12hr', '24hour'). + Returns None for unrecognized formats or non-positive periods. + """ + cad = _normalise_cadence(cadence) + + # Day cadence + if cad.endswith(("d", "day")): + cleaned = cad[:-3] if cad.endswith("day") else cad[:-1] + cleaned = cleaned.strip() + num_days = 1 + if cleaned: + with contextlib.suppress(ValueError): + num_days = int(cleaned) + if num_days <= 0: + return None + return max(1, delta_days // num_days) + + # Hour cadence + if cad.endswith(("hour", "hr", "h")): + if cad.endswith("hour"): + cleaned = cad[:-4] + elif cad.endswith("hr"): + cleaned = cad[:-2] + else: + cleaned = cad[:-1] + cleaned = cleaned.strip() + num_hours = 1 + if cleaned: + with contextlib.suppress(ValueError): + num_hours = int(cleaned) + if num_hours <= 0: + return None + delta_hours = delta_days * 24 + return max(1, delta_hours // num_hours) + + return None From f5a195a04d6ca89b267265983db19b8d4da78a79 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Wed, 29 Oct 2025 18:27:11 +0000 Subject: [PATCH 03/49] Refactor to use metadata in plots --- .../callbacks/plotting_callback.py | 47 ++++++++++++++++--- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/ice_station_zebra/callbacks/plotting_callback.py b/ice_station_zebra/callbacks/plotting_callback.py index e76a4538..0212b899 100644 --- a/ice_station_zebra/callbacks/plotting_callback.py +++ b/ice_station_zebra/callbacks/plotting_callback.py @@ -2,7 +2,7 @@ import logging from collections.abc import Mapping, Sequence from datetime import date, datetime -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast import numpy as np from lightning import LightningModule, Trainer @@ -10,6 +10,10 @@ from lightning.pytorch.loggers import Logger as LightningLogger from torch import Tensor +from ice_station_zebra.callbacks.metadata import ( + build_metadata_subtitle, + infer_hemisphere, +) from ice_station_zebra.data_loaders import CombinedDataset from ice_station_zebra.exceptions import InvalidArrayError, VideoRenderError from ice_station_zebra.types import ModelTestOutput, TensorDimensions @@ -71,7 +75,6 @@ def __init__( # noqa: PLR0913 def _detect_land_mask_path( self, dataset: CombinedDataset, - trainer: Trainer, # noqa: ARG002 ) -> None: """Detect and set the land mask path based on the dataset configuration.""" if self._land_mask_path_detected: @@ -88,12 +91,28 @@ def _detect_land_mask_path( # Detect land mask path land_mask_path = detect_land_mask_path(base_path, dataset_name) - if land_mask_path: - # Update the plot_spec with the detected land mask path - # Use dataclasses.replace to create a new PlotSpec instance with updated land_mask_path + # Always try to infer hemisphere regardless of land mask presence + hemisphere: Literal["north", "south"] | None = None + if isinstance(dataset_name, str): + low = dataset_name.lower() + if "south" in low: + hemisphere = cast("Literal['north', 'south']", "south") + elif "north" in low: + hemisphere = cast("Literal['north', 'south']", "north") + if hemisphere is None: + hemi_candidate = infer_hemisphere(dataset) + if hemi_candidate in ("north", "south"): + hemisphere = cast("Literal['north', 'south']", hemi_candidate) + + # Update plot_spec pieces independently + if hemisphere is not None and self.plot_spec.hemisphere != hemisphere: + self.plot_spec = dataclasses.replace(self.plot_spec, hemisphere=hemisphere) + if land_mask_path: + # Set land mask path when found self.plot_spec = dataclasses.replace( - self.plot_spec, land_mask_path=land_mask_path + self.plot_spec, + land_mask_path=land_mask_path, ) logger.info("Auto-detected land mask: %s", land_mask_path) else: @@ -142,7 +161,21 @@ def on_test_batch_end( ] # Detect land mask path if not already done - self._detect_land_mask_path(dataset, trainer) + self._detect_land_mask_path(dataset) + + # Build readable metadata subtitle from config + try: + model_name = getattr(_module, "name", None) + combined_meta = build_metadata_subtitle(self.config, model_name=model_name) + if combined_meta: + self.plot_spec = dataclasses.replace( + self.plot_spec, metadata_subtitle=combined_meta + ) + except Exception: + # Don't fail plotting just because metadata gathering failed. + logger.exception( + "Failed to build metadata subtitle; continuing without it." + ) # Log static and video plots self.log_static_plots(outputs, dates, trainer.loggers) From 4800b8c48b9cbb1ba2602083e8fe3744895d9684 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Wed, 29 Oct 2025 18:27:35 +0000 Subject: [PATCH 04/49] Include hemisphere type --- ice_station_zebra/types/simple_datatypes.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ice_station_zebra/types/simple_datatypes.py b/ice_station_zebra/types/simple_datatypes.py index 450e68e0..e5710d9a 100644 --- a/ice_station_zebra/types/simple_datatypes.py +++ b/ice_station_zebra/types/simple_datatypes.py @@ -113,3 +113,9 @@ class PlotSpec: # Land mask overlay land_mask_path: str | None = None + + # Optional metadata for titling + # hemisphere: "north" | "south" when known (used in titles) + hemisphere: Literal["north", "south"] | None = None + # metadata_subtitle: free-form text (e.g., "epochs=50; train=2010-2018") + metadata_subtitle: str | None = None From e413ba53d998c7d30c954ae2c5080405d3ef3a8d Mon Sep 17 00:00:00 2001 From: Lydia France Date: Wed, 29 Oct 2025 18:29:58 +0000 Subject: [PATCH 05/49] Refactor for complex plot title with metadata --- .../visualisations/plotting_maps.py | 175 +++++++++++++----- 1 file changed, 131 insertions(+), 44 deletions(-) diff --git a/ice_station_zebra/visualisations/plotting_maps.py b/ice_station_zebra/visualisations/plotting_maps.py index e53ad25a..558c1bae 100644 --- a/ice_station_zebra/visualisations/plotting_maps.py +++ b/ice_station_zebra/visualisations/plotting_maps.py @@ -10,6 +10,7 @@ import contextlib import io +import logging from collections.abc import Sequence from datetime import date, datetime from typing import Literal @@ -26,7 +27,7 @@ from ice_station_zebra.types import DiffColourmapSpec, PlotSpec from . import convert -from .layout import _add_colourbars, _build_layout, _set_axes_limits +from .layout import _add_colourbars, _build_layout, _set_axes_limits, _set_titles from .plotting_core import ( compute_difference, compute_display_ranges, @@ -40,6 +41,9 @@ ) from .range_check import compute_range_check_report +logger = logging.getLogger(__name__) + + # Keep strong references to animation objects during save to avoid GC-related warnings _ANIM_CACHE: list[animation.FuncAnimation] = [] @@ -118,6 +122,9 @@ def plot_maps( land_mask=land_mask, ) + # Restore axis titles after drawing (they were cleared in _draw_frame) + _set_titles(axs, plot_spec) + # Colourbars and title _add_colourbars( axs, @@ -129,7 +136,11 @@ def plot_maps( ) _set_axes_limits(axs, width=width, height=height) - title_text = _set_suptitle_with_box(fig, _format_date_to_string(date)) + try: + title_text = _set_suptitle_with_box(fig, _build_title_static(plot_spec, date)) + except Exception: + logger.exception("Failed to draw suptitle; continuing without title.") + title_text = None # Include range_check report (groundtruth_min, groundtruth_max), (prediction_min, prediction_max) = ( @@ -155,9 +166,10 @@ def plot_maps( # Place the warning just below the title if title_text is not None: _, title_y = title_text.get_position() - warning_y = max(title_y - 0.03, 0.0) + n_lines = title_text.get_text().count("\n") + 1 + warning_y = max(title_y - (0.08 + 0.03 * (n_lines - 1)), 0.0) else: - warning_y = 0.93 + warning_y = 0.90 _draw_badge_with_box(fig, 0.5, warning_y, badge) try: @@ -269,6 +281,8 @@ def video_maps( display_ranges_override=display_ranges, land_mask=land_mask, ) + # Restore axis titles after drawing (they were cleared in _draw_frame) + _set_titles(axs, plot_spec) # Colourbars and title _add_colourbars( axs, @@ -280,7 +294,13 @@ def video_maps( cbar_axes=cbar_axes, ) _set_axes_limits(axs, width=width, height=height) - title_text = _set_suptitle_with_box(fig, _format_date_to_string(dates[0])) + try: + title_text = _set_suptitle_with_box( + fig, _build_title_video(plot_spec, dates, 0) + ) + except Exception: + logger.exception("Failed to draw suptitle; continuing without title.") + title_text = None # Animation function def animate(tt: int) -> tuple[()]: @@ -300,8 +320,11 @@ def animate(tt: int) -> tuple[()]: display_ranges_override=display_ranges, land_mask=land_mask, ) + # Restore axis titles after drawing (they were cleared in _draw_frame) + _set_titles(axs, plot_spec) - title_text.set_text(_format_date_to_string(dates[tt])) + if title_text is not None: + title_text.set_text(_build_title_video(plot_spec, dates, tt)) return () # Create the animation object @@ -617,31 +640,6 @@ def _overlay_nans(ax: Axes, arr: np.ndarray) -> None: return image_groundtruth, image_prediction, image_difference, diff_colour_scale -def _format_date_to_string(date: date | datetime) -> str: - """Format a date or datetime object to a standardised string representation for plot titles. - - Args: - date: Date or datetime object to format. - - Returns: - Formatted string in "YYYY-MM-DD HH:MM" format for datetime objects, - or "YYYY-MM-DD" format for date objects. - - Example: - >>> from datetime import date, datetime - >>> _format_date_to_string(date(2023, 12, 25)) - '2023-12-25' - >>> _format_date_to_string(datetime(2023, 12, 25, 14, 30)) - '2023-12-25 14:30' - - """ - return ( - date.strftime(r"%Y-%m-%d %H:%M") - if isinstance(date, datetime) - else date.strftime(r"%Y-%m-%d") - ) - - def _clear_plot(ax: Axes) -> None: """Remove titles, labels, and contour collections from an axes to prevent overlaps. @@ -660,30 +658,119 @@ def _set_suptitle_with_box(fig: plt.Figure, text: str) -> Text: """Draw a fixed-position title with a white box that doesn't influence layout. Returns the Text artist so callers can update with set_text during animation. + This version avoids kwargs that are unsupported on older Matplotlib. """ - return fig.text( - 0.5, - 0.98, - text, + bbox = {"facecolor": "white", "edgecolor": "none", "pad": 2.0, "alpha": 1.0} + t = fig.text( + x=0.5, + y=0.98, + s=text, ha="center", va="top", - fontsize=plt.rcParams.get("axes.titlesize", 12), + fontsize=12, fontfamily="monospace", transform=fig.transFigure, - in_layout=False, - bbox={"facecolor": "white", "edgecolor": "none", "pad": 2.0, "alpha": 1.0}, + bbox=bbox, ) + with contextlib.suppress(Exception): + t.set_zorder(1000) + return t def _draw_badge_with_box(fig: plt.Figure, x: float, y: float, text: str) -> Text: """Draw a warning/info badge with white background box at figure coords.""" - return fig.text( - x, - y, - text, - fontsize=9, + bbox = {"facecolor": "white", "edgecolor": "none", "pad": 1.5, "alpha": 1.0} + t = fig.text( + x=x, + y=y, + s=text, + fontsize=11, + fontfamily="monospace", color="firebrick", ha="center", va="top", - bbox={"facecolor": "white", "edgecolor": "none", "pad": 1.5, "alpha": 1.0}, + bbox=bbox, ) + with contextlib.suppress(Exception): + t.set_zorder(1000) + return t + + +# --- Title helpers --- +def _formatted_variable_name(variable: str) -> str: + """Return a human-friendly variable name for titles. + + Example: "sea_ice_concentration" -> "Sea ice concentration". + """ + pretty = variable.replace("_", " ").strip() + return pretty.title() if pretty else "" + + +def _format_date_for_title(dt: date | datetime) -> str: + """Format a date/datetime to ISO date string (YYYY-MM-DD) for plot titles. + + Args: + dt: Date or datetime object to format. + + Returns: + ISO format date string (YYYY-MM-DD). Time components are stripped + from datetime objects. + + Example: + >>> from datetime import date, datetime + >>> _format_date_for_title(date(2023, 12, 25)) + '2023-12-25' + >>> _format_date_for_title(datetime(2023, 12, 25, 14, 30)) + '2023-12-25' + + """ + if isinstance(dt, datetime): + return dt.date().isoformat() + return dt.isoformat() + + +def _build_title_static(plot_spec: PlotSpec, when: date | datetime) -> str: + """Compose a readable suptitle for static plots. + + Lines: + 1) " () Shown: YYYY-MM-DD" + 2) "Model: Epoch: Training Dates: () pts" (optional) + 3) "Training Data: () ()" (optional) + """ + lines: list[str] = [] + metric = _formatted_variable_name(plot_spec.variable) + hemi = f" ({plot_spec.hemisphere.capitalize()})" if plot_spec.hemisphere else "" + lines.append(f"{metric}{hemi} Prediction Shown: {_format_date_for_title(when)}") + if plot_spec.metadata_subtitle: + lines.append(plot_spec.metadata_subtitle) + return "\n".join(lines) + + +def _build_title_video( + plot_spec: PlotSpec, + dates: Sequence[date | datetime], + current_index: int, +) -> str: + """Compose a readable multi-line suptitle for video plots. + + Lines: + 1) " () Frame: YYYY-MM-DD" + 2) "Animating from to " (shown when a date range is available) + 3) "Model: Epoch: Training Dates: () pts" (optional) + 4) "Training Data: () ()" (optional) + """ + lines: list[str] = [] + metric = _formatted_variable_name(plot_spec.variable) + hemi = f" ({plot_spec.hemisphere.capitalize()})" if plot_spec.hemisphere else "" + if dates: + lines.append( + f"{metric}{hemi} Prediction Frame: {_format_date_for_title(dates[current_index])}" + ) + start_s = _format_date_for_title(dates[0]) + end_s = _format_date_for_title(dates[-1]) + lines.append(f"Animating from {start_s} to {end_s}") + else: + lines.append(f"{metric}{hemi}") + if plot_spec.metadata_subtitle: + lines.append(plot_spec.metadata_subtitle) + return "\n".join(lines) From d1f33b12ec1d148c8ff163662fa82bea640c4c8e Mon Sep 17 00:00:00 2001 From: Lydia France Date: Wed, 29 Oct 2025 18:33:22 +0000 Subject: [PATCH 06/49] Tests for metadata titles --- tests/plotting/test_metadata.py | 419 ++++++++++++++++++++++++++++++++ 1 file changed, 419 insertions(+) create mode 100644 tests/plotting/test_metadata.py diff --git a/tests/plotting/test_metadata.py b/tests/plotting/test_metadata.py new file mode 100644 index 00000000..769264ab --- /dev/null +++ b/tests/plotting/test_metadata.py @@ -0,0 +1,419 @@ +from __future__ import annotations + +import dataclasses +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: # TC003: only used for typing + from collections.abc import Mapping, Sequence + +import pytest + +from ice_station_zebra.callbacks.metadata import ( + Metadata, + build_metadata, + build_metadata_subtitle, + calculate_training_points, + extract_variables_by_source, + format_cadence_display, + format_metadata_subtitle, + infer_hemisphere, +) +from ice_station_zebra.callbacks.plotting_callback import PlottingCallback + + +def test_build_metadata_subtitle_full_config_contains_epochs_and_range() -> None: + """build_metadata_subtitle should include epochs and train range when available.""" + config = { + "train": {"trainer": {"max_epochs": 10}}, + "split": { + "train": [ + {"start": "2000-01-01", "end": "2010-12-31"}, + {"start": "2011-01-01", "end": "2020-12-31"}, + ] + }, + } + + subtitle = build_metadata_subtitle(config) + + # Be tolerant of formatting (spaces, separators) — just assert presence of the key pieces. + assert subtitle is not None, ( + "Expected a subtitle string when config contains metadata" + ) + assert "Epoch" in subtitle + assert "10" in subtitle + assert "Training Dates:" in subtitle + assert "2000-01-01" in subtitle + assert "2020-12-31" in subtitle + + +def test_build_metadata_subtitle_missing_fields_returns_none() -> None: + """When no relevant metadata is present, build_metadata_subtitle should return None.""" + config: dict[str, Any] = {} + assert build_metadata_subtitle(config) is None + + # Config present but missing expected keys + config2: dict[str, Any] = {"some_other_section": {"foo": "bar"}} + assert build_metadata_subtitle(config2) is None + + +@pytest.mark.parametrize( + ("start", "end", "freq", "expected"), + [ + # Daily cadence tests + ("2020-01-01", "2020-01-10", "1d", 10), # 10 days inclusive + ("2020-01-01", "2020-01-10", "2d", 5), # 10 days / 2 = 5 points + ("2020-01-01", "2020-01-01", "1d", 1), # Single day + ("2020-01-01", "2020-01-02", "1d", 2), # Two days inclusive + ( + "2020-01-01", + "2020-01-05", + "3d", + 1, + ), # 5 days / 3 = 1 point (Jan 1), Jan 4 is beyond range + # Hourly cadence tests + ("2020-01-01", "2020-01-01", "1h", 24), # Single day = 24 hours + ("2020-01-01", "2020-01-01", "24h", 1), # 24h = daily + ("2020-01-01T00:00:00", "2020-01-01T23:00:00", "1h", 24), # With time component + ("2020-01-01", "2020-01-02", "12h", 4), # 2 days * 24h / 12h = 4 points + ("2020-01-01", "2020-01-01", "3h", 8), # 24h / 3h = 8 points + # Format variations + ("2020-01-01", "2020-01-10", "daily", 10), # Word format + ("2020-01-01", "2020-01-01", "hourly", 24), + ], +) +def test_calculate_training_points_parametric( + start: str, end: str, freq: str, expected: int +) -> None: + """Test calculate_training_points with various date ranges and cadences.""" + result = calculate_training_points(start, end, freq) + assert result == expected, ( + f"Expected {expected} points for {start} to {end} with {freq} cadence" + ) + + +def test_calculate_training_points_invalid_returns_none() -> None: + """Test that invalid inputs return None.""" + assert calculate_training_points(None, "2020-01-01", "1d") is None + assert calculate_training_points("2020-01-01", None, "1d") is None + assert calculate_training_points("2020-01-01", "2020-01-10", None) is None + assert calculate_training_points("", "2020-01-01", "1d") is None + # Note: "invalid" might parse as "1d" due to substring matching - this is acceptable + # The function is permissive with formats, so truly invalid formats that fail + # completely should return None + assert calculate_training_points("2020-01-01", "2020-01-10", "0d") is None + assert calculate_training_points("2020-01-01", "2020-01-10", "-1h") is None + # These should return None due to unrecognized format + assert calculate_training_points("2020-01-01", "2020-01-10", "xyz") is None + assert calculate_training_points("2020-01-01", "2020-01-10", "123") is None + + +def test_format_cadence_display() -> None: + """Test cadence display formatting.""" + assert format_cadence_display("1d") == "daily" + assert format_cadence_display("1day") == "daily" + assert format_cadence_display("1 day") == "daily" + assert format_cadence_display("1h") == "hourly" + assert format_cadence_display("1hr") == "hourly" + assert format_cadence_display("3d") == "3d" # Not 1d, so unchanged + assert format_cadence_display("24h") == "24h" # Not 1h, so unchanged + assert format_cadence_display(None) is None + + +def test_extract_variables_by_source_sic() -> None: + """Test extraction of sea ice variables from dataset config.""" + config = { + "datasets": { + "sic1": { + "name": "osisaf-sicsouth", + "group_as": "osisaf-south", + }, + "sic2": { + "name": "osisaf-sicnorth", + "group_as": "osisaf-north", + }, + } + } + result = extract_variables_by_source(config) + assert result == { + "osisaf-south": ["sea ice"], + "osisaf-north": ["sea ice"], + } + + +def test_extract_variables_by_source_weather() -> None: + """Test extraction of weather variables from dataset config.""" + config = { + "datasets": { + "era5_1": { + "name": "era5-weather", + "group_as": "era5", + "input": { + "join": [ + {"mars": {"param": ["2t", "sp"]}}, + {"mars": {"param": ["10u", "10v"]}}, + ] + }, + }, + } + } + result = extract_variables_by_source(config) + # Should extract and sort params + assert result["era5"] == ["10u", "10v", "2t", "sp"] + + +def test_extract_variables_by_source_weather_fallback() -> None: + """Test weather dataset with no params returns empty (no fallback).""" + config = { + "datasets": { + "era5_1": { + "name": "era5-weather", + "group_as": "era5", + "input": {}, + }, + } + } + result = extract_variables_by_source(config) + # When no params are found, the dataset is skipped (no fallback to "weather") + assert "era5" not in result or result["era5"] == [] + + +def test_extract_variables_by_source_empty_config() -> None: + """Test with empty or invalid config.""" + assert extract_variables_by_source({}) == {} + assert extract_variables_by_source({"datasets": {}}) == {} + assert extract_variables_by_source({"datasets": None}) == {} + + +class MockDataset: + """Mock dataset for hemisphere inference tests.""" + + def __init__( + self, + target_name: str | None = None, + inputs: Sequence[Any] | None = None, + name: str | None = None, + config: Mapping[str, Any] | None = None, + ) -> None: + """Initialize mock dataset with optional attributes for tests.""" + if target_name: + self.target = MockTarget(target_name) + else: + self.target = None # type: ignore[assignment] + self.inputs = inputs or [] + self.name = name + self.config = config + self.dataset_config = config + + +class MockTarget: + """Mock target dataset component.""" + + def __init__(self, name: str) -> None: + """Initialize mock target with a name.""" + self.name = name + + +def test_infer_hemisphere_from_target() -> None: + """Test hemisphere inference from target dataset name.""" + dataset = MockDataset(target_name="osisaf-sicsouth") + assert infer_hemisphere(dataset) == "south" # type: ignore[arg-type] + + dataset = MockDataset(target_name="osisaf-sicnorth") + assert infer_hemisphere(dataset) == "north" # type: ignore[arg-type] + + dataset = MockDataset(target_name="some-other-name") + assert infer_hemisphere(dataset) is None # type: ignore[arg-type] + + +def test_infer_hemisphere_from_top_level_name() -> None: + """Test hemisphere inference from top-level dataset name.""" + dataset = MockDataset(name="combined-south") + assert infer_hemisphere(dataset) == "south" # type: ignore[arg-type] + + dataset = MockDataset(name="combined-north") + assert infer_hemisphere(dataset) == "north" # type: ignore[arg-type] + + +def test_infer_hemisphere_from_inputs() -> None: + """Test hemisphere inference from input dataset names.""" + # Inputs as dict-like mappings + dataset = MockDataset( + inputs=[ + {"name": "input-south-dataset"}, + {"dataset_name": "other-dataset"}, + ] + ) + assert infer_hemisphere(dataset) == "south" # type: ignore[arg-type] + + # Inputs as objects with .name attribute + dataset = MockDataset( + inputs=[ + MockTarget("input-north-dataset"), + ] + ) + assert infer_hemisphere(dataset) == "north" # type: ignore[arg-type] + + +def test_infer_hemisphere_from_config() -> None: + """Test hemisphere inference from config dict.""" + dataset = MockDataset(config={"name": "config-south-name"}) + assert infer_hemisphere(dataset) == "south" # type: ignore[arg-type] + + dataset = MockDataset(config={"dataset": "config-north-dataset"}) + assert infer_hemisphere(dataset) == "north" # type: ignore[arg-type] + + +def test_infer_hemisphere_no_match() -> None: + """Test hemisphere inference returns None when no matches.""" + dataset = MockDataset() + assert infer_hemisphere(dataset) is None # type: ignore[arg-type] + + dataset = MockDataset(name="no-hemisphere-indicator") + assert infer_hemisphere(dataset) is None # type: ignore[arg-type] + + +def test_build_metadata_returns_dataclass() -> None: + """Test that build_metadata returns a Metadata dataclass with extracted fields.""" + config = { + "train": {"trainer": {"max_epochs": 10}}, + "split": { + "train": [ + {"start": "2000-01-01", "end": "2010-12-31"}, + ] + }, + "datasets": { + "sic1": { + "name": "osisaf-sicsouth", + "group_as": "osisaf-south", + }, + }, + "predict": {"dataset_group": "osisaf-south"}, + } + + metadata = build_metadata(config, model_name="test_model") + + assert isinstance(metadata, Metadata) + assert metadata.model == "test_model" + assert metadata.epochs == 10 + assert metadata.start == "2000-01-01" + assert metadata.end == "2010-12-31" + assert metadata.vars_by_source == {"osisaf-south": ["sea ice"]} + + +def test_build_metadata_empty_config() -> None: + """Test build_metadata with empty config returns Metadata with None fields.""" + metadata = build_metadata({}) + + assert isinstance(metadata, Metadata) + assert metadata.model is None + assert metadata.epochs is None + assert metadata.start is None + assert metadata.end is None + assert metadata.cadence is None + assert metadata.n_points is None + assert metadata.vars_by_source is None + + +def test_format_metadata_subtitle() -> None: + """Test format_metadata_subtitle formats Metadata dataclass correctly.""" + metadata = Metadata( + model="test_model", + epochs=5, + start="2020-01-01", + end="2020-01-10", + cadence="1d", + n_points=10, + vars_by_source={"era5": ["2t", "sp"]}, + ) + + subtitle = format_metadata_subtitle(metadata) + + assert subtitle is not None + assert "Model: test_model" in subtitle + assert "Epoch: 5" in subtitle + assert "Training Data:" in subtitle + assert "2020-01-01" in subtitle + assert "2020-01-10" in subtitle + assert "10 pts" in subtitle + + +def test_format_metadata_subtitle_minimal() -> None: + """Test format_metadata_subtitle with minimal metadata.""" + metadata = Metadata() # All None + + subtitle = format_metadata_subtitle(metadata) + + assert subtitle is None + + +def test_build_metadata_subtitle_backward_compatible() -> None: + """Test that build_metadata_subtitle still works and returns same format.""" + config = { + "train": {"trainer": {"max_epochs": 10}}, + "split": { + "train": [ + {"start": "2000-01-01", "end": "2010-12-31"}, + ] + }, + } + + # Should work the same way as before + subtitle = build_metadata_subtitle(config) + + assert subtitle is not None + assert "Epoch" in subtitle + assert "10" in subtitle + assert "2000-01-01" in subtitle + assert "2010-12-31" in subtitle + + +def test_plotting_callback_detect_land_mask_sets_hemisphere() -> None: + """Test that PlottingCallback._detect_land_mask_path sets hemisphere from dataset.""" + callback = PlottingCallback(config={}) + # Should start with no hemisphere + assert callback.plot_spec.hemisphere is None + + # Create a mock dataset with south in the name + dataset = MockDataset(target_name="osisaf-sicsouth") + callback._detect_land_mask_path(dataset) # type: ignore[arg-type] + + # Hemisphere should be set + assert callback.plot_spec.hemisphere == "south" + + +def test_plotting_callback_metadata_subtitle_from_config() -> None: + """Test that PlottingCallback sets metadata_subtitle when config is provided.""" + config = { + "train": {"trainer": {"max_epochs": 5}}, + "split": { + "train": [ + {"start": "2020-01-01", "end": "2020-01-10"}, + ] + }, + "datasets": { + "sic1": { + "name": "osisaf-sicsouth", + "group_as": "osisaf-south", + }, + }, + "predict": {"dataset_group": "osisaf-south"}, + } + callback = PlottingCallback(config=config) + # Should start with no metadata subtitle + assert ( + callback.plot_spec.metadata_subtitle is None + or callback.plot_spec.metadata_subtitle == "" + ) + + # Manually build metadata (simulating what happens in on_test_batch_end) + metadata = build_metadata_subtitle(config, model_name="test_model") + if metadata: + callback.plot_spec = dataclasses.replace( + callback.plot_spec, metadata_subtitle=metadata + ) + + # Metadata should now be set + assert callback.plot_spec.metadata_subtitle is not None + assert "Epoch" in callback.plot_spec.metadata_subtitle + assert "5" in callback.plot_spec.metadata_subtitle + assert "2020-01-01" in callback.plot_spec.metadata_subtitle From 1749e46604ce1a710694474e6c0f3e4aa5c03e5e Mon Sep 17 00:00:00 2001 From: Lydia France Date: Wed, 29 Oct 2025 18:41:53 +0000 Subject: [PATCH 07/49] Restore the ruff settings --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5a969794..0c8f6c3a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,7 @@ repos: # Run the linter - id: ruff-check args: ["--fix", "--show-fixes"] - pass_filenames: true + pass_filenames: false # Run the formatter - id: ruff-format - pass_filenames: true + pass_filenames: false From 22284edcbcbfb77116ebe074c8ee85bbbbe2907c Mon Sep 17 00:00:00 2001 From: Lydia France Date: Thu, 30 Oct 2025 13:23:55 +0000 Subject: [PATCH 08/49] Change spacing layout of figure --- .pre-commit-config.yaml | 4 ++-- ice_station_zebra/visualisations/layout.py | 21 ++++++++++++++++++--- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0c8f6c3a..5a969794 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,7 @@ repos: # Run the linter - id: ruff-check args: ["--fix", "--show-fixes"] - pass_filenames: false + pass_filenames: true # Run the formatter - id: ruff-format - pass_filenames: false + pass_filenames: true diff --git a/ice_station_zebra/visualisations/layout.py b/ice_station_zebra/visualisations/layout.py index 4d584707..cde46fe0 100644 --- a/ice_station_zebra/visualisations/layout.py +++ b/ice_station_zebra/visualisations/layout.py @@ -53,7 +53,8 @@ DEFAULT_CBAR_WIDTH = ( 0.06 # Width allocated for colourbar slots (fraction of panel width) ) -DEFAULT_TITLE_SPACE = 0.11 # Allow for the figure title, warning badge, and axes titles +DEFAULT_TITLE_SPACE = 0.07 # Reduced: simple title + warning badge + axes titles +DEFAULT_FOOTER_SPACE = 0.08 # Increased space reserved at bottom for metadata footer # Horizontal colourbar sizing (fractions of figure height) DEFAULT_CBAR_HEIGHT = ( @@ -82,6 +83,7 @@ def _build_layout( # noqa: PLR0913 gutter: float | None = None, cbar_width: float = DEFAULT_CBAR_WIDTH, title_space: float = DEFAULT_TITLE_SPACE, + footer_space: float = DEFAULT_FOOTER_SPACE, cbar_height: float = DEFAULT_CBAR_HEIGHT, cbar_pad: float = DEFAULT_CBAR_PAD, ) -> tuple[Figure, list[Axes], dict[str, Axes | None]]: @@ -111,6 +113,8 @@ def _build_layout( # noqa: PLR0913 cbar_width: Fraction of panel width allocated for each colourbar slot. title_space: Fraction of figure height reserved at the top for figure title (prevents title from overlapping with plot content). + footer_space: Fraction of figure height reserved at the bottom for the metadata + footer so it does not overlap colourbars. cbar_height: Height fraction for the horizontal colourbar row (row 2 when orientation is 'horizontal'). Controls the bar thickness. cbar_pad: Vertical gap fraction between the plot row and the colourbar row in @@ -143,6 +147,8 @@ def _build_layout( # noqa: PLR0913 # Calculate top boundary: ensure title space does not consume too much of the figure. # At least 60% of the figure height is reserved for the plotting area. top_val = max(0.6, 1.0 - (outer_margin + title_space)) + # Calculate bottom boundary, reserving footer space for metadata + bottom_val = outer_margin + footer_space # Calculate figure size based on data aspect ratio or use defaults if height and width and height > 0: @@ -189,6 +195,7 @@ def _build_layout( # noqa: PLR0913 gutter=gutter, cbar_width=cbar_width, top_val=top_val, + bottom_val=bottom_val, ) else: # Delegate to the horizontal builder which organises rows for plots and colourbars @@ -201,6 +208,7 @@ def _build_layout( # noqa: PLR0913 cbar_height=cbar_height, cbar_pad=cbar_pad, top_val=top_val, + bottom_val=bottom_val, ) _set_titles(axs, plot_spec) @@ -212,6 +220,7 @@ def _build_layout( # noqa: PLR0913 "gutter": gutter, "cbar_width": cbar_width, "title_space": title_space, + "footer_space": footer_space, "cbar_height": cbar_height, "cbar_pad": cbar_pad, } @@ -227,6 +236,7 @@ def _build_grid_vertical( # noqa: PLR0913, C901, PLR0912 gutter: float, cbar_width: float, top_val: float, + bottom_val: float, ) -> tuple[list[Axes], dict[str, Axes | None]]: """Construct a one-row GridSpec with vertical colourbars. @@ -242,6 +252,8 @@ def _build_grid_vertical( # noqa: PLR0913, C901, PLR0912 gutter: Fractional spacing between panel groups. cbar_width: Fractional width allocated to colourbar slots. top_val: The top boundary of the usable plotting area (accounts for title space). + bottom_val: The bottom boundary of the usable plotting area (accounts for + reserved footer space). Returns: A tuple of (axes, colourbar_axes) where axes are the main plot axes in order @@ -278,7 +290,7 @@ def _build_grid_vertical( # noqa: PLR0913, C901, PLR0912 left=outer_margin, right=1 - outer_margin, top=top_val, - bottom=outer_margin, + bottom=bottom_val, wspace=0.0, ) @@ -327,6 +339,7 @@ def _build_grid_horizontal( # noqa: PLR0913, PLR0912 cbar_height: float, cbar_pad: float, top_val: float, + bottom_val: float, ) -> tuple[list[Axes], dict[str, Axes | None]]: """Construct a three-row GridSpec with horizontal colourbars. @@ -346,6 +359,8 @@ def _build_grid_horizontal( # noqa: PLR0913, PLR0912 cbar_height: Fractional height allocated to the colourbar row. cbar_pad: Fractional padding between the plot row and colourbar row. top_val: The top boundary of the usable plotting area (accounts for title space). + bottom_val: The bottom boundary of the usable plotting area (accounts for + reserved footer space). Returns: A tuple of (axes, colourbar_axes) where axes are the main plot axes in order @@ -366,7 +381,7 @@ def _build_grid_horizontal( # noqa: PLR0913, PLR0912 left=outer_margin, right=1 - outer_margin, top=top_val, - bottom=outer_margin, + bottom=bottom_val, wspace=0.0, hspace=0.0, ) From a08fb5701f7cc674eb5d75021ec215c4ca1b6f91 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Thu, 30 Oct 2025 13:26:56 +0000 Subject: [PATCH 09/49] Include optional footer with metadata not title --- .../visualisations/plotting_maps.py | 139 +++++++++++++----- 1 file changed, 101 insertions(+), 38 deletions(-) diff --git a/ice_station_zebra/visualisations/plotting_maps.py b/ice_station_zebra/visualisations/plotting_maps.py index 558c1bae..b0cf07aa 100644 --- a/ice_station_zebra/visualisations/plotting_maps.py +++ b/ice_station_zebra/visualisations/plotting_maps.py @@ -100,8 +100,37 @@ def plot_maps( # Load land mask if specified land_mask = load_land_mask(plot_spec.land_mask_path, (height, width)) - # Initialise the figure and axes - fig, axs, cbar_axes = _build_layout(plot_spec=plot_spec, height=height, width=width) + # Pre-compute range check to decide top spacing (warning badge may need extra room) + (gt_min, gt_max), (pred_min, pred_max) = compute_display_ranges( + ground_truth, prediction, plot_spec + ) + range_check_report = compute_range_check_report( + ground_truth, + prediction, + vmin=gt_min, + vmax=gt_max, + outside_warn=getattr(plot_spec, "outside_warn", 0.05), + severe_outside=getattr(plot_spec, "severe_outside", 0.20), + include_shared_range_mismatch_check=getattr( + plot_spec, "include_shared_range_mismatch_check", True + ), + ) + + # Increase title space if warnings are present to avoid overlap with axes titles + title_space_override = 0.10 if range_check_report.warnings else None + + # Initialise the figure and axes with dynamic top spacing if needed + if title_space_override is not None: + fig, axs, cbar_axes = _build_layout( + plot_spec=plot_spec, + height=height, + width=width, + title_space=title_space_override, + ) + else: + fig, axs, cbar_axes = _build_layout( + plot_spec=plot_spec, height=height, width=width + ) levels = levels_from_spec(plot_spec) # Prepare difference rendering parameters if needed @@ -142,21 +171,7 @@ def plot_maps( logger.exception("Failed to draw suptitle; continuing without title.") title_text = None - # Include range_check report - (groundtruth_min, groundtruth_max), (prediction_min, prediction_max) = ( - compute_display_ranges(ground_truth, prediction, plot_spec) - ) - range_check_report = compute_range_check_report( - ground_truth, - prediction, - vmin=groundtruth_min, - vmax=groundtruth_max, - outside_warn=getattr(plot_spec, "outside_warn", 0.05), - severe_outside=getattr(plot_spec, "severe_outside", 0.20), - include_shared_range_mismatch_check=getattr( - plot_spec, "include_shared_range_mismatch_check", True - ), - ) + # Include range_check report (already computed above) badge = ( "" if not range_check_report.warnings @@ -167,11 +182,21 @@ def plot_maps( if title_text is not None: _, title_y = title_text.get_position() n_lines = title_text.get_text().count("\n") + 1 - warning_y = max(title_y - (0.08 + 0.03 * (n_lines - 1)), 0.0) + # Reduced gap to avoid overlapping axes titles; keep badge close to title + warning_y = max(title_y - (0.05 + 0.02 * (n_lines - 1)), 0.0) else: warning_y = 0.90 _draw_badge_with_box(fig, 0.5, warning_y, badge) + # Footer metadata at the bottom + if getattr(plot_spec, "include_footer_metadata", True): + try: + footer_text = _build_footer_static(plot_spec) + if footer_text: + _set_footer_with_box(fig, footer_text) + except Exception: + logger.exception("Failed to draw footer; continuing without footer.") + try: return {"sea-ice_concentration-static-maps": [convert._image_from_figure(fig)]} finally: @@ -240,8 +265,10 @@ def video_maps( ground_truth_stream = np.where(land_mask, np.nan, ground_truth_stream) prediction_stream = np.where(land_mask, np.nan, prediction_stream) - # Initialise the figure and axes - fig, axs, cbar_axes = _build_layout(plot_spec=plot_spec, height=height, width=width) + # Initialise the figure and axes with a larger footer space for videos + fig, axs, cbar_axes = _build_layout( + plot_spec=plot_spec, height=height, width=width, footer_space=0.11 + ) levels = levels_from_spec(plot_spec) # Stable ranges for the whole animation @@ -302,6 +329,15 @@ def video_maps( logger.exception("Failed to draw suptitle; continuing without title.") title_text = None + # Footer metadata at the bottom (static across frames) + if getattr(plot_spec, "include_footer_metadata", True): + try: + footer_text = _build_footer_video(plot_spec, dates) + if footer_text: + _set_footer_with_box(fig, footer_text) + except Exception: + logger.exception("Failed to draw footer; continuing without footer.") + # Animation function def animate(tt: int) -> tuple[()]: precomputed_diff_tt = ( @@ -677,6 +713,28 @@ def _set_suptitle_with_box(fig: plt.Figure, text: str) -> Text: return t +def _set_footer_with_box(fig: plt.Figure, text: str) -> Text: + """Draw a fixed-position footer with a white box at bottom center. + + Footer is intended for metadata and secondary information. + """ + bbox = {"facecolor": "white", "edgecolor": "none", "pad": 2.0, "alpha": 1.0} + t = fig.text( + x=0.5, + y=0.03, + s=text, + ha="center", + va="bottom", + fontsize=11, + fontfamily="monospace", + transform=fig.transFigure, + bbox=bbox, + ) + with contextlib.suppress(Exception): + t.set_zorder(1000) + return t + + def _draw_badge_with_box(fig: plt.Figure, x: float, y: float, text: str) -> Text: """Draw a warning/info badge with white background box at figure coords.""" bbox = {"facecolor": "white", "edgecolor": "none", "pad": 1.5, "alpha": 1.0} @@ -730,20 +788,15 @@ def _format_date_for_title(dt: date | datetime) -> str: def _build_title_static(plot_spec: PlotSpec, when: date | datetime) -> str: - """Compose a readable suptitle for static plots. + """Compose a simple suptitle for static plots. Lines: 1) " () Shown: YYYY-MM-DD" - 2) "Model: Epoch: Training Dates: () pts" (optional) - 3) "Training Data: () ()" (optional) + (Footer contains any metadata such as model/epoch/training data if present) """ - lines: list[str] = [] metric = _formatted_variable_name(plot_spec.variable) hemi = f" ({plot_spec.hemisphere.capitalize()})" if plot_spec.hemisphere else "" - lines.append(f"{metric}{hemi} Prediction Shown: {_format_date_for_title(when)}") - if plot_spec.metadata_subtitle: - lines.append(plot_spec.metadata_subtitle) - return "\n".join(lines) + return f"{metric}{hemi} Prediction Shown: {_format_date_for_title(when)}" def _build_title_video( @@ -751,26 +804,36 @@ def _build_title_video( dates: Sequence[date | datetime], current_index: int, ) -> str: - """Compose a readable multi-line suptitle for video plots. + """Compose a simple suptitle for video plots (date changes per frame). Lines: 1) " () Frame: YYYY-MM-DD" - 2) "Animating from to " (shown when a date range is available) - 3) "Model: Epoch: Training Dates: () pts" (optional) - 4) "Training Data: () ()" (optional) + 2) Footer: "Animating from to " + 3) Footer: "Model: Epoch: Training Dates: () pts" (optional) + 4) Footer: "Training Data: () ()" (optional) """ - lines: list[str] = [] metric = _formatted_variable_name(plot_spec.variable) hemi = f" ({plot_spec.hemisphere.capitalize()})" if plot_spec.hemisphere else "" if dates: - lines.append( - f"{metric}{hemi} Prediction Frame: {_format_date_for_title(dates[current_index])}" - ) + return f"{metric}{hemi} Prediction Frame: {_format_date_for_title(dates[current_index])}" + return f"{metric}{hemi} Prediction" + + +def _build_footer_static(plot_spec: PlotSpec) -> str: + """Build footer text for static plots using metadata that used to be in title.""" + lines: list[str] = [] + if plot_spec.metadata_subtitle: + lines.append(plot_spec.metadata_subtitle) + return "\n".join(lines) + + +def _build_footer_video(plot_spec: PlotSpec, dates: Sequence[date | datetime]) -> str: + """Build footer text for video plots: animation range and metadata.""" + lines: list[str] = [] + if dates: start_s = _format_date_for_title(dates[0]) end_s = _format_date_for_title(dates[-1]) lines.append(f"Animating from {start_s} to {end_s}") - else: - lines.append(f"{metric}{hemi}") if plot_spec.metadata_subtitle: lines.append(plot_spec.metadata_subtitle) return "\n".join(lines) From 18e14ecb47731441b447d7de676db3f762272c18 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Thu, 30 Oct 2025 13:27:26 +0000 Subject: [PATCH 10/49] Include option for footer --- ice_station_zebra/config/evaluate/callbacks/plotting.yaml | 1 + ice_station_zebra/types/simple_datatypes.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/ice_station_zebra/config/evaluate/callbacks/plotting.yaml b/ice_station_zebra/config/evaluate/callbacks/plotting.yaml index 0b57044e..fe2b0942 100644 --- a/ice_station_zebra/config/evaluate/callbacks/plotting.yaml +++ b/ice_station_zebra/config/evaluate/callbacks/plotting.yaml @@ -23,3 +23,4 @@ plotting: vmax: 1.0 colourbar_location: "horizontal" colourbar_strategy: "separate" # shared | separate (prediction gets its own colourbar) + include_footer_metadata: false diff --git a/ice_station_zebra/types/simple_datatypes.py b/ice_station_zebra/types/simple_datatypes.py index e5710d9a..4eba0fc4 100644 --- a/ice_station_zebra/types/simple_datatypes.py +++ b/ice_station_zebra/types/simple_datatypes.py @@ -119,3 +119,6 @@ class PlotSpec: hemisphere: Literal["north", "south"] | None = None # metadata_subtitle: free-form text (e.g., "epochs=50; train=2010-2018") metadata_subtitle: str | None = None + + # Footer control + include_footer_metadata: bool = True From 075e1229a6bb45f83461897f2f4182938881e9af Mon Sep 17 00:00:00 2001 From: Lydia France Date: Thu, 30 Oct 2025 13:35:44 +0000 Subject: [PATCH 11/49] Extra tests to check the textboxes don't overlap --- tests/plotting/test_layout.py | 62 ++++++++++++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/tests/plotting/test_layout.py b/tests/plotting/test_layout.py index fd96fa90..ac4ac598 100644 --- a/tests/plotting/test_layout.py +++ b/tests/plotting/test_layout.py @@ -10,7 +10,12 @@ import pytest from ice_station_zebra.visualisations.layout import _build_layout, _set_axes_limits -from ice_station_zebra.visualisations.plotting_maps import DEFAULT_SIC_SPEC +from ice_station_zebra.visualisations.plotting_maps import ( + DEFAULT_SIC_SPEC, + _draw_badge_with_box, + _set_footer_with_box, + _set_suptitle_with_box, +) from .test_helper_plot_layout import axis_rectangle, rectangles_overlap @@ -157,3 +162,58 @@ def test_y_axis_orientation_for_geographical_data() -> None: assert x_max == width, f"X-axis maximum should be {width}, got {x_max}" plt.close(fig) + + +def _text_rectangle( + fig: plt.Figure, text_artist: plt.Text +) -> tuple[float, float, float, float]: + """Return text bounding box in figure-normalised coords [0, 1].""" + fig.canvas.draw() + renderer = fig.canvas.get_renderer() + bbox = text_artist.get_window_extent(renderer=renderer) + # Transform display to figure coords + (x0, y0), (x1, y1) = fig.transFigure.inverted().transform( + [(bbox.x0, bbox.y0), (bbox.x1, bbox.y1)] + ) + return (float(x0), float(y0), float(x1), float(y1)) + + +@pytest.mark.parametrize("colourbar_location", ["horizontal", "vertical"]) +@pytest.mark.parametrize("include_difference", [False, True]) +def test_figure_text_boxes_do_not_overlap( + sic_pair_2d: tuple[np.ndarray, np.ndarray, date], + *, + colourbar_location: str, + include_difference: bool, +) -> None: + """Ensure figure title, warning badge, and footer do not overlap panels or colourbars.""" + ground_truth, _, _ = sic_pair_2d + + spec = replace( + DEFAULT_SIC_SPEC, + colourbar_location=colourbar_location, # type: ignore[arg-type] + include_difference=include_difference, + ) + + fig, axes, caxes = _build_layout( + plot_spec=spec, height=ground_truth.shape[0], width=ground_truth.shape[1] + ) + + # Add figure-level title, warning badge (synthetic), and footer + title = _set_suptitle_with_box(fig, "Title") + ty = title.get_position()[1] + badge = _draw_badge_with_box(fig, 0.5, max(ty - 0.05, 0.0), "Warnings: example") + footer = _set_footer_with_box(fig, "Footer metadata") + + # Collect rectangles: panels, colourbar axes, and figure texts + rectangles = [axis_rectangle(ax) for ax in axes] + rectangles.extend(axis_rectangle(ax) for ax in caxes.values() if ax is not None) + rectangles.append(_text_rectangle(fig, title)) + rectangles.append(_text_rectangle(fig, badge)) + rectangles.append(_text_rectangle(fig, footer)) + + # No overlaps among any of these elements + for rect_a, rect_b in combinations(rectangles, 2): + assert not rectangles_overlap(rect_a, rect_b), ( + f"Found overlap between rectangles {rect_a} and {rect_b}" + ) From 13f842f32afbe8b1730f2725242691193658603c Mon Sep 17 00:00:00 2001 From: Lydia France Date: Thu, 30 Oct 2025 13:43:05 +0000 Subject: [PATCH 12/49] Fix for mypy and restore ruff settings --- .pre-commit-config.yaml | 4 ++-- tests/plotting/test_layout.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5a969794..0c8f6c3a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,7 @@ repos: # Run the linter - id: ruff-check args: ["--fix", "--show-fixes"] - pass_filenames: true + pass_filenames: false # Run the formatter - id: ruff-format - pass_filenames: true + pass_filenames: false diff --git a/tests/plotting/test_layout.py b/tests/plotting/test_layout.py index ac4ac598..2ebc823f 100644 --- a/tests/plotting/test_layout.py +++ b/tests/plotting/test_layout.py @@ -3,7 +3,7 @@ from dataclasses import replace from itertools import combinations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import matplotlib as mpl import matplotlib.pyplot as plt @@ -169,7 +169,13 @@ def _text_rectangle( ) -> tuple[float, float, float, float]: """Return text bounding box in figure-normalised coords [0, 1].""" fig.canvas.draw() - renderer = fig.canvas.get_renderer() + # Obtain a renderer in a backend-agnostic, mypy-friendly way + canvas: Any = fig.canvas + get_renderer = getattr(canvas, "get_renderer", None) + if callable(get_renderer): + renderer = get_renderer() + else: + renderer = getattr(canvas, "renderer", None) bbox = text_artist.get_window_extent(renderer=renderer) # Transform display to figure coords (x0, y0), (x1, y1) = fig.transFigure.inverted().transform( From 4faf6d506b39fa347656ccba3401fc34de7a072a Mon Sep 17 00:00:00 2001 From: Lydia France Date: Thu, 30 Oct 2025 16:30:16 +0000 Subject: [PATCH 13/49] Make sure footer is in the config --- ice_station_zebra/config/evaluate/callbacks/plotting.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ice_station_zebra/config/evaluate/callbacks/plotting.yaml b/ice_station_zebra/config/evaluate/callbacks/plotting.yaml index fe2b0942..d01bb3ed 100644 --- a/ice_station_zebra/config/evaluate/callbacks/plotting.yaml +++ b/ice_station_zebra/config/evaluate/callbacks/plotting.yaml @@ -23,4 +23,4 @@ plotting: vmax: 1.0 colourbar_location: "horizontal" colourbar_strategy: "separate" # shared | separate (prediction gets its own colourbar) - include_footer_metadata: false + include_footer_metadata: true From b14e779f8c9f3b47b10693217e7526907ca44c82 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Thu, 30 Oct 2025 16:32:36 +0000 Subject: [PATCH 14/49] Expose n_history_steps for plots --- ice_station_zebra/callbacks/metadata.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/ice_station_zebra/callbacks/metadata.py b/ice_station_zebra/callbacks/metadata.py index d539a134..40045204 100644 --- a/ice_station_zebra/callbacks/metadata.py +++ b/ice_station_zebra/callbacks/metadata.py @@ -28,6 +28,7 @@ class Metadata: cadence: Training data cadence string (if available). n_points: Number of training points calculated from date range and cadence. vars_by_source: Dictionary mapping dataset source names to lists of variable names. + n_history_steps: Number of history steps used as model input window (days). """ @@ -37,6 +38,7 @@ class Metadata: end: str | None = None cadence: str | None = None n_points: int | None = None + n_history_steps: int | None = None vars_by_source: dict[str, list[str]] | None = None @@ -264,6 +266,14 @@ def build_metadata( # Get variables grouped by source vars_by_source = extract_variables_by_source(config) + # Extract history window length (days) if available + n_history_steps: int | None = None + predict_cfg = config.get("predict") if isinstance(config, dict) else None + if isinstance(predict_cfg, dict): + nh = predict_cfg.get("n_history_steps") + if isinstance(nh, int): + n_history_steps = nh + return Metadata( model=model_name if isinstance(model_name, str) and model_name else None, epochs=max_epochs if isinstance(max_epochs, int) else None, @@ -271,11 +281,12 @@ def build_metadata( end=end_str, cadence=cadence_str, n_points=training_points, + n_history_steps=n_history_steps, vars_by_source=vars_by_source if vars_by_source else None, ) -def format_metadata_subtitle(metadata: Metadata) -> str | None: # noqa: C901 +def format_metadata_subtitle(metadata: Metadata) -> str | None: # noqa: C901, PLR0912 """Format metadata dataclass as a compact multi-line subtitle for plot titles. Lines: @@ -312,7 +323,13 @@ def format_metadata_subtitle(metadata: Metadata) -> str | None: # noqa: C901 cadence_display = format_cadence_display(metadata.cadence) dates_part = f"Training Dates: {s_clean} — {e_clean}" if cadence_display: - dates_part += f" ({cadence_display})" + # Append cadence, and if available, the history window length + if metadata.n_history_steps is not None and metadata.n_history_steps > 0: + dates_part += ( + f" ({cadence_display}, {metadata.n_history_steps} day history)" + ) + else: + dates_part += f" ({cadence_display})" if metadata.n_points is not None: dates_part += f" {metadata.n_points} pts" info_parts.append(dates_part) From 023a2ef5acc2caba43cd7eedf22d73f91fcb45de Mon Sep 17 00:00:00 2001 From: Lydia France Date: Sat, 22 Nov 2025 13:49:45 +0000 Subject: [PATCH 15/49] Ruff settings --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0c8f6c3a..5a969794 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,7 @@ repos: # Run the linter - id: ruff-check args: ["--fix", "--show-fixes"] - pass_filenames: false + pass_filenames: true # Run the formatter - id: ruff-format - pass_filenames: false + pass_filenames: true From 2c00b133e01c69929fac483333e5c45e8bc871bc Mon Sep 17 00:00:00 2001 From: Lydia France Date: Sat, 22 Nov 2025 13:52:48 +0000 Subject: [PATCH 16/49] return variable names from input data --- .../data_loaders/combined_dataset.py | 17 +++++++++++++++++ ice_station_zebra/data_loaders/zebra_dataset.py | 14 ++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/ice_station_zebra/data_loaders/combined_dataset.py b/ice_station_zebra/data_loaders/combined_dataset.py index 2e605af7..bd0265b8 100644 --- a/ice_station_zebra/data_loaders/combined_dataset.py +++ b/ice_station_zebra/data_loaders/combined_dataset.py @@ -117,3 +117,20 @@ def start_date(self) -> np.datetime64: msg = f"Datasets have {len(start_date)} different start dates" raise ValueError(msg) return start_date.pop() + + @property + def input_variable_names(self) -> list[str]: + """Return all input variable names across all input datasets. + + Variable names are prefixed with the dataset name for disambiguation. + Format: "{dataset_name}:{variable_name}" + + Returns: + List of variable names in the order they appear in the combined input channels. + + """ + return [ + f"{ds.name}:{var_name}" + for ds in self.inputs + for var_name in ds.variable_names + ] diff --git a/ice_station_zebra/data_loaders/zebra_dataset.py b/ice_station_zebra/data_loaders/zebra_dataset.py index d6a682c0..46ae5b4a 100644 --- a/ice_station_zebra/data_loaders/zebra_dataset.py +++ b/ice_station_zebra/data_loaders/zebra_dataset.py @@ -94,6 +94,20 @@ def start_date(self) -> np.datetime64: """Return the start date of the dataset.""" return self.dates[0] + @cached_property + def variable_names(self) -> list[str]: + """Return the variable names for this dataset. + + The variable names are extracted from the underlying Anemoi dataset. + All datasets must have the same variables. + """ + # Check all datasets have the same variables + per_ds_variables = [tuple(ds.variables) for ds in self.datasets] + if len(set(per_ds_variables)) != 1: + msg = f"All datasets must have the same variables, found {len(set(per_ds_variables))} different sets" + raise ValueError(msg) + return list(self.datasets[0].variables) + def __len__(self) -> int: """Return the total length of the dataset.""" if self._len is None: From 3dcded7bf907fa1847380e425f3338b889688f21 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Sat, 22 Nov 2025 13:58:11 +0000 Subject: [PATCH 17/49] Preventing memory leaks from plotting --- ice_station_zebra/evaluation/evaluator.py | 5 +++++ ice_station_zebra/visualisations/__init__.py | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/ice_station_zebra/evaluation/evaluator.py b/ice_station_zebra/evaluation/evaluator.py index 11489388..5abc93d6 100644 --- a/ice_station_zebra/evaluation/evaluator.py +++ b/ice_station_zebra/evaluation/evaluator.py @@ -2,6 +2,11 @@ from pathlib import Path from typing import TYPE_CHECKING +# Set matplotlib backend BEFORE any imports that might use it +import matplotlib as mpl + +mpl.use("Agg") + import hydra from lightning.fabric.utilities import suggested_max_num_workers from omegaconf import DictConfig, OmegaConf diff --git a/ice_station_zebra/visualisations/__init__.py b/ice_station_zebra/visualisations/__init__.py index e60635f0..ea88bf9a 100644 --- a/ice_station_zebra/visualisations/__init__.py +++ b/ice_station_zebra/visualisations/__init__.py @@ -1,3 +1,9 @@ +# Set non-interactive matplotlib backend early to prevent hangs on macOS +# make sure the import is not reordered by the formatter +import matplotlib as mpl + +mpl.use("Agg") + from ice_station_zebra.types import PlotSpec from .plotting_core import detect_land_mask_path From 56ecffe387054e5ddea3ef2f58b998d74b9f6fcd Mon Sep 17 00:00:00 2001 From: Lydia France Date: Sat, 22 Nov 2025 14:00:49 +0000 Subject: [PATCH 18/49] prevent memory leak and log more errors --- ice_station_zebra/callbacks/plotting_callback.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ice_station_zebra/callbacks/plotting_callback.py b/ice_station_zebra/callbacks/plotting_callback.py index 0212b899..22d834c3 100644 --- a/ice_station_zebra/callbacks/plotting_callback.py +++ b/ice_station_zebra/callbacks/plotting_callback.py @@ -1,3 +1,8 @@ +# Set matplotlib backend before any plotting imports +import matplotlib as mpl + +mpl.use("Agg") + import dataclasses import logging from collections.abc import Mapping, Sequence @@ -207,7 +212,7 @@ def log_static_plots( ) except InvalidArrayError as err: logger.warning("Static plotting skipped due to invalid arrays: %s", err) - except (ValueError, MemoryError, OSError): + except Exception: logger.exception("Static plotting failed") def log_video_plots( From 6c83ddf20d01fb4e1bba55a2c7306f189f5517ac Mon Sep 17 00:00:00 2001 From: Lydia France Date: Sat, 22 Nov 2025 14:04:07 +0000 Subject: [PATCH 19/49] Raw input callback file --- .../callbacks/raw_inputs_callback.py | 521 ++++++++++++++++++ 1 file changed, 521 insertions(+) create mode 100644 ice_station_zebra/callbacks/raw_inputs_callback.py diff --git a/ice_station_zebra/callbacks/raw_inputs_callback.py b/ice_station_zebra/callbacks/raw_inputs_callback.py new file mode 100644 index 00000000..35b470c8 --- /dev/null +++ b/ice_station_zebra/callbacks/raw_inputs_callback.py @@ -0,0 +1,521 @@ +"""Callback for plotting raw input variables during evaluation.""" + +# Set matplotlib backend before any plotting imports +import matplotlib as mpl + +mpl.use("Agg") + +import dataclasses +import gc +import logging +from collections.abc import Mapping, Sequence +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal, cast + +import numpy as np +from lightning import LightningModule, Trainer +from lightning.pytorch import Callback +from lightning.pytorch.loggers import Logger as LightningLogger + +from ice_station_zebra.callbacks.metadata import infer_hemisphere +from ice_station_zebra.data_loaders import CombinedDataset +from ice_station_zebra.exceptions import InvalidArrayError, VideoRenderError +from ice_station_zebra.types import PlotSpec +from ice_station_zebra.visualisations import DEFAULT_SIC_SPEC, detect_land_mask_path +from ice_station_zebra.visualisations.plotting_core import safe_filename +from ice_station_zebra.visualisations.plotting_raw_inputs import ( + plot_raw_inputs_for_timestep, + video_raw_inputs_for_timesteps, +) + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + +logger = logging.getLogger(__name__) + +# Constants +EXPECTED_INPUT_NDIM = 5 # Expected input data shape: [B, T, C, H, W] + + +class RawInputsCallback(Callback): + """A callback to plot raw input variables during evaluation.""" + + def __init__( # noqa: PLR0913 + self, + *, + frequency: int | None = None, + save_dir: str | Path | None = None, + plot_spec: PlotSpec | None = None, + config: dict | None = None, + timestep_index: int = 0, + variable_styles: dict[str, dict[str, Any]] | None = None, + make_video_plots: bool = False, + video_fps: int = 2, + video_format: Literal["mp4", "gif"] = "gif", + video_save_dir: str | Path | None = None, + max_animation_frames: int | None = None, + log_to_wandb: bool = True, + ) -> None: + """Create raw input plots and/or animations during evaluation. + + Args: + frequency: Create plots every `frequency` batches; `None` plots once per run. + save_dir: Directory to save static plots to. If None and log_to_wandb=False, no plots saved. + plot_spec: Plotting specification (colourmap, hemisphere, etc.). + config: Configuration dictionary for land mask detection. + timestep_index: Which history timestep to plot (0 = most recent). + variable_styles: Per-variable styling overrides (cmap, vmin/vmax, units, etc.). + make_video_plots: Whether to create temporal animations of raw inputs. + video_fps: Frames per second for animations. + video_format: Video format ("mp4" or "gif"). + video_save_dir: Directory to save animations. If None and log_to_wandb=False, no videos saved. + max_animation_frames: Maximum number of frames to include in animations (None = unlimited). + Limits temporal accumulation to control memory and file size. + log_to_wandb: Whether to log plots and animations to WandB (default: True). + + """ + super().__init__() + if frequency is None: + self.frequency = None + else: + self.frequency = int(max(1, frequency)) + self.save_dir = Path(save_dir) if save_dir else None + self.timestep_index = timestep_index + self.variable_styles = variable_styles or {} + self._has_plotted = False + + # Animation settings + self.make_video_plots = make_video_plots + self.video_fps = video_fps + self.video_format = video_format + self.video_save_dir = Path(video_save_dir) if video_save_dir else self.save_dir + self.max_animation_frames = max_animation_frames + + # WandB logging control + self.log_to_wandb = log_to_wandb + + # Ensure plot_spec is a PlotSpec instance + if plot_spec is None: + self.plot_spec = DEFAULT_SIC_SPEC + else: + self.plot_spec = plot_spec + + self.config = config or {} + self._land_mask_path_detected = False + self._land_mask_array: np.ndarray | None = None + + # Temporal data accumulation for animations + self._temporal_data: dict[ + str, list[np.ndarray] + ] = {} # var_name -> list of [H,W] arrays + self._temporal_dates: list[Any] = [] + self._dataset_ref: CombinedDataset | None = None + + def on_test_start(self, _trainer: Trainer, _module: LightningModule) -> None: + """Called when the test loop starts.""" + logger.info("RawInputsCallback: Test loop started") + + def _detect_land_mask_path( # noqa: C901 + self, + dataset: CombinedDataset, + ) -> None: + """Detect and set the land mask path based on the dataset configuration.""" + if self._land_mask_path_detected: + return + + # Get base path from callback config or use default + base_path = self.config.get("base_path", "../ice-station-zebra/data") + + # Try to get dataset name from the target dataset + dataset_name = None + if hasattr(dataset, "target") and hasattr(dataset.target, "name"): + dataset_name = dataset.target.name + + # Detect land mask path + land_mask_path = detect_land_mask_path(base_path, dataset_name) + + # Always try to infer hemisphere regardless of land mask presence + hemisphere: Literal["north", "south"] | None = None + if isinstance(dataset_name, str): + low = dataset_name.lower() + if "south" in low: + hemisphere = cast("Literal['north', 'south']", "south") + elif "north" in low: + hemisphere = cast("Literal['north', 'south']", "north") + if hemisphere is None: + hemi_candidate = infer_hemisphere(dataset) + if hemi_candidate in ("north", "south"): + hemisphere = cast("Literal['north', 'south']", hemi_candidate) + + # Update plot_spec pieces independently + if hemisphere is not None and self.plot_spec.hemisphere != hemisphere: + self.plot_spec = dataclasses.replace(self.plot_spec, hemisphere=hemisphere) + + if land_mask_path: + # Set land mask path when found + self.plot_spec = dataclasses.replace( + self.plot_spec, + land_mask_path=land_mask_path, + ) + logger.info("Auto-detected land mask: %s", land_mask_path) + + # Load the land mask array + try: + self._land_mask_array = np.load(land_mask_path) + logger.debug( + "Loaded land mask array with shape: %s", self._land_mask_array.shape + ) + except Exception: + logger.exception("Failed to load land mask from %s", land_mask_path) + self._land_mask_array = None + else: + logger.debug("No land mask found for dataset: %s", dataset_name) + + self._land_mask_path_detected = True + + def on_test_batch_end( # noqa: C901, PLR0912 + self, + trainer: Trainer, + _module: LightningModule, + _outputs: Any, # noqa: ANN401 + batch: Any, # noqa: ANN401 + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Called when the test batch ends.""" + logger.debug( + "RawInputsCallback.on_test_batch_end called for batch %d", batch_idx + ) + + # Get dataset and date + dl: DataLoader | list[DataLoader] | None = trainer.test_dataloaders + if dl is None: + logger.warning("No test dataloaders found, skipping raw inputs plotting.") + return + + dataset = (dl[dataloader_idx] if isinstance(dl, Sequence) else dl).dataset + if not isinstance(dataset, CombinedDataset): + logger.warning( + "Dataset is of type %s, skipping raw inputs plotting.", type(dataset) + ) + return + + # Detect land mask path if not already done + self._detect_land_mask_path(dataset) + + # Store dataset reference for later use + if self._dataset_ref is None: + self._dataset_ref = dataset + + # Get the date for this batch's first sample + date = dataset.date_from_index(batch_idx) + + # Determine if we should plot static plots this batch + should_plot_static = False + if self.frequency is None: + should_plot_static = not self._has_plotted + else: + should_plot_static = batch_idx % self.frequency == 0 + + # Always accumulate temporal data if animations are enabled + # (regardless of static plot frequency) + if self.make_video_plots: + # Check if we've reached the frame limit + if ( + self.max_animation_frames is not None + and len(self._temporal_dates) >= self.max_animation_frames + ): + logger.debug( + "Reached max animation frames (%d), skipping further accumulation", + self.max_animation_frames, + ) + else: + try: + # Extract data without plotting + channel_arrays, channel_names = self._extract_channel_data( + batch, dataset + ) + if channel_arrays and channel_names: + self._accumulate_temporal_data( + channel_arrays, channel_names, date + ) + except (ValueError, RuntimeError, IndexError, KeyError) as e: + logger.warning( + "Failed to accumulate temporal data for batch %d: %s", + batch_idx, + e, + ) + + # Create static plots according to frequency + if should_plot_static: + logger.debug("Starting raw inputs plotting for batch %d", batch_idx) + try: + self.log_raw_inputs(batch, dataset, date, trainer.loggers, batch_idx) + logger.info( + "Successfully completed raw inputs plotting for batch %d", batch_idx + ) + except Exception: + logger.exception("Raw inputs plotting failed for batch %d", batch_idx) + else: + if self.frequency is None: + self._has_plotted = True + else: + logger.debug( + "Skipping static plots for batch %d (frequency=%s)", + batch_idx, + self.frequency, + ) + + def _extract_channel_data( + self, + batch: Mapping[str, Any], + dataset: CombinedDataset, + ) -> tuple[list[np.ndarray], list[str]]: + """Extract channel data from batch without plotting. + + Returns: + Tuple of (channel_arrays, channel_names). + + """ + # Collect all input channel arrays + channel_arrays = [] + for ds in dataset.inputs: + if ds.name not in batch: + logger.warning("Dataset %s not found in batch, skipping", ds.name) + continue + + input_data = batch[ds.name] # Shape: [B, T, C, H, W] + + # Take first batch and specified timestep + if input_data.ndim != EXPECTED_INPUT_NDIM: + logger.warning( + "Expected 5D input data [B,T,C,H,W], got shape %s for %s", + input_data.shape, + ds.name, + ) + continue + + timestep_data = input_data[0, self.timestep_index] # Shape: [C, H, W] + + # Add each channel as a 2D array + for c in range(timestep_data.shape[0]): + channel_arr = timestep_data[c].detach().cpu().numpy() + channel_arrays.append(channel_arr) + + if not channel_arrays: + logger.warning("No input channels found in batch") + return [], [] + + # Get variable names from dataset + channel_names = dataset.input_variable_names + + if len(channel_arrays) != len(channel_names): + logger.warning( + "Mismatch: %d channel arrays but %d channel names. Using generic names.", + len(channel_arrays), + len(channel_names), + ) + channel_names = [f"channel_{i}" for i in range(len(channel_arrays))] + + return channel_arrays, channel_names + + def log_raw_inputs( + self, + batch: Mapping[str, Any], + dataset: CombinedDataset, + date: Any, # noqa: ANN401 + lightning_loggers: list[LightningLogger], + _batch_idx: int, + ) -> None: + """Extract and log raw input plots.""" + # Early return if nothing will be saved + if not self.log_to_wandb and self.save_dir is None: + logger.debug( + "Skipping raw inputs plotting: log_to_wandb=False and save_dir=None" + ) + return + + try: + # Extract data + channel_arrays, channel_names = self._extract_channel_data(batch, dataset) + + if not channel_arrays: + logger.warning( + "No input channels found in batch, skipping raw inputs plotting" + ) + return + + # Plot the raw inputs + results = plot_raw_inputs_for_timestep( + channel_arrays=channel_arrays, + channel_names=channel_names, + when=date, + plot_spec_base=self.plot_spec, + land_mask=self._land_mask_array, + styles=self.variable_styles, + save_dir=self.save_dir, + ) + + # Log to WandB if enabled + if self.log_to_wandb: + for lightning_logger in lightning_loggers: + if hasattr(lightning_logger, "log_image"): + # Group images by their name for logging + for var_name, pil_image, _saved_path in results: + safe_name = safe_filename(var_name.replace(":", "__")) + lightning_logger.log_image( + key=f"raw_inputs/{safe_name}", + images=[pil_image], + ) + else: + logger.debug( + "Logger %s does not support images.", + lightning_logger.name + if lightning_logger.name + else "unknown", + ) + + logger.debug( + "Plotted %d raw input variables (saved to disk: %s, logged to WandB: %s)", + len(results), + self.save_dir is not None, + self.log_to_wandb, + ) + + except Exception: + logger.exception("Failed to log raw inputs") + + def _accumulate_temporal_data( + self, + channel_arrays: list[np.ndarray], + channel_names: list[str], + date: Any, # noqa: ANN401 + ) -> None: + """Accumulate temporal data for creating animations later. + + Args: + channel_arrays: List of 2D arrays [H, W] for this timestep. + channel_names: Variable names for each channel. + date: Date/datetime for this timestep. + + """ + # Initialise temporal data storage for each variable on first call + if not self._temporal_data: + for name in channel_names: + self._temporal_data[name] = [] + + # Append this timestep's data for each variable + for arr, name in zip(channel_arrays, channel_names, strict=True): + if name in self._temporal_data: + self._temporal_data[name].append(arr) + + # Append date + self._temporal_dates.append(date) + + logger.debug( + "Accumulated temporal data: %d timesteps for %d variables", + len(self._temporal_dates), + len(self._temporal_data), + ) + + def on_test_end(self, trainer: Trainer, _module: LightningModule) -> None: + """Called when the test loop ends. Create and log animations.""" + if not self.make_video_plots: + logger.debug("Video plots disabled, skipping animation creation") + return + + if not self._temporal_data or not self._temporal_dates: + logger.warning("No temporal data collected for animations") + return + + logger.debug("Creating animations for %d variables", len(self._temporal_data)) + try: + self.log_video_plots(trainer.loggers) + finally: + # Clear temporal data to free memory + self._temporal_data.clear() + self._temporal_dates.clear() + # Force garbage collection to clean up animation resources + gc.collect() + + def log_video_plots(self, lightning_loggers: list[LightningLogger]) -> None: # noqa: C901 + """Create and log temporal animations of raw input variables.""" + if not self.make_video_plots: + return + + # Early return if nothing will be saved + if not self.log_to_wandb and self.video_save_dir is None: + logger.debug( + "Skipping video plotting: log_to_wandb=False and video_save_dir=None" + ) + return + + try: + # Convert accumulated data to 3D arrays [T, H, W] + channel_arrays_stream = [] + channel_names = [] + + for var_name, frames in self._temporal_data.items(): + if not frames: + continue + # Stack frames into [T, H, W] + data_stream = np.stack(frames, axis=0) + channel_arrays_stream.append(data_stream) + channel_names.append(var_name) + + if not channel_arrays_stream: + logger.warning("No data to create animations") + return + + logger.info( + "Creating animations: %d variables x %d timesteps", + len(channel_names), + len(self._temporal_dates), + ) + + # Create animations for all variables + results = video_raw_inputs_for_timesteps( + channel_arrays_stream=channel_arrays_stream, + channel_names=channel_names, + dates=self._temporal_dates, + plot_spec_base=self.plot_spec, + land_mask=self._land_mask_array, + styles=self.variable_styles, + fps=self.video_fps, + video_format=self.video_format, + save_dir=self.video_save_dir, + ) + + # Log to WandB if enabled + if self.log_to_wandb: + for lightning_logger in lightning_loggers: + if hasattr(lightning_logger, "log_video"): + for var_name, video_buffer, _saved_path in results: + # Ensure buffer is at start + video_buffer.seek(0) + safe_name = safe_filename(var_name.replace(":", "__")) + lightning_logger.log_video( + key=f"raw_inputs_video/{safe_name}", + videos=[video_buffer], + format=[self.video_format], + ) + logger.debug("Logged video for %s to WandB", var_name) + else: + logger.debug( + "Logger %s does not support videos.", + lightning_logger.name + if lightning_logger.name + else "unknown", + ) + + logger.info( + "Successfully created %d animations (saved to disk: %s, logged to WandB: %s)", + len(results), + self.video_save_dir is not None, + self.log_to_wandb, + ) + + except (InvalidArrayError, VideoRenderError) as err: + logger.warning("Video plotting skipped: %s", err) + except (ValueError, MemoryError, OSError): + logger.exception("Video plotting failed") From 71d448ceea448375734b2438f46cdc16477bbc13 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Sat, 22 Nov 2025 14:04:43 +0000 Subject: [PATCH 20/49] Config for raw input plotting --- .../config/evaluate/callbacks/raw_inputs.yaml | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 ice_station_zebra/config/evaluate/callbacks/raw_inputs.yaml diff --git a/ice_station_zebra/config/evaluate/callbacks/raw_inputs.yaml b/ice_station_zebra/config/evaluate/callbacks/raw_inputs.yaml new file mode 100644 index 00000000..cd6eee33 --- /dev/null +++ b/ice_station_zebra/config/evaluate/callbacks/raw_inputs.yaml @@ -0,0 +1,52 @@ +raw_inputs: + # Callback to plot raw input variables (ERA5 data, etc.) + _target_: ice_station_zebra.callbacks.RawInputsCallback + frequency: null # Plot once; set N to repeat every N batches + timestep_index: 0 # Which history timestep to plot (0 = most recent) + save_dir: ./raw_input_plots # Directory to save plots (null to skip saving to disk) + log_to_wandb: true # Whether to log plots to WandB (if false and save_dir null, nothing saved) + + # Animation configuration (following PlottingCallback pattern) + make_video_plots: false # Set to true to create temporal animations + video_fps: 2 # Frames per second for animations + video_format: gif # mp4 or gif + video_save_dir: ./raw_input_animations # Directory to save animations (null to skip disk save) + max_animation_frames: null # Limit frames (null = unlimited; 30 ≈ 1 month daily data) + + # Plot specification (colourmap, hemisphere, etc.) + plot_spec: + _target_: ice_station_zebra.types.PlotSpec + variable: "raw_inputs" + colourmap: "viridis" + colourbar_location: "vertical" + # hemisphere will be auto-detected from dataset name + + # Per-variable styling overrides (optional) + # Uncomment and customize as needed: + variable_styles: + # Diverging (signed) + "era5:10u": { cmap: "RdBu_r", two_slope_centre: 0.0, units: "m/s" } # 10m u-wind component + "era5:10v": { cmap: "RdBu_r", two_slope_centre: 0.0, units: "m/s" } # 10m v-wind component + "era5:sin_julian_day": { cmap: "PuOr", two_slope_centre: 0.0 } # sin of Julian day + "era5:cos_julian_day": { cmap: "PuOr", two_slope_centre: 0.0 } # cos of Julian day + + # Sequential (physical) + "era5:2t": { cmap: "RdBu_r", two_slope_centre: 273.15, units: "K" } # 2m temperature + "era5:t_*": { cmap: "RdBu_r", two_slope_centre: 273.15, units: "K" } # temperature at various levels + "era5:msl": { cmap: "RdYlBu_r", units: "Pa" } # mean sea level pressure + "era5:sp": { cmap: "RdYlBu_r" } # surface pressure + + "era5:q_*": { cmap: "viridis", decimals: 4 } # specific humidity at various levels (small values need more decimals) + "era5:z_*": { cmap: "plasma", units: "m" } # geopotential at various levels + "era5:u_*": { cmap: "RdBu_r", two_slope_centre: 0.0 } # u-wind component at various levels + "era5:v_*": { cmap: "RdBu_r", two_slope_centre: 0.0 } # v-wind component at various levels + + # Sea ice concentration + "osisaf-south:ice_conc": { cmap: "Blues_r" } + + # Default fallback + "_default": { cmap: "viridis" } + + # Wildcard fallback for ERA5 channels + "era5:*": + origin: "upper" From 699bf189d7ef7924fd5faeda4fd9f6584ab6f71a Mon Sep 17 00:00:00 2001 From: Lydia France Date: Sat, 22 Nov 2025 14:06:06 +0000 Subject: [PATCH 21/49] update init and default yaml for raw input config --- ice_station_zebra/callbacks/__init__.py | 2 ++ ice_station_zebra/config/evaluate/default.yaml | 1 + 2 files changed, 3 insertions(+) diff --git a/ice_station_zebra/callbacks/__init__.py b/ice_station_zebra/callbacks/__init__.py index fe73afc5..6c784959 100644 --- a/ice_station_zebra/callbacks/__init__.py +++ b/ice_station_zebra/callbacks/__init__.py @@ -1,11 +1,13 @@ from .ema_weight_averaging_callback import EMAWeightAveragingCallback from .metric_summary_callback import MetricSummaryCallback from .plotting_callback import PlottingCallback +from .raw_inputs_callback import RawInputsCallback from .unconditional_checkpoint import UnconditionalCheckpoint __all__ = [ "EMAWeightAveragingCallback", "MetricSummaryCallback", "PlottingCallback", + "RawInputsCallback", "UnconditionalCheckpoint", ] diff --git a/ice_station_zebra/config/evaluate/default.yaml b/ice_station_zebra/config/evaluate/default.yaml index aa312e3a..999d28f7 100644 --- a/ice_station_zebra/config/evaluate/default.yaml +++ b/ice_station_zebra/config/evaluate/default.yaml @@ -2,4 +2,5 @@ defaults: - callbacks: - metric_summary - plotting + - raw_inputs - _self_ From 8629239de976d9457cc6b69059fd0cf7c1a2ae3c Mon Sep 17 00:00:00 2001 From: Lydia France Date: Sat, 22 Nov 2025 14:14:58 +0000 Subject: [PATCH 22/49] Memory and garbage collection for animation --- .../visualisations/animation_helper.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 ice_station_zebra/visualisations/animation_helper.py diff --git a/ice_station_zebra/visualisations/animation_helper.py b/ice_station_zebra/visualisations/animation_helper.py new file mode 100644 index 00000000..2a84a423 --- /dev/null +++ b/ice_station_zebra/visualisations/animation_helper.py @@ -0,0 +1,40 @@ +"""Animation helper module for managing matplotlib animation lifecycle. + +Provides centralised cache management to prevent garbage collection of animation +objects during save operations, avoiding warnings and ensuring proper cleanup. +""" + +import contextlib + +from matplotlib import animation + +# Centralized cache to hold strong references to animation objects +_ANIM_CACHE: list[animation.FuncAnimation] = [] + + +def hold_anim(anim: animation.FuncAnimation) -> None: + """Add an animation object to the cache to prevent garbage collection. + + Call this immediately after creating a FuncAnimation object to ensure + it remains in memory during save operations. + + Args: + anim: The FuncAnimation object to cache. + + """ + _ANIM_CACHE.append(anim) + + +def release_anim(anim: animation.FuncAnimation) -> None: + """Remove an animation object from the cache after saving is complete. + + Call this in a finally block after saving the animation to allow + proper cleanup and garbage collection. + + Args: + anim: The FuncAnimation object to remove from cache. + + """ + with contextlib.suppress(ValueError): + # Animation already removed or never added - ignore + _ANIM_CACHE.remove(anim) From 170ecb9304cb26c177f57b7aa2b0ef3029f2b0cb Mon Sep 17 00:00:00 2001 From: Lydia France Date: Sat, 22 Nov 2025 14:20:11 +0000 Subject: [PATCH 23/49] Improve saving --- ice_station_zebra/visualisations/convert.py | 107 +++++++++++++------- 1 file changed, 68 insertions(+), 39 deletions(-) diff --git a/ice_station_zebra/visualisations/convert.py b/ice_station_zebra/visualisations/convert.py index 9b093276..a0b69504 100644 --- a/ice_station_zebra/visualisations/convert.py +++ b/ice_station_zebra/visualisations/convert.py @@ -1,5 +1,9 @@ +import contextlib +import gc import io +import logging import tempfile +from collections.abc import Iterator from pathlib import Path from typing import Literal @@ -15,65 +19,90 @@ DEFAULT_DPI = 200 -def _image_from_figure(fig: Figure) -> ImageFile: - """Convert a matplotlib figure to a PIL image file.""" +@contextlib.contextmanager +def _suppress_mpl_animation_logs() -> Iterator[None]: + """Temporarily suppress matplotlib animation INFO log messages.""" + mpl_logger = logging.getLogger("matplotlib.animation") + original_level = mpl_logger.level + try: + mpl_logger.setLevel(logging.WARNING) + yield + finally: + mpl_logger.setLevel(original_level) + + +def image_from_figure(fig: Figure) -> ImageFile: + """Convert a matplotlib figure to a PIL image file. + + Uses the same save parameters as save_figure for consistency: + - dpi=300 for matching resolution + - bbox_inches="tight" to crop to content (matching disk saves) + """ buf = io.BytesIO() - fig.savefig(buf, format="png") + fig.savefig(buf, format="png", dpi=300, bbox_inches="tight") buf.seek(0) return Image.open(buf) -def _image_from_array(a: np.ndarray) -> ImageFile: +def image_from_array(a: np.ndarray) -> ImageFile: """Convert a numpy array to a PIL image file.""" fig, ax = plt.subplots(figsize=(6, 6)) ax.imshow(a) ax.axis("off") try: - return _image_from_figure(fig) + return image_from_figure(fig) finally: plt.close(fig) -def _save_animation( +def save_animation( anim: animation.FuncAnimation, *, - fps: int | None = None, + fps: int = 2, video_format: Literal["mp4", "gif"] = "gif", - _fps: int | None = None, - _video_format: Literal["mp4", "gif"] | None = None, ) -> io.BytesIO: """Save an animation to a temporary file and return BytesIO (with cleanup).""" - # Accept both standard and underscored names for test compatibility - fps_value: int = int(fps if fps is not None else (_fps if _fps is not None else 2)) - if _video_format is not None: - video_format = _video_format # prefer underscored override if provided + fps_value: int = int(fps) suffix = ".gif" if video_format.lower() == "gif" else ".mp4" - with tempfile.NamedTemporaryFile(suffix=suffix, delete=True) as tmp: - try: - # Save video to tempfile - writer: animation.AbstractMovieWriter = ( - animation.PillowWriter(fps=fps_value) - if suffix == ".gif" - else animation.FFMpegWriter( - fps=fps_value, - codec="libx264", - bitrate=1800, - # Ensure dimensions are compatible with yuv420p (even width/height) - # by applying a scale filter that truncates to the nearest even integers. - extra_args=[ - "-pix_fmt", - "yuv420p", - "-vf", - "scale=trunc(iw/2)*2:trunc(ih/2)*2", - ], + + writer: animation.AbstractMovieWriter | None = None + try: + with tempfile.NamedTemporaryFile(suffix=suffix, delete=True) as tmp: + try: + # Save video to tempfile + writer = ( + animation.PillowWriter(fps=fps_value) + if suffix == ".gif" + else animation.FFMpegWriter( + fps=fps_value, + codec="libx264", + bitrate=1800, + # Ensure dimensions are compatible with yuv420p (even width/height) + # by applying a scale filter that truncates to the nearest even integers. + extra_args=[ + "-pix_fmt", + "yuv420p", + "-vf", + "scale=trunc(iw/2)*2:trunc(ih/2)*2", + ], + ) ) - ) - anim.save(tmp.name, writer=writer, dpi=DEFAULT_DPI) - # Load tempfile into a BytesIO buffer - with Path(tmp.name).open("rb") as fh: - buffer = io.BytesIO(fh.read()) - except (OSError, MemoryError) as err: - msg = f"Video encoding failed: {err!s}" - raise VideoRenderError(msg) from err + # Suppress matplotlib's INFO log message about writer selection + with _suppress_mpl_animation_logs(): + anim.save(tmp.name, writer=writer, dpi=DEFAULT_DPI) + # Load tempfile into a BytesIO buffer + with Path(tmp.name).open("rb") as fh: + buffer = io.BytesIO(fh.read()) + except (OSError, MemoryError) as err: + msg = f"Video encoding failed: {err!s}" + raise VideoRenderError(msg) from err + finally: + # Explicitly cleanup writer to prevent semaphore leaks + if writer is not None and hasattr(writer, "cleanup"): + with contextlib.suppress(OSError, RuntimeError, AttributeError): + writer.cleanup() + # Force garbage collection to clean up any remaining resources + gc.collect() + buffer.seek(0) return buffer From 196587262a7501e4fc917e05c7efaeddbad8a1f3 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Sat, 22 Nov 2025 14:22:27 +0000 Subject: [PATCH 24/49] Add colour map helper --- ice_station_zebra/types/simple_datatypes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ice_station_zebra/types/simple_datatypes.py b/ice_station_zebra/types/simple_datatypes.py index 4eba0fc4..5bdd77c7 100644 --- a/ice_station_zebra/types/simple_datatypes.py +++ b/ice_station_zebra/types/simple_datatypes.py @@ -49,7 +49,7 @@ class DiffColourmapSpec(NamedTuple): norm: Normalisation for mapping values to colours (e.g. TwoSlopeNorm for signed diffs). vmin: Lower bound if no norm is provided. vmax: Upper bound if no norm is provided. - cmap: Matplotlib colormap name. + cmap: Matplotlib colourmap name. """ @@ -69,7 +69,7 @@ class PlotSpec: title_prediction: Title above the prediction panel. title_difference: Title above the difference panel. n_contour_levels: Number of contour levels per panel. - colourmap: Colormap used for GT/prediction panels. + colourmap: colourmap used for GT/prediction panels. include_difference: Whether to draw a difference panel. diff_mode: Difference definition (e.g. "signed", "absolute", "smape"). diff_strategy: Strategy for animations (precompute, two-pass, per-frame). From 57e888e3b231f32682585c721936e7bcf5022ff7 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Sat, 22 Nov 2025 14:25:24 +0000 Subject: [PATCH 25/49] Generalise further, plot styles, allow easier local save --- .../visualisations/plotting_core.py | 394 ++++++++++++++++-- 1 file changed, 351 insertions(+), 43 deletions(-) diff --git a/ice_station_zebra/visualisations/plotting_core.py b/ice_station_zebra/visualisations/plotting_core.py index 75a9054f..df522959 100644 --- a/ice_station_zebra/visualisations/plotting_core.py +++ b/ice_station_zebra/visualisations/plotting_core.py @@ -1,15 +1,199 @@ +import logging +from collections.abc import Mapping +from dataclasses import dataclass from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal +import matplotlib as mpl import numpy as np -from matplotlib.colors import TwoSlopeNorm +from matplotlib.colors import Normalize, TwoSlopeNorm from ice_station_zebra.exceptions import InvalidArrayError from ice_station_zebra.types import DiffColourmapSpec, DiffMode, DiffStrategy, PlotSpec +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + import matplotlib.pyplot as plt + # Constants for land mask validation EXPECTED_LAND_MASK_DIMENSIONS = 2 +# --- Variable Styling --- + + +@dataclass +class VariableStyle: + """Styling configuration for individual variables. + + Attributes: + cmap: Matplotlib colourmap name (e.g., "viridis", "RdBu_r"). + colourbar_strategy: "shared" or "separate" (kept for compatibility). + vmin: Minimum value for colour scale. + vmax: Maximum value for colour scale. + two_slope_centre: Centre value for diverging colourmap (TwoSlopeNorm). + units: Display units for the variable (e.g., "K", "m/s"). + origin: Imshow origin override ("upper" keeps north-up, "lower" keeps south-up). + decimals: Number of decimal places for colourbar tick labels (default: 2). + + """ + + cmap: str | None = None + colourbar_strategy: str | None = None + vmin: float | None = None + vmax: float | None = None + two_slope_centre: float | None = None + units: str | None = None + origin: Literal["upper", "lower"] | None = None + decimals: int | None = None + + +def colourmap_with_bad( + cmap_name: str | None, bad_color: str = "#dcdcdc" +) -> mpl.colors.Colormap: + """Create a colourmap copy with a specified color for bad (NaN) values. + + This function copies the specified colourmap and sets the 'bad' color to handle + NaN values consistently, preventing white artifacts in visualisations. + + Args: + cmap_name: Name of the matplotlib colourmap (e.g., "viridis", "RdBu_r"). + If None, defaults to "viridis". + bad_color: Color to use for NaN/bad values. Default is light grey (#dcdcdc). + + Returns: + A copy of the colourmap with set_bad() configured. + + """ + if cmap_name is None: + cmap = mpl.colormaps.get_cmap("viridis") + else: + cmap = mpl.colormaps.get_cmap(cmap_name) + + try: + cmap = cmap.copy() + except (AttributeError, TypeError): + # Some matplotlib versions return non-copyable colourmap; create new + cmap = mpl.colormaps.get_cmap(cmap.name) + + cmap.set_bad(bad_color) + return cmap + + +def safe_nanmin(arr: np.ndarray, default: float = 0.0) -> float: + """Safely compute nanmin with fallback for empty or all-NaN arrays. + + Args: + arr: Array to compute minimum from. + default: Default value if array is empty or all NaN. + + Returns: + Minimum value or default. + + """ + if np.isfinite(arr).any(): + result = np.nanmin(arr) + return float(result) if np.isfinite(result) else default + return default + + +def safe_nanmax(arr: np.ndarray, default: float = 1.0) -> float: + """Safely compute nanmax with fallback for empty or all-NaN arrays. + + Args: + arr: Array to compute maximum from. + default: Default value if array is empty or all NaN. + + Returns: + Maximum value or default. + + """ + if np.isfinite(arr).any(): + result = np.nanmax(arr) + return float(result) if np.isfinite(result) else default + return default + + +def style_for_variable( # noqa: C901, PLR0911 + var_name: str, styles: dict[str, dict[str, Any]] | None +) -> VariableStyle: + """Return best matching style for a variable from config styles dict. + + Matching priority: + 1) exact key + 2) wildcard prefix key ending with '*' + 3) _default + 4) empty style + Accepts any Mapping (so OmegaConf DictConfig works). + """ + + def _normalise_name(name: str) -> str: + # Convert double-underscore to colon (this maps 'era5__2t' -> 'era5:2t') + name = name.replace("__", ":") + # Treat hyphens as separators too: 'era5-2t' -> 'era5:2t' + name = name.replace("-", ":") + # Collapse accidental repeated '::' to single ':' + while "::" in name: + name = name.replace("::", ":") + # Keep single underscores (they are meaningful in some variable names) + return name + + if not styles: + return VariableStyle() + + # Accept Mapping-like configs (Dict, DictConfig, etc.) + if not isinstance(styles, Mapping): + logger.info("style_for_variable: styles is not a Mapping; ignoring styles") + return VariableStyle() + + # Quick exact match first (try raw var_name) + spec = styles.get(var_name) + if isinstance(spec, Mapping): + return VariableStyle(**{k: spec.get(k) for k in VariableStyle.__annotations__}) + + # Try normalised exact match + norm_var = _normalise_name(var_name) + if norm_var != var_name: + spec = styles.get(norm_var) + if isinstance(spec, Mapping): + return VariableStyle( + **{k: spec.get(k) for k in VariableStyle.__annotations__} + ) + + # Wildcard prefix match: scan keys ending with '*' (normalise the key before comparing) + # We iterate keys so keep original order (OmegaConf preserves insertion order). + for key in styles: + if isinstance(key, str) and key.endswith("*"): + prefix = key[:-1] + prefix_norm = _normalise_name(prefix) + # If prefix_norm is empty (user wrote '*' only) skip it + if not prefix_norm: + continue + # Compare against both raw and normalised var names + if var_name.startswith(prefix) or norm_var.startswith(prefix_norm): + spec = styles.get(key) + if isinstance(spec, Mapping): + return VariableStyle( + **{k: spec.get(k) for k in VariableStyle.__annotations__} + ) + logger.info( + "style_for_variable: wildcard candidate %r not a dict (type=%s)", + key, + type(spec), + ) + + # Fallback to _default + spec = styles.get("_default") + if isinstance(spec, Mapping): + return VariableStyle(**{k: spec.get(k) for k in VariableStyle.__annotations__}) + + return VariableStyle() + + +# --- Colour Scale Generation --- + + def levels_from_spec(spec: PlotSpec) -> np.ndarray: """Generate contour levels from a plotting specification. @@ -30,6 +214,63 @@ def levels_from_spec(spec: PlotSpec) -> np.ndarray: return np.linspace(vmin, vmax, spec.n_contour_levels) +def create_normalisation( + data: np.ndarray, + *, + vmin: float | None = None, + vmax: float | None = None, + centre: float | None = None, +) -> tuple[Normalize | TwoSlopeNorm, float, float]: + """Create appropriate normalisation for data with optional centring. + + This function creates either a linear Normalize or a diverging TwoSlopeNorm + based on whether a centre value is provided. When a centre is specified, + the normalisation will be symmetric around that value. + + Args: + data: 2D array of data to normalise. + vmin: Minimum value for colour scale. If None, inferred from data. + vmax: Maximum value for colour scale. If None, inferred from data. + centre: Centre value for diverging colourmap. If provided, creates + a symmetric TwoSlopeNorm around this value. + + Returns: + Tuple of (normalisation, vmin, vmax) where: + - normalisation: Normalize or TwoSlopeNorm object + - vmin: Computed minimum value + - vmax: Computed maximum value + + """ + # Compute data range with robust handling of NaN/inf + data_min = float(np.nanmin(data)) if np.isfinite(data).any() else 0.0 + data_max = float(np.nanmax(data)) if np.isfinite(data).any() else 1.0 + + if centre is not None: + # Diverging colourmap centred at specified value + low = vmin if vmin is not None else data_min + high = vmax if vmax is not None else data_max + + # Make symmetric around the centre where possible + span_low = abs(centre - low) + span_high = abs(high - centre) + span = max(span_low, span_high, 1e-6) + + final_vmin = float(centre - span) + final_vmax = float(centre + span) + + norm: Normalize | TwoSlopeNorm = TwoSlopeNorm( + vmin=final_vmin, vcenter=float(centre), vmax=final_vmax + ) + return norm, final_vmin, final_vmax + + # Linear colourmap + final_vmin = float(vmin if vmin is not None else data_min) + final_vmax = float(vmax if vmax is not None else data_max) + + norm_linear: Normalize | TwoSlopeNorm = Normalize(vmin=final_vmin, vmax=final_vmax) + return norm_linear, final_vmin, final_vmax + + def compute_difference( ground_truth: np.ndarray, prediction: np.ndarray, diff_mode: DiffMode ) -> np.ndarray: @@ -87,13 +328,9 @@ def make_diff_colourmap( max_abs = max(1.0, float(abs(sample))) vmin, vmax = -max_abs, max_abs else: - # Find the min and max values of the sample array - vmin_data = float( - np.nanmin(sample) if np.nanmin(sample) is not None else -1.0 - ) - vmax_data = float( - np.nanmax(sample) if np.nanmax(sample) is not None else 1.0 - ) + # Find the min and max values of the sample array using safe helpers + vmin_data = safe_nanmin(sample, default=-1.0) + vmax_data = safe_nanmax(sample, default=1.0) # Find the maximum absolute value of the sample array max_abs = max(1.0, abs(vmin_data), abs(vmax_data)) vmin, vmax = -max_abs, max_abs @@ -110,7 +347,7 @@ def make_diff_colourmap( if isinstance(sample, (float, int)): vmax = max(1e-6, float(sample)) else: - vmax = max(1e-6, float(np.nanmax(sample) or 0.0)) + vmax = max(1e-6, safe_nanmax(sample, default=0.0)) return DiffColourmapSpec( norm=None, @@ -330,7 +567,7 @@ def compute_display_ranges_stream( return (groundtruth_min, groundtruth_max), (prediction_min, prediction_max) -def detect_land_mask_path( +def detect_land_mask_path( # noqa: C901 base_path: str | Path, dataset_name: str | None = None, hemisphere: str | None = None, @@ -338,12 +575,14 @@ def detect_land_mask_path( """Automatically detect the land mask path based on dataset configuration. This function looks for land mask files in the expected locations based on - the dataset name and hemisphere. It follows the pattern: + the dataset name and hemisphere. It accepts either the repository root + (containing a ``data`` directory) or the ``data`` directory itself as + ``base_path``. It follows the pattern: - {base_path}/data/preprocessing/{dataset_name}/IceNetSIC/data/masks/{hemisphere}/masks/land_mask.npy - {base_path}/data/preprocessing/IceNetSIC/data/masks/{hemisphere}/masks/land_mask.npy Args: - base_path: Base path to the data directory. + base_path: Base path to the project root **or** the ``data`` directory. dataset_name: Name of the dataset (e.g., 'samp-sicsouth-osisaf-25k-2017-2019-24h-v1'). hemisphere: Hemisphere ('north' or 'south'). @@ -363,37 +602,54 @@ def detect_land_mask_path( if hemisphere is None: return None - # Try dataset-specific path first - if dataset_name is not None: - dataset_specific_path = ( - base_path - / "data" - / "preprocessing" - / dataset_name - / "IceNetSIC" - / "data" - / "masks" - / hemisphere - / "masks" - / "land_mask.npy" - ) - if dataset_specific_path.exists(): - return str(dataset_specific_path) - - # Try general IceNetSIC path - general_path = ( - base_path - / "data" - / "preprocessing" - / "IceNetSIC" - / "data" - / "masks" - / hemisphere - / "masks" - / "land_mask.npy" - ) - if general_path.exists(): - return str(general_path) + # Support repo-root, data-directory, and nested data/data layouts + candidate_data_roots: list[Path] = [] + if base_path.name != "data": + candidate_data_roots.append(base_path / "data") + candidate_data_roots.append(base_path) + + seen_roots: set[Path] = set() + for data_root in candidate_data_roots: + resolved_root = data_root.resolve() + if resolved_root in seen_roots: + continue + seen_roots.add(resolved_root) + + # Some deployments keep datasets in data/data/, + # so we probe both the root and an extra nested data/ layer. + preprocessing_roots = { + resolved_root / "preprocessing", + resolved_root / "data" / "preprocessing", + } + + for preproc_root in preprocessing_roots: + # Try dataset-specific path first + if dataset_name is not None: + dataset_specific_path = ( + preproc_root + / dataset_name + / "IceNetSIC" + / "data" + / "masks" + / hemisphere + / "masks" + / "land_mask.npy" + ) + if dataset_specific_path.exists(): + return str(dataset_specific_path) + + # Try general IceNetSIC path + general_path = ( + preproc_root + / "IceNetSIC" + / "data" + / "masks" + / hemisphere + / "masks" + / "land_mask.npy" + ) + if general_path.exists(): + return str(general_path) return None @@ -437,3 +693,55 @@ def load_land_mask( land_mask = land_mask.astype(bool) return land_mask + + +# --- File Utilities --- + + +def safe_filename(name: str) -> str: + """Sanitise a string for use as a filename. + + Replaces non-alphanumeric characters (except hyphens and underscores) + with hyphens and strips leading/trailing hyphens. + + Args: + name: Input string to sanitise. + + Returns: + Sanitised filename string. + + Examples: + >>> safe_filename("era5:2t") + 'era5-2t' + >>> safe_filename("My Variable Name!") + 'My-Variable-Name' + + """ + keep = [c if c.isalnum() or c in ("-", "_") else "-" for c in name.strip()] + return "".join(keep).strip("-") or "var" + + +def save_figure( + fig: "plt.Figure", save_dir: Path | None, base_name: str +) -> Path | None: + """Save a matplotlib figure to disk as PNG. + + Creates the save directory if it doesn't exist. Sanitises the filename + to avoid filesystem issues. + + Args: + fig: Matplotlib Figure object to save. + save_dir: Directory to save the figure in. If None, does not save. + base_name: Base name for the file (will be sanitised). + + Returns: + Path to the saved file, or None if save_dir was None. + + """ + if save_dir is None: + return None + + save_dir.mkdir(parents=True, exist_ok=True) + path = save_dir / f"{safe_filename(base_name)}.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + return path From 52f1b04981f62b7d875884e8499f4f5999cccf09 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Sat, 22 Nov 2025 14:27:46 +0000 Subject: [PATCH 26/49] Refactor layout for single panel plots --- ice_station_zebra/visualisations/layout.py | 582 ++++++++++++++++++--- 1 file changed, 497 insertions(+), 85 deletions(-) diff --git a/ice_station_zebra/visualisations/layout.py b/ice_station_zebra/visualisations/layout.py index cde46fe0..613bca1e 100644 --- a/ice_station_zebra/visualisations/layout.py +++ b/ice_station_zebra/visualisations/layout.py @@ -15,7 +15,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import contextlib +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Literal import matplotlib.pyplot as plt import numpy as np @@ -33,59 +35,317 @@ from .plotting_core import DiffColourmapSpec, PlotSpec -# Layout constants +import logging + +logger = logging.getLogger(__name__) + +# Panel index constants (used for layout logic) PREDICTION_PANEL_INDEX = 1 DIFFERENCE_PANEL_INDEX = 2 MIN_PANEL_COUNT_FOR_PREDICTION = 2 MIN_PANEL_COUNT_FOR_DIFFERENCE = 3 - -# -# --- LAYOUT CONFIGURATION CONSTANTS --- -# These constants define the default proportional spacing for the GridSpec layout. -# All values are expressed as fractions of the total figure dimensions to ensure -# consistent scaling across different figure sizes. - -# Spacing and margin constants (as fractions of figure dimensions) -DEFAULT_OUTER_MARGIN = 0.05 # Outer margin around entire figure (prevents clipping) -DEFAULT_GUTTER_HORIZONTAL = 0.03 # Smaller gaps when colourbar is below -DEFAULT_GUTTER_VERTICAL = 0.03 # Default for side-by-side with vertical bars -DEFAULT_CBAR_WIDTH = ( - 0.06 # Width allocated for colourbar slots (fraction of panel width) -) -DEFAULT_TITLE_SPACE = 0.07 # Reduced: simple title + warning badge + axes titles -DEFAULT_FOOTER_SPACE = 0.08 # Increased space reserved at bottom for metadata footer - -# Horizontal colourbar sizing (fractions of figure height) -DEFAULT_CBAR_HEIGHT = ( - 0.07 # Thickness of horizontal colourbar row (relative to plot height) -) -DEFAULT_CBAR_PAD = 0.03 # Vertical gap between plots and the colourbar row - -# Horizontal colourbar width within its slot (inset fraction within parent slot) +# Horizontal colourbar inset constants HCBAR_WIDTH_FRAC = 0.75 # centred 75% width HCBAR_LEFT_FRAC = (1.0 - HCBAR_WIDTH_FRAC) / 2.0 -# Default figure sizes for different panel configurations -DEFAULT_FIGSIZE_TWO_PANELS = (12, 6) # Ground truth + Prediction -DEFAULT_FIGSIZE_THREE_PANELS = (18, 6) # Ground truth + Prediction + Difference +# Small epsilon for numerical stability +EPSILON_SMALL = 1e-6 + +# --- Layout Configuration Dataclasses --- + + +@dataclass(frozen=True) +class ColourbarConfig: + """Colourbar sizing and clamp behaviour (fractions of figure). + + Example: + # Wider colourbar with higher max fraction + cbar_config = ColourbarConfig(default_width_frac=0.08, max_fraction_of_fig=0.15) + + """ + + default_width_frac: float = ( + 0.06 # Width allocated for colourbar slots (fraction of panel width) + ) + min_width_frac: float = 0.01 # Minimum colourbar width clamp + max_width_frac: float = 0.5 # Maximum colourbar width clamp + desired_physical_width_in: float = ( + 0.45 # Desired physical colourbar width in inches + ) + max_fraction_of_fig: float = 0.12 # Maximum colourbar width as fraction of figure + default_height_frac: float = 0.07 # Thickness of horizontal colourbar row + default_pad_frac: float = 0.03 # Vertical gap between plots and colourbar row + + def clamp_frac(self, frac: float, fig_width_in: float) -> float: + """Return a sane clamped fraction for the cbar slot based on figure width. + + Computes a physical-limit fraction derived from desired physical width, + then clamps the input fraction between min/max bounds. + """ + # Compute a physical-limit fraction derived from desired physical width + phys_frac = min( + self.max_fraction_of_fig, + self.desired_physical_width_in / max(fig_width_in, EPSILON_SMALL), + ) + return float( + max(self.min_width_frac, min(frac, phys_frac, self.max_width_frac)) + ) + + +@dataclass(frozen=True) +class TitleFooterConfig: + """Title and footer spacing, positioning, and styling.""" + title_space: float = 0.07 # Fraction of figure height reserved for title + footer_space: float = 0.08 # Fraction of figure height reserved for footer + title_fontsize: int = 12 # Font size for title + footer_fontsize: int = 11 # Font size for footer and badge + title_y: float = 0.98 # Y position for title (near top, in figure coordinates) + footer_y: float = 0.03 # Y position for footer (near bottom, in figure coordinates) + bbox_pad_title: float = 2.0 # Padding for title bbox + bbox_pad_badge: float = 1.5 # Padding for badge bbox + zorder_high: int = 1000 # High z-order for text overlays + + +@dataclass(frozen=True) +class GapConfig: + """Aspect-aware physical gap settings for single-panel layouts.""" + + base: float = 0.15 # Preferred gap for aspect≈1 panels + min_val: float = 0.10 # Hard minimum + max_val: float = 0.22 # Hard maximum + wide_limit: float = 4.0 # Aspect ratio at which we reach the "wide" gap + tall_limit: float = 0.5 # Aspect ratio at which we reach the "tall" gap + wide_gap: float = 0.11 # Gap to use for very wide panels + tall_gap: float = 0.19 # Gap to use for very tall panels + + +@dataclass(frozen=True) +class SinglePanelSpacing: + """Encapsulate spacing controls for standalone panels.""" + + gap: GapConfig = field(default_factory=GapConfig) + outer_buffer_in: float = 0.16 # Padding outside map/cbar (both sides) + edge_guard_in: float = 0.25 # Minimum blank space beyond colourbar for ticks + right_margin_scale: float = 0.5 # Scale factor for right margin adjustment + right_margin_offset: float = 0.02 # Offset for right margin adjustment + + +@dataclass(frozen=True) +class FormattingConfig: + """Tick formatting and fallback value constants.""" + + num_ticks_linear: int = 5 # Number of ticks for linear colourbars + midpoint_factor: float = ( + 0.5 # Factor for calculating midpoint values in symmetric ticks + ) + default_vmin_fallback: float = 0.0 # Fallback minimum value + default_vmax_fallback: float = 1.0 # Fallback maximum value + default_vmin_diff_fallback: float = -1.0 # Fallback minimum for difference plots + default_vmax_diff_fallback: float = 1.0 # Fallback maximum for difference plots + + +@dataclass(frozen=True) +class LayoutConfig: + """Top-level layout tuning surface - passed into build functions. + + Groups all layout-related configuration into a single, discoverable object. + This makes it easy to override defaults for testing or custom layouts. + + Example: + # Wider colourbar and less title space + my_layout = LayoutConfig( + colourbar=ColourbarConfig(default_width_frac=0.08), + title_footer=TitleFooterConfig(title_space=0.04) + ) + fig, axs, cax = build_single_panel_figure( + height=200, width=300, layout_config=my_layout, colourbar_location="vertical" + ) + + """ + + base_height_in: float = 6.0 # Standard figure height in inches + outer_margin: float = 0.05 # Outer margin around entire figure (prevents clipping) + gutter_vertical: float = 0.03 # Default for side-by-side with vertical bars + gutter_horizontal: float = 0.03 # Smaller gaps when colourbar is below + colourbar: ColourbarConfig = field(default_factory=ColourbarConfig) + title_footer: TitleFooterConfig = field(default_factory=TitleFooterConfig) + single_panel_spacing: SinglePanelSpacing = field(default_factory=SinglePanelSpacing) + formatting: FormattingConfig = field(default_factory=FormattingConfig) + min_plot_fraction: float = ( + 0.2 # Minimum plot width/height as fraction of available space + ) + min_usable_height_fraction: float = ( + 0.6 # Minimum fraction of figure height for plotting area + ) + default_figsizes: dict[int, tuple[float, float]] = field( + default_factory=lambda: { + 1: (8, 6), # Single panel + 2: (12, 6), # Ground truth + Prediction + 3: (18, 6), # Ground truth + Prediction + Difference + } + ) + + +# Default layout configuration instance (used when layout_config is None) +_DEFAULT_LAYOUT_CONFIG = LayoutConfig() # --- Main Layout Functions --- -def _build_layout( # noqa: PLR0913 +def build_single_panel_figure( # noqa: PLR0913, PLR0915 + *, + height: int, + width: int, + layout_config: LayoutConfig | None = None, + colourbar_location: Literal["vertical", "horizontal"] = "vertical", + cbar_width: float | None = None, + cbar_height: float | None = None, + cbar_pad: float | None = None, +) -> tuple[Figure, Axes, Axes]: + """Create a single-panel figure with layout consistent with multi-panel maps. + + Args: + height: Data height in pixels. + width: Data width in pixels. + layout_config: Optional LayoutConfig instance to override defaults. + If None, uses default configuration. + colourbar_location: Orientation of colourbar ("vertical" or "horizontal"). + cbar_width: Optional override for colourbar width fraction. + cbar_height: Optional override for colourbar height fraction (horizontal only). + cbar_pad: Optional override for colourbar padding. + + """ + if height <= 0 or width <= 0: + msg = "height and width must be positive for single panel layout." + raise ValueError(msg) + + layout = layout_config or _DEFAULT_LAYOUT_CONFIG + outer_margin = layout.outer_margin + title_space = layout.title_footer.title_space + footer_space = layout.title_footer.footer_space + + base_h = layout.base_height_in + aspect = width / max(1, height) + fig_w = base_h * aspect + fig = plt.figure( + figsize=(fig_w, base_h), constrained_layout=False, facecolor="none" + ) + + top_val = max(layout.min_usable_height_fraction, 1.0 - (outer_margin + title_space)) + bottom_val = outer_margin + footer_space + usable_height = top_val - bottom_val + + if colourbar_location == "vertical": + # --- Interpret cbar_width as "fraction of panel width" --- + cw_panel = float( + layout.colourbar.default_width_frac if cbar_width is None else cbar_width + ) + fig_w_in = float(fig.get_size_inches()[0]) + cw_panel = layout.colourbar.clamp_frac(cw_panel, fig_w_in) + + spacing = layout.single_panel_spacing + + extra_left_frac = spacing.outer_buffer_in / max(fig_w_in, EPSILON_SMALL) + extra_right_frac = spacing.outer_buffer_in / max(fig_w_in, EPSILON_SMALL) + + if cbar_pad is None: + cbar_pad_inches = _default_vertical_gap_inches(aspect, spacing.gap) + logger.debug( + "single-panel vertical gap auto: w=%d h=%d aspect=%.3f gap=%.4fin", + width, + height, + aspect, + cbar_pad_inches, + ) + else: + base_pad_inches = float(cbar_pad) * fig_w_in + cbar_pad_inches = np.clip( + base_pad_inches, spacing.gap.min_val, spacing.gap.max_val + ) + logger.debug( + "single-panel vertical gap override: w=%d h=%d aspect=%.3f raw=%.4fin clipped=%.4fin", + width, + height, + aspect, + base_pad_inches, + cbar_pad_inches, + ) + cbar_pad = cbar_pad_inches / max(fig_w_in, EPSILON_SMALL) + + # --- Right margin (slightly smaller than left) with physical safeguard --- + right_margin = max( + outer_margin * spacing.right_margin_scale, + outer_margin - spacing.right_margin_offset, + ) + right_margin_inches = right_margin * fig_w_in + if right_margin_inches < spacing.edge_guard_in: + right_margin = spacing.edge_guard_in / max(fig_w_in, EPSILON_SMALL) + + # === 2) Compute plot width from panel-fraction rule === + available = 1.0 - ( + outer_margin + extra_left_frac + cbar_pad + right_margin + extra_right_frac + ) + plot_width_candidate = max( + layout.min_plot_fraction, available / (1.0 + cw_panel) + ) + aspect_width_target = usable_height + plot_width = min(plot_width_candidate, aspect_width_target) + + # initial colourbar width (fraction of figure) + cax_width = cw_panel * plot_width + + # === 3) CAP COLOURBAR PHYSICAL WIDTH === + # Compute maximum allowed fraction based on desired physical width + phys_frac = layout.colourbar.desired_physical_width_in / max( + fig_w_in, EPSILON_SMALL + ) + max_cax_frac = min(layout.colourbar.max_fraction_of_fig, phys_frac) + + if cax_width > max_cax_frac: + cax_width = max_cax_frac + plot_width = min( + aspect_width_target, + max( + layout.min_plot_fraction, + 1.0 - (outer_margin + cbar_pad + cax_width + right_margin), + ), + ) + + # === 4) Place axes === + plot_left = outer_margin + extra_left_frac + ax = fig.add_axes((plot_left, bottom_val, plot_width, usable_height)) + + cax_left = plot_left + plot_width + cbar_pad + cax = fig.add_axes((cax_left, bottom_val, cax_width, usable_height)) + + else: + cbar_height = ( + layout.colourbar.default_height_frac if cbar_height is None else cbar_height + ) + cbar_pad = layout.colourbar.default_pad_frac if cbar_pad is None else cbar_pad + + plot_left = outer_margin + plot_width = 1.0 - 2 * outer_margin + plot_height = usable_height - (cbar_height + cbar_pad) + plot_height = max(plot_height, layout.min_plot_fraction) + + ax_bottom = bottom_val + cbar_height + cbar_pad + ax = fig.add_axes((plot_left, ax_bottom, plot_width, plot_height)) + cax = fig.add_axes((plot_left, bottom_val, plot_width, cbar_height)) + + _style_axes([ax]) + _set_axes_limits([ax], width=width, height=height) + return fig, ax, cax + + +def build_layout( *, plot_spec: PlotSpec, height: int | None = None, width: int | None = None, - outer_margin: float = DEFAULT_OUTER_MARGIN, - gutter: float | None = None, - cbar_width: float = DEFAULT_CBAR_WIDTH, - title_space: float = DEFAULT_TITLE_SPACE, - footer_space: float = DEFAULT_FOOTER_SPACE, - cbar_height: float = DEFAULT_CBAR_HEIGHT, - cbar_pad: float = DEFAULT_CBAR_PAD, + layout_config: LayoutConfig | None = None, ) -> tuple[Figure, list[Axes], dict[str, Axes | None]]: """Create a GridSpec layout for multi-panel plots. @@ -106,19 +366,8 @@ def _build_layout( # noqa: PLR0913 height: Optional data height for aspect-ratio-aware figure sizing. If provided with width, the figure dimensions will be calculated to maintain proper data aspect ratios. width: Optional data width for aspect-ratio-aware figure sizing. - outer_margin: Fraction of figure dimensions reserved for outer margins (prevents - plot elements from being clipped at figure edges). - gutter: Fraction of panel width used as horizontal spacing between panel groups. - Only applied between prediction+colourbar and difference panel. - cbar_width: Fraction of panel width allocated for each colourbar slot. - title_space: Fraction of figure height reserved at the top for figure title - (prevents title from overlapping with plot content). - footer_space: Fraction of figure height reserved at the bottom for the metadata - footer so it does not overlap colourbars. - cbar_height: Height fraction for the horizontal colourbar row (row 2 when - orientation is 'horizontal'). Controls the bar thickness. - cbar_pad: Vertical gap fraction between the plot row and the colourbar row in - horizontal layouts. Set to 0.0 for flush rows. + layout_config: Optional LayoutConfig instance to override default layout parameters. + If None, uses default configuration. Returns: tuple containing: @@ -132,28 +381,35 @@ def _build_layout( # noqa: PLR0913 InvalidArrayError: If the arrays are not 2D or have different shapes. """ + layout = layout_config or _DEFAULT_LAYOUT_CONFIG + outer_margin = layout.outer_margin + title_space = layout.title_footer.title_space + footer_space = layout.title_footer.footer_space + cbar_width = layout.colourbar.default_width_frac + cbar_height = layout.colourbar.default_height_frac + cbar_pad = layout.colourbar.default_pad_frac + # Decide how many main panels are required and which orientation the colourbars use n_panels = 3 if plot_spec.include_difference else 2 orientation = plot_spec.colourbar_location - # Choose gutter default per orientation if not provided - if gutter is None: - gutter = ( - DEFAULT_GUTTER_HORIZONTAL - if orientation == "horizontal" - else DEFAULT_GUTTER_VERTICAL - ) + # Choose gutter default per orientation + gutter = ( + layout.gutter_horizontal + if orientation == "horizontal" + else layout.gutter_vertical + ) # Calculate top boundary: ensure title space does not consume too much of the figure. # At least 60% of the figure height is reserved for the plotting area. - top_val = max(0.6, 1.0 - (outer_margin + title_space)) + top_val = max(layout.min_usable_height_fraction, 1.0 - (outer_margin + title_space)) # Calculate bottom boundary, reserving footer space for metadata bottom_val = outer_margin + footer_space # Calculate figure size based on data aspect ratio or use defaults if height and width and height > 0: # Calculate panel width maintaining data aspect ratio - base_h = 6.0 # Standard height in inches + base_h = layout.base_height_in # Standard height in inches aspect = width / height if orientation == "vertical": @@ -177,11 +433,7 @@ def _build_layout( # noqa: PLR0913 fig_size = (fig_w, base_h) else: # Use predefined sizes when data dimensions are unknown - fig_size = ( - DEFAULT_FIGSIZE_THREE_PANELS - if plot_spec.include_difference - else DEFAULT_FIGSIZE_TWO_PANELS - ) + fig_size = layout.default_figsizes[n_panels] fig = plt.figure(figsize=fig_size, constrained_layout=False, facecolor="none") @@ -477,7 +729,7 @@ def _set_titles(axs: list[Axes], plot_spec: PlotSpec) -> None: bbox={ "facecolor": "white", "edgecolor": "none", - "pad": 2.0, + "pad": _DEFAULT_LAYOUT_CONFIG.title_footer.bbox_pad_title, "alpha": 1.0, }, ) @@ -525,7 +777,7 @@ def _set_axes_limits(axs: list[Axes], *, width: int, height: int) -> None: # --- Colourbar Functions --- -def _get_cbar_limits_from_mappable(cbar: Colorbar) -> tuple[float, float]: +def get_cbar_limits_from_mappable(cbar: Colorbar) -> tuple[float, float]: """Return (vmin, vmax) for a colourbar's mappable with robust fallbacks.""" vmin = vmax = None try: # Works for many matplotlib mappables @@ -535,7 +787,10 @@ def _get_cbar_limits_from_mappable(cbar: Colorbar) -> tuple[float, float]: vmin = getattr(norm, "vmin", None) vmax = getattr(norm, "vmax", None) if vmin is None or vmax is None: - vmin, vmax = 0.0, 1.0 + vmin, vmax = ( + _DEFAULT_LAYOUT_CONFIG.formatting.default_vmin_fallback, + _DEFAULT_LAYOUT_CONFIG.formatting.default_vmax_fallback, + ) return float(vmin), float(vmax) @@ -555,7 +810,7 @@ def _add_colourbars( # noqa: PLR0913, PLR0912 1. Shared colourbar for ground truth and prediction (same data scale) 2. Separate colourbar for difference data (different scale, often symmetric) - The function works with the layout from _build_layout, using dedicated + The function works with the layout from build_layout, using dedicated colourbar axes when available, or falling back to automatic matplotlib placement. Colourbar Design: @@ -594,7 +849,7 @@ def _add_colourbars( # noqa: PLR0913, PLR0912 colourbar_groundtruth = plt.colorbar( image_groundtruth, cax=cbar_axes["groundtruth"], orientation=orientation ) - _format_linear_ticks( + format_linear_ticks( colourbar_groundtruth, vmin=float(plot_spec.vmin) if plot_spec.vmin is not None else None, vmax=float(plot_spec.vmax) if plot_spec.vmax is not None else None, @@ -607,7 +862,7 @@ def _add_colourbars( # noqa: PLR0913, PLR0912 colourbar_prediction = plt.colorbar( image_prediction, cax=cbar_axes["prediction"], orientation=orientation ) - _format_linear_ticks( + format_linear_ticks( colourbar_prediction, decimals=1, is_vertical=is_vertical ) else: @@ -624,7 +879,7 @@ def _add_colourbars( # noqa: PLR0913, PLR0912 ) # Tick formatting - _format_linear_ticks( + format_linear_ticks( colourbar_truth, vmin=float(plot_spec.vmin) if plot_spec.vmin is not None else None, vmax=float(plot_spec.vmax) if plot_spec.vmax is not None else None, @@ -666,23 +921,30 @@ def _add_colourbars( # noqa: PLR0913, PLR0912 # Tick formatting: symmetric for TwoSlopeNorm, otherwise linear if isinstance(image_difference.norm, TwoSlopeNorm): - vmin = float(image_difference.norm.vmin or -1.0) - vmax = float(image_difference.norm.vmax or 1.0) - _format_symmetric_ticks( + vmin = float( + image_difference.norm.vmin + or _DEFAULT_LAYOUT_CONFIG.formatting.default_vmin_diff_fallback + ) + vmax = float( + image_difference.norm.vmax + or _DEFAULT_LAYOUT_CONFIG.formatting.default_vmax_diff_fallback + ) + format_symmetric_ticks( colourbar_diff, vmin=vmin, vmax=vmax, decimals=2, is_vertical=is_vertical, + centre=image_difference.norm.vcenter, ) else: - _format_linear_ticks(colourbar_diff, decimals=2, is_vertical=is_vertical) + format_linear_ticks(colourbar_diff, decimals=2, is_vertical=is_vertical) # --- Tick Formatting Functions --- -def _format_linear_ticks( +def format_linear_ticks( colourbar: Colorbar, *, vmin: float | None = None, @@ -697,40 +959,60 @@ def _format_linear_ticks( axis = colourbar.ax.yaxis if is_vertical else colourbar.ax.xaxis if vmin is None or vmax is None: - mvmin, mvmax = _get_cbar_limits_from_mappable(colourbar) + mvmin, mvmax = get_cbar_limits_from_mappable(colourbar) vmin = mvmin if vmin is None else vmin vmax = mvmax if vmax is None else vmax - ticks = np.linspace(float(vmin), float(vmax), 5) + ticks = np.linspace( + float(vmin), float(vmax), _DEFAULT_LAYOUT_CONFIG.formatting.num_ticks_linear + ) colourbar.set_ticks([float(t) for t in ticks]) axis.set_major_formatter(FuncFormatter(lambda x, _: f"{x:.{decimals}f}")) if not is_vertical: colourbar.ax.xaxis.set_tick_params(pad=1) - _apply_monospace_to_cbar_text(colourbar) + apply_monospace_to_cbar_text(colourbar) -def _format_symmetric_ticks( +def format_symmetric_ticks( colourbar: Colorbar, *, vmin: float, vmax: float, decimals: int = 2, is_vertical: bool, + centre: float | None = None, ) -> None: - """Format symmetric diverging ticks with a 0-centered midpoint. + """Format symmetric diverging ticks with a centred midpoint. + + Places five ticks: [vmin, midpoint to centre, centre, centre to midpoint, vmax]. + + Args: + colourbar: Colorbar to format. + vmin: Minimum value. + vmax: Maximum value. + decimals: Number of decimal places for tick labels. + is_vertical: Whether the colorbar is vertical. + centre: centre value for diverging colourmap (default: 0.0). - Places five ticks: [vmin, midpoint to 0, 0, 0 to midpoint, vmax]. """ axis = colourbar.ax.yaxis if is_vertical else colourbar.ax.xaxis - ticks = [vmin, 0.5 * (vmin + 0.0), 0.0, 0.5 * (0.0 + vmax), vmax] + centre_val = centre if centre is not None else 0.0 + midpoint_factor = _DEFAULT_LAYOUT_CONFIG.formatting.midpoint_factor + ticks = [ + vmin, + midpoint_factor * (vmin + centre_val), + centre_val, + midpoint_factor * (centre_val + vmax), + vmax, + ] colourbar.set_ticks([float(t) for t in ticks]) axis.set_major_formatter(FuncFormatter(lambda x, _: f"{x:.{decimals}f}")) if not is_vertical: colourbar.ax.xaxis.set_tick_params(pad=1) - _apply_monospace_to_cbar_text(colourbar) + apply_monospace_to_cbar_text(colourbar) -def _apply_monospace_to_cbar_text(colourbar: Colorbar) -> None: +def apply_monospace_to_cbar_text(colourbar: Colorbar) -> None: """Set tick labels and axis labels on a colourbar to monospace family.""" ax = colourbar.ax for label in list(ax.get_xticklabels()) + list(ax.get_yticklabels()): @@ -738,3 +1020,133 @@ def _apply_monospace_to_cbar_text(colourbar: Colorbar) -> None: # Ensure axis labels also use monospace if present ax.xaxis.label.set_fontfamily("monospace") ax.yaxis.label.set_fontfamily("monospace") + + +def _default_vertical_gap_inches(aspect: float, cfg: GapConfig) -> float: + """Return an aspect-aware physical gap between panel and vertical colourbar.""" + aspect = max(aspect, EPSILON_SMALL) + if aspect >= 1.0: + wide = min(aspect, cfg.wide_limit) + if np.isclose(cfg.wide_limit, 1.0): + return cfg.base + t = (wide - 1.0) / (cfg.wide_limit - 1.0) + gap = cfg.base - t * (cfg.base - cfg.wide_gap) + else: + tall_limit = max(cfg.tall_limit, EPSILON_SMALL) + tall = min(1.0 / aspect, 1.0 / tall_limit) + denom = (1.0 / tall_limit) - 1.0 + t = 0.0 if np.isclose(denom, 0.0) else (tall - 1.0) / denom + gap = cfg.base + t * (cfg.tall_gap - cfg.base) + return float(np.clip(gap, cfg.min_val, cfg.max_val)) + + +# --- Text and Box Annotation Functions --- + + +def set_suptitle_with_box(fig: Figure, text: str) -> plt.Text: + """Draw a fixed-position title with a white box that doesn't influence layout. + + Returns the Text artist so callers can update with set_text during animation. + This version avoids kwargs that are unsupported on older Matplotlib. + + Args: + fig: Matplotlib Figure object. + text: Title text to display. + + Returns: + Text artist for the title. + + """ + config = _DEFAULT_LAYOUT_CONFIG.title_footer + bbox = { + "facecolor": "white", + "edgecolor": "none", + "pad": config.bbox_pad_title, + "alpha": 1.0, + } + t = fig.text( + x=0.5, + y=config.title_y, + s=text, + ha="center", + va="top", + fontsize=config.title_fontsize, + fontfamily="monospace", + transform=fig.transFigure, + bbox=bbox, + ) + with contextlib.suppress(Exception): + t.set_zorder(config.zorder_high) + return t + + +def set_footer_with_box(fig: Figure, text: str) -> plt.Text: + """Draw a fixed-position footer with a white box at bottom centre. + + Footer is intended for metadata and secondary information. + + Args: + fig: Matplotlib Figure object. + text: Footer text to display. + + Returns: + Text artist for the footer. + + """ + config = _DEFAULT_LAYOUT_CONFIG.title_footer + bbox = { + "facecolor": "white", + "edgecolor": "none", + "pad": config.bbox_pad_title, + "alpha": 1.0, + } + t = fig.text( + x=0.5, + y=config.footer_y, + s=text, + ha="center", + va="bottom", + fontsize=config.footer_fontsize, + fontfamily="monospace", + transform=fig.transFigure, + bbox=bbox, + ) + with contextlib.suppress(Exception): + t.set_zorder(config.zorder_high) + return t + + +def draw_badge_with_box(fig: Figure, x: float, y: float, text: str) -> plt.Text: + """Draw a warning/info badge with white background box at figure coords. + + Args: + fig: Matplotlib Figure object. + x: X position in figure coordinates (0-1). + y: Y position in figure coordinates (0-1). + text: Badge text to display. + + Returns: + Text artist for the badge. + + """ + config = _DEFAULT_LAYOUT_CONFIG.title_footer + bbox = { + "facecolor": "white", + "edgecolor": "none", + "pad": config.bbox_pad_badge, + "alpha": 1.0, + } + t = fig.text( + x=x, + y=y, + s=text, + fontsize=config.footer_fontsize, + fontfamily="monospace", + color="firebrick", + ha="center", + va="top", + bbox=bbox, + ) + with contextlib.suppress(Exception): + t.set_zorder(config.zorder_high) + return t From 25506de37b8b14025b7bc4275c2a65b0c3269e27 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Sat, 22 Nov 2025 14:28:35 +0000 Subject: [PATCH 27/49] work with style and layout object --- .../visualisations/plotting_maps.py | 185 +++++++----------- 1 file changed, 75 insertions(+), 110 deletions(-) diff --git a/ice_station_zebra/visualisations/plotting_maps.py b/ice_station_zebra/visualisations/plotting_maps.py index b0cf07aa..a55ddc15 100644 --- a/ice_station_zebra/visualisations/plotting_maps.py +++ b/ice_station_zebra/visualisations/plotting_maps.py @@ -8,7 +8,6 @@ """ -import contextlib import io import logging from collections.abc import Sequence @@ -20,15 +19,26 @@ from matplotlib import animation from matplotlib.axes import Axes from matplotlib.colors import ListedColormap -from matplotlib.text import Text from PIL.ImageFile import ImageFile from ice_station_zebra.exceptions import InvalidArrayError from ice_station_zebra.types import DiffColourmapSpec, PlotSpec from . import convert -from .layout import _add_colourbars, _build_layout, _set_axes_limits, _set_titles +from .animation_helper import hold_anim, release_anim +from .layout import ( + LayoutConfig, + TitleFooterConfig, + _add_colourbars, + _set_axes_limits, + _set_titles, + build_layout, + draw_badge_with_box, + set_footer_with_box, + set_suptitle_with_box, +) from .plotting_core import ( + colourmap_with_bad, compute_difference, compute_display_ranges, compute_display_ranges_stream, @@ -43,10 +53,6 @@ logger = logging.getLogger(__name__) - -# Keep strong references to animation objects during save to avoid GC-related warnings -_ANIM_CACHE: list[animation.FuncAnimation] = [] - #: Default plotting specification for sea ice concentration visualisation DEFAULT_SIC_SPEC = PlotSpec( variable="sea_ice_concentration", @@ -65,6 +71,23 @@ ) +def _safe_linspace(vmin: float, vmax: float, n: int) -> np.ndarray: + """Return an increasing linspace even if vmin==vmax or the inputs are swapped.""" + # ensure float + vmin = float(vmin) + vmax = float(vmax) + if not np.isfinite(vmin) or not np.isfinite(vmax): + # fallback to a stable interval + vmin, vmax = 0.0, 1.0 + if vmax < vmin: + vmin, vmax = vmax, vmin + if vmax == vmin: + # create a tiny range so contourf accepts the levels + eps = max(1e-6, abs(vmin) * 1e-6) + vmax = vmin + eps + return np.linspace(vmin, vmax, n) + + # --- Static Map Plot --- def plot_maps( plot_spec: PlotSpec, @@ -117,20 +140,17 @@ def plot_maps( ) # Increase title space if warnings are present to avoid overlap with axes titles - title_space_override = 0.10 if range_check_report.warnings else None + layout_config = None + if range_check_report.warnings: + layout_config = LayoutConfig(title_footer=TitleFooterConfig(title_space=0.10)) # Initialise the figure and axes with dynamic top spacing if needed - if title_space_override is not None: - fig, axs, cbar_axes = _build_layout( - plot_spec=plot_spec, - height=height, - width=width, - title_space=title_space_override, - ) - else: - fig, axs, cbar_axes = _build_layout( - plot_spec=plot_spec, height=height, width=width - ) + fig, axs, cbar_axes = build_layout( + plot_spec=plot_spec, + height=height, + width=width, + layout_config=layout_config, + ) levels = levels_from_spec(plot_spec) # Prepare difference rendering parameters if needed @@ -166,7 +186,7 @@ def plot_maps( _set_axes_limits(axs, width=width, height=height) try: - title_text = _set_suptitle_with_box(fig, _build_title_static(plot_spec, date)) + title_text = set_suptitle_with_box(fig, _build_title_static(plot_spec, date)) except Exception: logger.exception("Failed to draw suptitle; continuing without title.") title_text = None @@ -186,19 +206,19 @@ def plot_maps( warning_y = max(title_y - (0.05 + 0.02 * (n_lines - 1)), 0.0) else: warning_y = 0.90 - _draw_badge_with_box(fig, 0.5, warning_y, badge) + draw_badge_with_box(fig, 0.5, warning_y, badge) # Footer metadata at the bottom if getattr(plot_spec, "include_footer_metadata", True): try: footer_text = _build_footer_static(plot_spec) if footer_text: - _set_footer_with_box(fig, footer_text) + set_footer_with_box(fig, footer_text) except Exception: logger.exception("Failed to draw footer; continuing without footer.") try: - return {"sea-ice_concentration-static-maps": [convert._image_from_figure(fig)]} + return {"sea-ice_concentration-static-maps": [convert.image_from_figure(fig)]} finally: plt.close(fig) @@ -266,8 +286,12 @@ def video_maps( prediction_stream = np.where(land_mask, np.nan, prediction_stream) # Initialise the figure and axes with a larger footer space for videos - fig, axs, cbar_axes = _build_layout( - plot_spec=plot_spec, height=height, width=width, footer_space=0.11 + layout_config = LayoutConfig(title_footer=TitleFooterConfig(footer_space=0.11)) + fig, axs, cbar_axes = build_layout( + plot_spec=plot_spec, + height=height, + width=width, + layout_config=layout_config, ) levels = levels_from_spec(plot_spec) @@ -322,9 +346,7 @@ def video_maps( ) _set_axes_limits(axs, width=width, height=height) try: - title_text = _set_suptitle_with_box( - fig, _build_title_video(plot_spec, dates, 0) - ) + title_text = set_suptitle_with_box(fig, _build_title_video(plot_spec, dates, 0)) except Exception: logger.exception("Failed to draw suptitle; continuing without title.") title_text = None @@ -334,7 +356,7 @@ def video_maps( try: footer_text = _build_footer_video(plot_spec, dates) if footer_text: - _set_footer_with_box(fig, footer_text) + set_footer_with_box(fig, footer_text) except Exception: logger.exception("Failed to draw footer; continuing without footer.") @@ -372,19 +394,18 @@ def animate(tt: int) -> tuple[()]: blit=False, repeat=True, ) - # Keep a strong reference without touching figure attributes (ruff-fix) - _ANIM_CACHE.append(animation_object) + # Keep a strong reference to prevent garbage collection during save + hold_anim(animation_object) # Save -> BytesIO and clean up temp file try: - video_buffer = convert._save_animation( - animation_object, _fps=fps, _video_format=video_format + video_buffer = convert.save_animation( + animation_object, fps=fps, video_format=video_format ) return {"sea-ice_concentration-video-maps": video_buffer} finally: # Drop strong reference now that saving is done - with contextlib.suppress(ValueError): - _ANIM_CACHE.remove(animation_object) + release_anim(animation_object) plt.close(fig) @@ -473,19 +494,22 @@ def _draw_main_panels( # noqa: PLR0913 # Use PlotSpec levels for ground_truth and prediction unless overridden levels = levels_from_spec(plot_spec) if levels_override is None else levels_override + # Create colourmap with bad color handling for NaN values + cmap = colourmap_with_bad(plot_spec.colourmap, bad_color="lightgrey") + if plot_spec.colourbar_strategy == "separate": # For separate strategy, use explicit levels to prevent breathing - groundtruth_levels = np.linspace( + groundtruth_levels = _safe_linspace( groundtruth_vmin, groundtruth_vmax, plot_spec.n_contour_levels ) - prediction_levels = np.linspace( + prediction_levels = _safe_linspace( prediction_vmin, prediction_vmax, plot_spec.n_contour_levels ) image_groundtruth = axs[0].contourf( ground_truth, levels=groundtruth_levels, - cmap=plot_spec.colourmap, + cmap=cmap, vmin=groundtruth_vmin, vmax=groundtruth_vmax, origin="lower", @@ -493,7 +517,7 @@ def _draw_main_panels( # noqa: PLR0913 image_prediction = axs[1].contourf( prediction, levels=prediction_levels, - cmap=plot_spec.colourmap, + cmap=cmap, vmin=prediction_vmin, vmax=prediction_vmax, origin="lower", @@ -503,7 +527,7 @@ def _draw_main_panels( # noqa: PLR0913 image_groundtruth = axs[0].contourf( ground_truth, levels=levels, - cmap=plot_spec.colourmap, + cmap=cmap, vmin=groundtruth_vmin, vmax=groundtruth_vmax, origin="lower", @@ -511,7 +535,7 @@ def _draw_main_panels( # noqa: PLR0913 image_prediction = axs[1].contourf( prediction, levels=levels, - cmap=plot_spec.colourmap, + cmap=cmap, vmin=prediction_vmin, vmax=prediction_vmax, origin="lower", @@ -603,16 +627,21 @@ def _draw_frame( # noqa: PLR0913 error_msg = "diff_colour_scale must be provided when including difference" raise InvalidArrayError(error_msg) + # Create colourmap with bad color handling for NaN values + diff_cmap = colourmap_with_bad(diff_colour_scale.cmap, bad_color="lightgrey") + if diff_colour_scale.norm is not None: # Signed differences with TwoSlopeNorm - use explicit levels to ensure consistency diff_vmin = diff_colour_scale.norm.vmin or 0.0 diff_vmax = diff_colour_scale.norm.vmax or 1.0 - diff_levels = np.linspace(diff_vmin, diff_vmax, plot_spec.n_contour_levels) + diff_levels = _safe_linspace( + diff_vmin, diff_vmax, plot_spec.n_contour_levels + ) image_difference = axs[2].contourf( difference, levels=diff_levels, - cmap=diff_colour_scale.cmap, + cmap=diff_cmap, vmin=diff_vmin, vmax=diff_vmax, origin="lower", @@ -621,7 +650,7 @@ def _draw_frame( # noqa: PLR0913 # Non-negative differences with vmin/vmax vmin = diff_colour_scale.vmin or 0.0 vmax = diff_colour_scale.vmax or 1.0 - diff_levels = np.linspace( + diff_levels = _safe_linspace( vmin, vmax, plot_spec.n_contour_levels, @@ -630,7 +659,7 @@ def _draw_frame( # noqa: PLR0913 image_difference = axs[2].contourf( difference, levels=diff_levels, - cmap=diff_colour_scale.cmap, + cmap=diff_cmap, vmin=vmin, vmax=vmax, origin="lower", @@ -652,7 +681,7 @@ def _overlay_nans(ax: Axes, arr: np.ndarray) -> None: if np.isnan(arr).any(): # Create overlay for NaN areas (land mask) nan_mask = np.isnan(arr).astype(float) - # Create a custom colormap: 0=transparent, 1=land color + # Create a custom colourmap: 0=transparent, 1=land color # Land colour options: 'white' for white land, 'grey' for grey land colors = ["white", "white"] # 0=white (transparent), 1=land color @@ -690,70 +719,6 @@ def _clear_plot(ax: Axes) -> None: ax.set_title("") -def _set_suptitle_with_box(fig: plt.Figure, text: str) -> Text: - """Draw a fixed-position title with a white box that doesn't influence layout. - - Returns the Text artist so callers can update with set_text during animation. - This version avoids kwargs that are unsupported on older Matplotlib. - """ - bbox = {"facecolor": "white", "edgecolor": "none", "pad": 2.0, "alpha": 1.0} - t = fig.text( - x=0.5, - y=0.98, - s=text, - ha="center", - va="top", - fontsize=12, - fontfamily="monospace", - transform=fig.transFigure, - bbox=bbox, - ) - with contextlib.suppress(Exception): - t.set_zorder(1000) - return t - - -def _set_footer_with_box(fig: plt.Figure, text: str) -> Text: - """Draw a fixed-position footer with a white box at bottom center. - - Footer is intended for metadata and secondary information. - """ - bbox = {"facecolor": "white", "edgecolor": "none", "pad": 2.0, "alpha": 1.0} - t = fig.text( - x=0.5, - y=0.03, - s=text, - ha="center", - va="bottom", - fontsize=11, - fontfamily="monospace", - transform=fig.transFigure, - bbox=bbox, - ) - with contextlib.suppress(Exception): - t.set_zorder(1000) - return t - - -def _draw_badge_with_box(fig: plt.Figure, x: float, y: float, text: str) -> Text: - """Draw a warning/info badge with white background box at figure coords.""" - bbox = {"facecolor": "white", "edgecolor": "none", "pad": 1.5, "alpha": 1.0} - t = fig.text( - x=x, - y=y, - s=text, - fontsize=11, - fontfamily="monospace", - color="firebrick", - ha="center", - va="top", - bbox=bbox, - ) - with contextlib.suppress(Exception): - t.set_zorder(1000) - return t - - # --- Title helpers --- def _formatted_variable_name(variable: str) -> str: """Return a human-friendly variable name for titles. From ce014554976a12d5eff8d80771a704062ea21500 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Sat, 22 Nov 2025 14:38:43 +0000 Subject: [PATCH 28/49] file for raw plots --- .../visualisations/plotting_raw_inputs.py | 488 ++++++++++++++++++ 1 file changed, 488 insertions(+) create mode 100644 ice_station_zebra/visualisations/plotting_raw_inputs.py diff --git a/ice_station_zebra/visualisations/plotting_raw_inputs.py b/ice_station_zebra/visualisations/plotting_raw_inputs.py new file mode 100644 index 00000000..4577e7b6 --- /dev/null +++ b/ice_station_zebra/visualisations/plotting_raw_inputs.py @@ -0,0 +1,488 @@ +"""Raw input variable visualisation for a single timestep or animation over time. + +Creates one static map per input channel with optional per-variable style overrides, +or animations showing temporal evolution of individual variables. + +This module is intentionally lightweight and single-panel oriented while reusing the +layout and formatting helpers from the sea-ice plotting stack to keep appearances aligned. +""" + +from __future__ import annotations + +import logging +from datetime import date, datetime +from typing import TYPE_CHECKING, Any, Literal + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib import animation +from matplotlib.colors import TwoSlopeNorm + +from ice_station_zebra.exceptions import InvalidArrayError +from ice_station_zebra.visualisations.animation_helper import hold_anim, release_anim +from ice_station_zebra.visualisations.layout import ( + build_single_panel_figure, + format_linear_ticks, + format_symmetric_ticks, + set_suptitle_with_box, +) +from ice_station_zebra.visualisations.plotting_core import ( + VariableStyle, + colourmap_with_bad, + create_normalisation, + load_land_mask, + safe_filename, + save_figure, + style_for_variable, +) + +if TYPE_CHECKING: + import io + from collections.abc import Sequence + from pathlib import Path + + from matplotlib.text import Text + from PIL.ImageFile import ImageFile + + from ice_station_zebra.types import PlotSpec + +logger = logging.getLogger(__name__) + + +def _format_title( + variable: str, hemisphere: str | None, when: date | datetime, units: str | None +) -> str: + """Format a title string for a raw input variable plot. + + Args: + variable: Variable name. + hemisphere: Hemisphere ("north" or "south"), if applicable. + when: Date or datetime of the data. + units: Display units for the variable. + + Returns: + Formatted title string. + + """ + hemi = f" ({hemisphere.capitalize()})" if hemisphere else "" + units_s = f" [{units}]" if units else "" + shown = when.date().isoformat() if isinstance(when, datetime) else when.isoformat() + return f"{variable}{units_s}{hemi} Shown: {shown}" + + +def plot_raw_inputs_for_timestep( # noqa: PLR0913, C901, PLR0912, PLR0915 + *, + channel_arrays: list[np.ndarray], + channel_names: list[str], + when: date | datetime, + plot_spec_base: PlotSpec, + land_mask: np.ndarray | None = None, + styles: dict[str, dict[str, Any]] | None = None, + save_dir: Path | None = None, +) -> list[tuple[str, ImageFile, Path | None]]: + """Plot one image per input channel as a static map. + + Returns list of (name, PIL.ImageFile, saved_path|None). + """ + if len(channel_arrays) != len(channel_names): + msg = ( + f"Channels count mismatch: arrays={len(channel_arrays)} " + f"names={len(channel_names)}" + ) + raise InvalidArrayError(msg) + + # Ensure we pick up styles either from the explicit argument or from the PlotSpec + if styles is None: + # some callers may embed variable_styles inside the plot_spec dataclass + styles = getattr(plot_spec_base, "variable_styles", None) + + results: list[tuple[str, ImageFile, Path | None]] = [] + land_mask_cache: dict[tuple[int, int], np.ndarray | None] = {} + + from . import convert # noqa: PLC0415 # local import to avoid circulars + + expected_ndim = 2 + for arr, name in zip(channel_arrays, channel_names, strict=True): + if arr.ndim != expected_ndim: + msg = f"Expected 2D [H,W] for channel '{name}', got {arr.shape}" + raise InvalidArrayError(msg) + + # Apply land mask (mask out land to NaN) + active_mask = land_mask + if active_mask is None and plot_spec_base.land_mask_path: + shape = arr.shape + if shape not in land_mask_cache: + try: + land_mask_cache[shape] = load_land_mask( + plot_spec_base.land_mask_path, shape + ) + except InvalidArrayError: + logger.exception( + "Failed to load land mask for '%s' with shape %s", name, shape + ) + land_mask_cache[shape] = None + active_mask = land_mask_cache.get(shape) + + # Apply mask if available, creating a new variable to avoid overwriting loop variable + arr_to_plot = arr + if active_mask is not None: + if active_mask.shape == arr.shape: + arr_to_plot = np.where(active_mask, np.nan, arr) + else: + logger.debug( + "Skipping land mask for '%s': mask shape %s != array shape %s", + name, + active_mask.shape, + arr.shape, + ) + + # Prefer an explicit styles dict, otherwise fall back to the PlotSpec attribute if present + styles_effective = ( + styles + if styles is not None + else getattr(plot_spec_base, "variable_styles", None) + ) + logger.debug( + "plot_raw_inputs: incoming name=%r; using styles keys=%s", + name, + list(styles_effective.keys()) if styles_effective else None, + ) + style = style_for_variable(name, styles_effective) + logger.debug( + "plot_raw_inputs: resolved style for %r => cmap=%r, vmin=%r, vmax=%r, origin=%r", + name, + style.cmap, + style.vmin, + style.vmax, + style.origin, + ) + + # Create normalisation using shared function + norm, vmin, vmax = create_normalisation( + arr_to_plot, + vmin=style.vmin, + vmax=style.vmax, + centre=style.two_slope_centre, + ) + + # Build figure and axis + height, width = arr_to_plot.shape + fig, ax, cax = build_single_panel_figure( + height=height, + width=width, + colourbar_location=plot_spec_base.colourbar_location, + ) + + # Render + cmap_name = style.cmap or plot_spec_base.colourmap + cmap = colourmap_with_bad(cmap_name, bad_color="lightgrey") + origin = style.origin or "lower" + image = ax.imshow( + arr_to_plot, cmap=cmap, norm=norm, origin=origin, interpolation="nearest" + ) + + # Colourbar + orientation = plot_spec_base.colourbar_location + cbar = fig.colorbar(image, ax=ax, cax=cax, orientation=orientation) + is_vertical = orientation == "vertical" + decimals = style.decimals if style.decimals is not None else 2 + if isinstance(norm, TwoSlopeNorm): + format_symmetric_ticks( + cbar, + vmin=vmin, + vmax=vmax, + decimals=decimals, + is_vertical=is_vertical, + centre=norm.vcenter, + ) + else: + format_linear_ticks( + cbar, vmin=vmin, vmax=vmax, decimals=decimals, is_vertical=is_vertical + ) + + # Title + try: + set_suptitle_with_box( + fig, + _format_title(name, plot_spec_base.hemisphere, when, style.units), + ) + except (ValueError, AttributeError, RuntimeError) as err: + logger.debug( + "Failed to draw raw-inputs title: %s; continuing without title.", err + ) + + # Save to disk if requested (colons in variable names replaced) + file_base = name.replace(":", "__") + saved_path = save_figure(fig, save_dir, file_base) + + # Convert to PIL + try: + pil_img = convert.image_from_figure(fig) + finally: + plt.close(fig) + + results.append((name, pil_img, saved_path)) + + return results + + +def video_raw_input_for_variable( # noqa: PLR0913, C901, PLR0915 + *, + variable_name: str, + data_stream: np.ndarray, + dates: Sequence[date | datetime], + plot_spec_base: PlotSpec, + land_mask: np.ndarray | None = None, + style: VariableStyle | None = None, + styles: dict[str, dict[str, Any]] | None = None, + fps: int = 2, + video_format: Literal["mp4", "gif"] = "gif", + save_path: Path | None = None, +) -> io.BytesIO: + """Create animation showing temporal evolution of a single variable. + + This function creates a video animation of a single input variable over time, + with consistent colour scaling across all frames to avoid visual "breathing". + + Args: + variable_name: Name of the variable (e.g., "era5:2t", "osisaf-south:ice_conc"). + data_stream: 3D array of data over time [T, H, W]. + dates: Sequence of dates corresponding to each timestep (length must match T). + plot_spec_base: Base plotting specification (colourmap, hemisphere, etc.). + land_mask: Optional 2D boolean array marking land areas [H, W]. + style: Optional pre-computed VariableStyle for this variable. + styles: Optional dictionary of style configurations for matching. + fps: Frames per second for the output video. + video_format: Output format, either "mp4" or "gif". + save_path: Optional path to save the video file to disk. + + Returns: + BytesIO buffer containing the encoded video data. + + Raises: + InvalidArrayError: If data_stream is not 3D or dates length mismatches. + VideoRenderError: If video encoding fails. + + """ + from . import convert # noqa: PLC0415 # Local import to avoid circulars + + # Validate input + if data_stream.ndim != 3: # noqa: PLR2004 + msg = f"Expected 3D data [T,H,W], got shape {data_stream.shape}" + raise InvalidArrayError(msg) + + n_timesteps, height, width = data_stream.shape + if len(dates) != n_timesteps: + msg = f"Number of dates ({len(dates)}) != number of timesteps ({n_timesteps})" + raise InvalidArrayError(msg) + + # Apply land mask if provided + if land_mask is not None: + if land_mask.shape != (height, width): + logger.debug( + "Land mask shape %s doesn't match data shape (%d, %d), skipping mask", + land_mask.shape, + height, + width, + ) + else: + # Mask out land areas by setting them to NaN + data_stream = np.where(land_mask, np.nan, data_stream) + + # Get styling for this variable + if style is None: + style = style_for_variable(variable_name, styles) + + # Create stable normalisation across all frames + # Use global min/max to prevent colour scale "breathing" + data_min = float(np.nanmin(data_stream)) if np.isfinite(data_stream).any() else 0.0 + data_max = float(np.nanmax(data_stream)) if np.isfinite(data_stream).any() else 1.0 + + # Override with style limits if provided + effective_vmin = style.vmin if style.vmin is not None else data_min + effective_vmax = style.vmax if style.vmax is not None else data_max + + # Create normalisation using first frame to establish type + norm, vmin, vmax = create_normalisation( + data_stream[0], + vmin=effective_vmin, + vmax=effective_vmax, + centre=style.two_slope_centre, + ) + + # Build figure with single panel + colourbar + fig, ax, cax = build_single_panel_figure( + height=height, + width=width, + colourbar_location=plot_spec_base.colourbar_location, + ) + + # Render initial frame + cmap_name = style.cmap or plot_spec_base.colourmap + cmap = colourmap_with_bad(cmap_name, bad_color="lightgrey") + origin = style.origin or "lower" + image = ax.imshow( + data_stream[0], cmap=cmap, norm=norm, origin=origin, interpolation="nearest" + ) + + # Create colourbar (stable across all frames) + orientation = plot_spec_base.colourbar_location + cbar = fig.colorbar(image, ax=ax, cax=cax, orientation=orientation) + is_vertical = orientation == "vertical" + + # Format colourbar ticks based on normalisation type + decimals = style.decimals if style.decimals is not None else 2 + if isinstance(norm, TwoSlopeNorm): + format_symmetric_ticks( + cbar, + vmin=vmin, + vmax=vmax, + decimals=decimals, + is_vertical=is_vertical, + centre=norm.vcenter, + ) + else: + format_linear_ticks( + cbar, vmin=vmin, vmax=vmax, decimals=decimals, is_vertical=is_vertical + ) + + # Create title (will be updated each frame) + title_text: Text | None = None + try: + title_text = set_suptitle_with_box( + fig, + _format_title( + variable_name, plot_spec_base.hemisphere, dates[0], style.units + ), + ) + except (ValueError, AttributeError, RuntimeError) as err: + logger.debug("Failed to draw title: %s; continuing without title.", err) + + # Animation update function + def animate(tt: int) -> tuple[()]: + """Update function for each frame of the animation.""" + # Update image data + image.set_data(data_stream[tt]) + + # Update title with current date + if title_text is not None: + title_text.set_text( + _format_title( + variable_name, plot_spec_base.hemisphere, dates[tt], style.units + ) + ) + + return () + + # Create animation object + anim = animation.FuncAnimation( + fig, + animate, + frames=n_timesteps, + interval=1000 // fps, + blit=False, + repeat=True, + ) + + # Keep strong reference to prevent garbage collection during save + hold_anim(anim) + + try: + # Save to BytesIO buffer + video_buffer = convert.save_animation(anim, fps=fps, video_format=video_format) + + # Optionally save to disk + if save_path is not None: + save_path.parent.mkdir(parents=True, exist_ok=True) + save_path.write_bytes(video_buffer.getvalue()) + logger.debug("Saved animation to %s", save_path) + # Reset buffer position after writing + video_buffer.seek(0) + + return video_buffer + + finally: + # Clean up: remove from cache and close figure + release_anim(anim) + plt.close(fig) + + +def video_raw_inputs_for_timesteps( # noqa: PLR0913 + *, + channel_arrays_stream: list[np.ndarray], + channel_names: list[str], + dates: Sequence[date | datetime], + plot_spec_base: PlotSpec, + land_mask: np.ndarray | None = None, + styles: dict[str, dict[str, Any]] | None = None, + fps: int = 2, + video_format: Literal["mp4", "gif"] = "gif", + save_dir: Path | None = None, +) -> list[tuple[str, io.BytesIO, Path | None]]: + """Create animations for multiple input variables over time. + + This is a convenience wrapper around `video_raw_input_for_variable()` that + processes multiple variables in a batch, applying consistent styling and + land masking to all variables. + + Args: + channel_arrays_stream: List of 3D arrays, one per variable [T, H, W]. + channel_names: List of variable names corresponding to each array. + dates: Sequence of dates for each timestep (length must match T). + plot_spec_base: Base plotting specification. + land_mask: Optional 2D land mask to apply to all variables. + styles: Optional dictionary of style configurations. + fps: Frames per second for videos. + video_format: Output format ("mp4" or "gif"). + save_dir: Optional directory to save videos to disk. + + Returns: + List of tuples (variable_name, video_buffer, saved_path). + + Raises: + InvalidArrayError: If arrays/names count mismatch or invalid shapes. + + """ + if len(channel_arrays_stream) != len(channel_names): + msg = ( + f"Channel count mismatch: {len(channel_arrays_stream)} arrays, " + f"{len(channel_names)} names" + ) + raise InvalidArrayError(msg) + + results: list[tuple[str, io.BytesIO, Path | None]] = [] + + for data_stream, var_name in zip(channel_arrays_stream, channel_names, strict=True): + logger.debug("Creating animation for variable: %s", var_name) + + # Determine save path if save_dir is provided + save_path: Path | None = None + if save_dir is not None: + # Sanitise variable name for filename + file_base = var_name.replace(":", "__") + suffix = ".gif" if video_format == "gif" else ".mp4" + save_path = save_dir / f"{safe_filename(file_base)}{suffix}" + + try: + # Create animation for this variable + video_buffer = video_raw_input_for_variable( + variable_name=var_name, + data_stream=data_stream, + dates=dates, + plot_spec_base=plot_spec_base, + land_mask=land_mask, + styles=styles, + fps=fps, + video_format=video_format, + save_path=save_path, + ) + + results.append((var_name, video_buffer, save_path)) + + except (InvalidArrayError, ValueError, MemoryError, OSError): + logger.exception( + "Failed to create animation for variable %s, skipping", var_name + ) + continue + + return results From ed0179f83d6a005489c3393f9be69509fd3026cc Mon Sep 17 00:00:00 2001 From: Lydia France Date: Sat, 22 Nov 2025 14:40:36 +0000 Subject: [PATCH 29/49] Update fixtures for raw plot testing --- tests/plotting/conftest.py | 123 +++++++++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) diff --git a/tests/plotting/conftest.py b/tests/plotting/conftest.py index 11d1ef6b..f31895b2 100644 --- a/tests/plotting/conftest.py +++ b/tests/plotting/conftest.py @@ -1,10 +1,12 @@ import warnings +from collections.abc import Iterator from datetime import date, timedelta from pathlib import Path from typing import Any, Protocol, cast import hydra import matplotlib as mpl +import matplotlib.pyplot as plt import numpy as np import pytest import torch @@ -12,6 +14,7 @@ from omegaconf import errors as oc_errors from ice_station_zebra.data_loaders import ZebraDataModule +from ice_station_zebra.types import PlotSpec from tests.conftest import make_varying_sic_stream mpl.use("Agg") @@ -25,6 +28,15 @@ ) TEST_DATE = date(2020, 1, 15) +TEST_HEIGHT = 48 +TEST_WIDTH = 48 + + +@pytest.fixture(autouse=True) +def close_all_figures() -> Iterator[None]: + """Automatically close all matplotlib figures after each test to prevent warnings.""" + yield + plt.close("all") @pytest.fixture @@ -259,3 +271,114 @@ def example_checkpoint_path(pytestconfig: pytest.Config) -> Path | None: return ckpt.resolve() return None + + +# --- Raw Inputs Fixtures --- + + +@pytest.fixture +def base_plot_spec() -> PlotSpec: + """Base plotting specification for raw inputs.""" + return PlotSpec( + variable="raw_inputs", + colourmap="viridis", + colourbar_location="vertical", + hemisphere="south", + ) + + +@pytest.fixture +def test_dates_short() -> list[date]: + """Generate a short sequence of test dates for animations (4 days).""" + return [TEST_DATE + timedelta(days=i) for i in range(4)] + + +@pytest.fixture +def land_mask_2d() -> np.ndarray: + """Create a simple circular land mask for testing [H, W].""" + dist = make_central_distance_grid(TEST_HEIGHT, TEST_WIDTH) + radius = min(TEST_HEIGHT, TEST_WIDTH) * 0.25 + return (dist < radius).astype(bool) + + +@pytest.fixture +def era5_temperature_2d() -> np.ndarray: + """Generate synthetic ERA5 2m temperature data (K) [H, W].""" + rng = np.random.default_rng(100) + # Temperature centreed around 273.15K (0°C) with realistic variation + base_temp = 273.15 + rng.normal(0, 10, size=(TEST_HEIGHT, TEST_WIDTH)) + return base_temp.astype(np.float32) + + +@pytest.fixture +def era5_humidity_2d() -> np.ndarray: + """Generate synthetic ERA5 specific humidity data (kg/kg) [H, W].""" + rng = np.random.default_rng(100) + # Humidity values are very small (0.001 to 0.01) + humidity = rng.uniform(0.0005, 0.015, size=(TEST_HEIGHT, TEST_WIDTH)) + return humidity.astype(np.float32) + + +@pytest.fixture +def era5_wind_u_2d() -> np.ndarray: + """Generate synthetic ERA5 u-wind component (m/s) [H, W].""" + rng = np.random.default_rng(100) + # Wind centreed around 0 with realistic variation + wind = rng.normal(0, 5, size=(TEST_HEIGHT, TEST_WIDTH)) + return wind.astype(np.float32) + + +@pytest.fixture +def osisaf_ice_conc_2d() -> np.ndarray: + """Generate synthetic OSISAF sea ice concentration data (fraction 0-1) [H, W].""" + rng = np.random.default_rng(100) + # Ice concentration between 0 and 1 + ice_conc = rng.uniform(0.0, 1.0, size=(TEST_HEIGHT, TEST_WIDTH)) + return ice_conc.astype(np.float32) + + +@pytest.fixture +def era5_temperature_3d(test_dates_short: list[date]) -> np.ndarray: + """Generate synthetic 3D temperature stream [T, H, W].""" + rng = np.random.default_rng(100) + n_timesteps = len(test_dates_short) + # Temperature evolving over time + data = np.zeros((n_timesteps, TEST_HEIGHT, TEST_WIDTH), dtype=np.float32) + for t in range(n_timesteps): + data[t] = 273.15 + rng.normal(0, 10, size=(TEST_HEIGHT, TEST_WIDTH)) + return data + + +@pytest.fixture +def multi_channel_data() -> tuple[list[np.ndarray], list[str]]: + """Generate multiple channels of raw input data.""" + rng = np.random.default_rng(100) + channels = [ + rng.uniform(270, 280, size=(TEST_HEIGHT, TEST_WIDTH)).astype( + np.float32 + ), # temperature + rng.normal(0, 5, size=(TEST_HEIGHT, TEST_WIDTH)).astype(np.float32), # u-wind + rng.normal(0, 5, size=(TEST_HEIGHT, TEST_WIDTH)).astype(np.float32), # v-wind + rng.uniform(0, 1, size=(TEST_HEIGHT, TEST_WIDTH)).astype( + np.float32 + ), # ice conc + ] + names = ["era5:2t", "era5:10u", "era5:10v", "osisaf-south:ice_conc"] + return channels, names + + +@pytest.fixture +def variable_styles() -> dict[str, dict[str, Any]]: + """Sample variable styling configuration for raw inputs.""" + return { + "era5:2t": { + "cmap": "RdBu_r", + "two_slope_centre": 273.15, + "units": "K", + "decimals": 1, + }, + "era5:10u": {"cmap": "RdBu_r", "two_slope_centre": 0.0, "units": "m/s"}, + "era5:10v": {"cmap": "RdBu_r", "two_slope_centre": 0.0, "units": "m/s"}, + "era5:q_10": {"cmap": "viridis", "decimals": 4, "units": "kg/kg"}, + "osisaf-south:ice_conc": {"cmap": "Blues_r"}, + } From 9be6ddc6e825030506f3cd2ae7e4cb6d28f34f67 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Sat, 22 Nov 2025 15:14:34 +0000 Subject: [PATCH 30/49] Tests for api updated --- tests/plotting/test_api.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/plotting/test_api.py b/tests/plotting/test_api.py index 5e5fea66..b670902b 100644 --- a/tests/plotting/test_api.py +++ b/tests/plotting/test_api.py @@ -88,14 +88,18 @@ def test_plot_maps_emits_warning_badge( @pytest.fixture def fake_save_animation(monkeypatch: pytest.MonkeyPatch) -> Callable[..., io.BytesIO]: - """Monkeypatch _save_animation so video_maps runs fast in tests.""" + """Monkeypatch save_animation so video_maps runs fast in tests.""" def _fake_save( - _anim: object, *, _fps: int, _video_format: str = "gif" + _anim: object, + *, + fps: int = 2, # noqa: ARG001 + video_format: str = "gif", # noqa: ARG001 ) -> io.BytesIO: + # Parameters match real save_animation signature but are unused in fake return io.BytesIO(b"fake-video-data") - monkeypatch.setattr(convert, "_save_animation", _fake_save) + monkeypatch.setattr(convert, "save_animation", _fake_save) return _fake_save @@ -134,10 +138,10 @@ def test_plot_maps_with_land_mask( ground_truth, prediction, date = sic_pair_2d height, width = ground_truth.shape - # Create a simple land mask with land in the center + # Create a simple land mask with land in the centre land_mask = np.zeros((height, width), dtype=bool) - center_h, center_w = height // 2, width // 2 - land_mask[center_h - 5 : center_h + 5, center_w - 5 : center_w + 5] = True + centre_h, centre_w = height // 2, width // 2 + land_mask[centre_h - 5 : centre_h + 5, centre_w - 5 : centre_w + 5] = True # Save land mask to temporary file with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as tmp_file: From afe72fb82d1e3cac825ce3881309dcb464907bb5 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Sat, 22 Nov 2025 15:16:07 +0000 Subject: [PATCH 31/49] Add layout tests for single panel --- tests/plotting/test_layout.py | 158 +++++++++++++++++++++++++++++++--- 1 file changed, 146 insertions(+), 12 deletions(-) diff --git a/tests/plotting/test_layout.py b/tests/plotting/test_layout.py index 2ebc823f..50ba3760 100644 --- a/tests/plotting/test_layout.py +++ b/tests/plotting/test_layout.py @@ -9,13 +9,15 @@ import matplotlib.pyplot as plt import pytest -from ice_station_zebra.visualisations.layout import _build_layout, _set_axes_limits -from ice_station_zebra.visualisations.plotting_maps import ( - DEFAULT_SIC_SPEC, - _draw_badge_with_box, - _set_footer_with_box, - _set_suptitle_with_box, +from ice_station_zebra.visualisations.layout import ( + _set_axes_limits, + build_layout, + build_single_panel_figure, + draw_badge_with_box, + set_footer_with_box, + set_suptitle_with_box, ) +from ice_station_zebra.visualisations.plotting_maps import DEFAULT_SIC_SPEC from .test_helper_plot_layout import axis_rectangle, rectangles_overlap @@ -59,7 +61,7 @@ def test_no_axes_overlap( colourbar_strategy=colourbar_strategy, # type: ignore[arg-type] ) - _, axes, colourbar_axes = _build_layout( + _, axes, colourbar_axes = build_layout( plot_spec=spec, height=ground_truth.shape[0], width=ground_truth.shape[1] ) @@ -96,7 +98,7 @@ def test_axes_have_reasonable_gaps( colourbar_strategy=colourbar_strategy, # type: ignore[arg-type] ) - _, axes, colourbar_axes = _build_layout( + _, axes, colourbar_axes = build_layout( plot_spec=spec, height=ground_truth.shape[0], width=ground_truth.shape[1] ) @@ -201,15 +203,15 @@ def test_figure_text_boxes_do_not_overlap( include_difference=include_difference, ) - fig, axes, caxes = _build_layout( + fig, axes, caxes = build_layout( plot_spec=spec, height=ground_truth.shape[0], width=ground_truth.shape[1] ) # Add figure-level title, warning badge (synthetic), and footer - title = _set_suptitle_with_box(fig, "Title") + title = set_suptitle_with_box(fig, "Title") ty = title.get_position()[1] - badge = _draw_badge_with_box(fig, 0.5, max(ty - 0.05, 0.0), "Warnings: example") - footer = _set_footer_with_box(fig, "Footer metadata") + badge = draw_badge_with_box(fig, 0.5, max(ty - 0.05, 0.0), "Warnings: example") + footer = set_footer_with_box(fig, "Footer metadata") # Collect rectangles: panels, colourbar axes, and figure texts rectangles = [axis_rectangle(ax) for ax in axes] @@ -223,3 +225,135 @@ def test_figure_text_boxes_do_not_overlap( assert not rectangles_overlap(rect_a, rect_b), ( f"Found overlap between rectangles {rect_a} and {rect_b}" ) + + +# --- Single Panel Layout Tests --- + + +@pytest.mark.parametrize("colourbar_location", ["horizontal", "vertical"]) +def test_single_panel_no_overlap( + era5_temperature_2d: np.ndarray, + *, + colourbar_location: str, +) -> None: + """Single panel layout: main panel and colourbar must not overlap.""" + height, width = era5_temperature_2d.shape + + fig, ax, cax = build_single_panel_figure( + height=height, + width=width, + colourbar_location=colourbar_location, # type: ignore[arg-type] + ) + + # Get rectangles for panel and colourbar + panel_rect = axis_rectangle(ax) + cbar_rect = axis_rectangle(cax) + + # They should not overlap + assert not rectangles_overlap(panel_rect, cbar_rect), ( + f"Panel and colourbar overlap: panel={panel_rect}, cbar={cbar_rect}" + ) + + plt.close(fig) + + +@pytest.mark.parametrize("colourbar_location", ["horizontal", "vertical"]) +def test_single_panel_has_reasonable_gap( + era5_temperature_2d: np.ndarray, + *, + colourbar_location: str, +) -> None: + """Single panel layout: require a minimum gap between panel and colourbar.""" + height, width = era5_temperature_2d.shape + + fig, ax, cax = build_single_panel_figure( + height=height, + width=width, + colourbar_location=colourbar_location, # type: ignore[arg-type] + ) + + panel_rect = axis_rectangle(ax) + cbar_rect = axis_rectangle(cax) + + pl, pb, pr, pt = panel_rect + cl, cb, cr, ct = cbar_rect + + if colourbar_location == "vertical": + # Vertical colorbar should be to the right of panel + gap = cl - pr + assert gap >= 0.005, ( + f"Expected ≥0.5% gap between panel and vertical colorbar, got {gap:.5f}" + ) + else: # horizontal + # Horizontal colorbar should be below panel + gap = pb - ct + assert gap >= 0.005, ( + f"Expected ≥0.5% gap between panel and horizontal colorbar, got {gap:.5f}" + ) + + plt.close(fig) + + +def test_single_panel_with_text_annotations( + era5_temperature_2d: np.ndarray, +) -> None: + """Single panel: title and annotations should not overlap panel or colourbar.""" + height, width = era5_temperature_2d.shape + + fig, ax, cax = build_single_panel_figure( + height=height, + width=width, + colourbar_location="vertical", + ) + + # Add text annotations + title = set_suptitle_with_box(fig, "Test Title") + footer = set_footer_with_box(fig, "Test Footer") + + # Collect all rectangles + panel_rect = axis_rectangle(ax) + cbar_rect = axis_rectangle(cax) + title_rect = _text_rectangle(fig, title) + footer_rect = _text_rectangle(fig, footer) + + rectangles = [panel_rect, cbar_rect, title_rect, footer_rect] + + # No overlaps + for rect_a, rect_b in combinations(rectangles, 2): + assert not rectangles_overlap(rect_a, rect_b), ( + f"Found overlap between {rect_a} and {rect_b}" + ) + + plt.close(fig) + + +@pytest.mark.parametrize( + ("height", "width"), + [ + (48, 48), # Square + (181, 720), # Wide (ERA5-like) + (432, 432), # Square (OSISAF-like) + (100, 200), # Wide + (200, 100), # Tall + ], +) +def test_single_panel_various_aspect_ratios( + height: int, + width: int, +) -> None: + """Single panel layout should handle various aspect ratios without overlap.""" + fig, ax, cax = build_single_panel_figure( + height=height, + width=width, + colourbar_location="vertical", + ) + + panel_rect = axis_rectangle(ax) + cbar_rect = axis_rectangle(cax) + + # No overlap + assert not rectangles_overlap(panel_rect, cbar_rect), ( + f"Overlap for {height}x{width}: panel={panel_rect}, cbar={cbar_rect}" + ) + + plt.close(fig) From a837b26d097ac9d30edee565d9f6e37a236ff66a Mon Sep 17 00:00:00 2001 From: Lydia France Date: Sat, 22 Nov 2025 15:22:51 +0000 Subject: [PATCH 32/49] Tests for raw plots --- tests/plotting/test_raw_inputs.py | 518 ++++++++++++++++++++++++++++++ 1 file changed, 518 insertions(+) create mode 100644 tests/plotting/test_raw_inputs.py diff --git a/tests/plotting/test_raw_inputs.py b/tests/plotting/test_raw_inputs.py new file mode 100644 index 00000000..d25ab90c --- /dev/null +++ b/tests/plotting/test_raw_inputs.py @@ -0,0 +1,518 @@ +"""Tests for raw input plotting functionality. + +This module tests both static plots and animations of raw input variables, +covering ERA5 weather data, OSISAF sea ice concentration, and various +styling configurations. +""" + +from __future__ import annotations + +import io +from datetime import date, timedelta +from typing import TYPE_CHECKING, Any + +import numpy as np +import pytest +from PIL import Image + +from ice_station_zebra.exceptions import InvalidArrayError +from ice_station_zebra.types import PlotSpec +from ice_station_zebra.visualisations.plotting_core import style_for_variable +from ice_station_zebra.visualisations.plotting_raw_inputs import ( + plot_raw_inputs_for_timestep, + video_raw_input_for_variable, + video_raw_inputs_for_timesteps, +) + +# Import test constants from conftest +from .conftest import TEST_DATE, TEST_HEIGHT, TEST_WIDTH + +if TYPE_CHECKING: + from pathlib import Path + +# --- Tests for Static Plotting --- + + +def test_plot_single_channel_basic( + era5_temperature_2d: np.ndarray, + base_plot_spec: PlotSpec, +) -> None: + """Test basic single channel plotting.""" + results = plot_raw_inputs_for_timestep( + channel_arrays=[era5_temperature_2d], + channel_names=["era5:2t"], + when=TEST_DATE, + plot_spec_base=base_plot_spec, + ) + + assert len(results) == 1 + name, pil_image, saved_path = results[0] + assert name == "era5:2t" + assert isinstance(pil_image, Image.Image) + assert saved_path is None # No save_dir provided + + +def test_plot_with_land_mask( + era5_temperature_2d: np.ndarray, + land_mask_2d: np.ndarray, + base_plot_spec: PlotSpec, +) -> None: + """Test plotting with land mask applied.""" + results = plot_raw_inputs_for_timestep( + channel_arrays=[era5_temperature_2d], + channel_names=["era5:2t"], + when=TEST_DATE, + plot_spec_base=base_plot_spec, + land_mask=land_mask_2d, + ) + + assert len(results) == 1 + name, pil_image, _ = results[0] + assert name == "era5:2t" + assert isinstance(pil_image, Image.Image) + + +def test_plot_with_custom_styles( + era5_temperature_2d: np.ndarray, + base_plot_spec: PlotSpec, + variable_styles: dict[str, dict[str, Any]], +) -> None: + """Test plotting with custom variable styling.""" + results = plot_raw_inputs_for_timestep( + channel_arrays=[era5_temperature_2d], + channel_names=["era5:2t"], + when=TEST_DATE, + plot_spec_base=base_plot_spec, + styles=variable_styles, + ) + + assert len(results) == 1 + name, pil_image, _ = results[0] + assert name == "era5:2t" + assert isinstance(pil_image, Image.Image) + + +def test_plot_multiple_channels( + multi_channel_data: tuple[list[np.ndarray], list[str]], + base_plot_spec: PlotSpec, +) -> None: + """Test plotting multiple channels at once.""" + channel_arrays, channel_names = multi_channel_data + + results = plot_raw_inputs_for_timestep( + channel_arrays=channel_arrays, + channel_names=channel_names, + when=TEST_DATE, + plot_spec_base=base_plot_spec, + ) + + assert len(results) == len(channel_names) + for (name, pil_image, _), expected_name in zip(results, channel_names, strict=True): + assert name == expected_name + assert isinstance(pil_image, Image.Image) + + +def test_plot_to_disk( + era5_temperature_2d: np.ndarray, + base_plot_spec: PlotSpec, + tmp_path: Path, +) -> None: + """Test saving plots to disk.""" + results = plot_raw_inputs_for_timestep( + channel_arrays=[era5_temperature_2d], + channel_names=["era5:2t"], + when=TEST_DATE, + plot_spec_base=base_plot_spec, + save_dir=tmp_path, + ) + + assert len(results) == 1 + _, _, saved_path = results[0] + assert saved_path is not None + assert saved_path.exists() + assert saved_path.suffix == ".png" + + +@pytest.mark.parametrize("colourbar_location", ["vertical", "horizontal"]) +def test_plot_colourbar_locations( + era5_temperature_2d: np.ndarray, + colourbar_location: str, +) -> None: + """Test plotting with different colorbar orientations.""" + plot_spec = PlotSpec( + variable="raw_inputs", + colourmap="viridis", + colourbar_location=colourbar_location, + ) + + results = plot_raw_inputs_for_timestep( + channel_arrays=[era5_temperature_2d], + channel_names=["era5:2t"], + when=TEST_DATE, + plot_spec_base=plot_spec, + ) + + assert len(results) == 1 + + +@pytest.mark.parametrize( + ("var_name", "fixture_name"), + [ + ("era5:2t", "era5_temperature_2d"), + ("era5:q_10", "era5_humidity_2d"), + ("era5:10u", "era5_wind_u_2d"), + ("osisaf-south:ice_conc", "osisaf_ice_conc_2d"), + ], +) +def test_plot_different_variables( + var_name: str, + fixture_name: str, + base_plot_spec: PlotSpec, + variable_styles: dict[str, dict[str, Any]], + request: pytest.FixtureRequest, +) -> None: + """Test plotting different types of variables with appropriate styling.""" + data = request.getfixturevalue(fixture_name) + + results = plot_raw_inputs_for_timestep( + channel_arrays=[data], + channel_names=[var_name], + when=TEST_DATE, + plot_spec_base=base_plot_spec, + styles=variable_styles, + ) + + assert len(results) == 1 + name, pil_image, _ = results[0] + assert name == var_name + assert isinstance(pil_image, Image.Image) + + +# --- Tests for Style Resolution --- + + +def test_style_for_variable_exact_match( + variable_styles: dict[str, dict[str, Any]], +) -> None: + """Test exact variable name matching in styling.""" + style = style_for_variable("era5:2t", variable_styles) + + assert style.cmap == "RdBu_r" + assert style.two_slope_centre == 273.15 + assert style.units == "K" + assert style.decimals == 1 + + +def test_style_for_variable_wildcard_match( + variable_styles: dict[str, dict[str, Any]], +) -> None: + """Test wildcard pattern matching in styling.""" + # Add wildcard pattern + styles_with_wildcard = { + **variable_styles, + "era5:q_*": {"cmap": "viridis", "decimals": 4}, + } + + style = style_for_variable("era5:q_500", styles_with_wildcard) + + assert style.cmap == "viridis" + assert style.decimals == 4 + + +def test_style_for_variable_no_match() -> None: + """Test default styling when no match found.""" + style = style_for_variable("unknown:variable", {}) + + # Should use defaults + assert style.cmap is None + assert style.decimals is None + + +# --- Tests for Animations --- + + +def test_video_single_variable_basic( + era5_temperature_3d: np.ndarray, + test_dates_short: list[date], + base_plot_spec: PlotSpec, +) -> None: + """Test basic video creation for a single variable.""" + video_buffer = video_raw_input_for_variable( + variable_name="era5:2t", + data_stream=era5_temperature_3d, + dates=test_dates_short, + plot_spec_base=base_plot_spec, + fps=2, + video_format="gif", + ) + + assert isinstance(video_buffer, io.BytesIO) + video_buffer.seek(0) + assert len(video_buffer.read()) > 1000 # Reasonable GIF size + + +def test_video_with_land_mask( + era5_temperature_3d: np.ndarray, + test_dates_short: list[date], + land_mask_2d: np.ndarray, + base_plot_spec: PlotSpec, +) -> None: + """Test video creation with land mask.""" + video_buffer = video_raw_input_for_variable( + variable_name="era5:2t", + data_stream=era5_temperature_3d, + dates=test_dates_short, + plot_spec_base=base_plot_spec, + land_mask=land_mask_2d, + fps=2, + video_format="gif", + ) + + assert isinstance(video_buffer, io.BytesIO) + video_buffer.seek(0) + assert len(video_buffer.read()) > 1000 + + +def test_video_with_custom_style( + era5_temperature_3d: np.ndarray, + test_dates_short: list[date], + base_plot_spec: PlotSpec, + variable_styles: dict[str, dict[str, Any]], +) -> None: + """Test video creation with custom styling.""" + video_buffer = video_raw_input_for_variable( + variable_name="era5:2t", + data_stream=era5_temperature_3d, + dates=test_dates_short, + plot_spec_base=base_plot_spec, + styles=variable_styles, + fps=2, + video_format="gif", + ) + + assert isinstance(video_buffer, io.BytesIO) + video_buffer.seek(0) + assert len(video_buffer.read()) > 1000 + + +def test_video_save_to_disk( + era5_temperature_3d: np.ndarray, + test_dates_short: list[date], + base_plot_spec: PlotSpec, + tmp_path: Path, +) -> None: + """Test saving video to disk.""" + save_path = tmp_path / "test_animation.gif" + + video_buffer = video_raw_input_for_variable( + variable_name="era5:2t", + data_stream=era5_temperature_3d, + dates=test_dates_short, + plot_spec_base=base_plot_spec, + save_path=save_path, + fps=2, + video_format="gif", + ) + + assert isinstance(video_buffer, io.BytesIO) + assert save_path.exists() + assert save_path.stat().st_size > 1000 + + +@pytest.mark.parametrize("video_format", ["gif", "mp4"]) +def test_video_formats( + era5_temperature_3d: np.ndarray, + test_dates_short: list[date], + base_plot_spec: PlotSpec, + video_format: str, +) -> None: + """Test creating videos in different formats.""" + video_buffer = video_raw_input_for_variable( + variable_name="era5:2t", + data_stream=era5_temperature_3d, + dates=test_dates_short, + plot_spec_base=base_plot_spec, + fps=2, + video_format=video_format, + ) + + assert isinstance(video_buffer, io.BytesIO) + video_buffer.seek(0) + content = video_buffer.read() + assert len(content) > 1000 + + # Check file signature + if video_format == "gif": + assert content[:6] == b"GIF89a" # GIF header + elif video_format == "mp4": + # MP4 typically has ftyp box early + assert b"ftyp" in content[:100] + + +def test_video_multiple_variables( + test_dates_short: list[date], + base_plot_spec: PlotSpec, +) -> None: + """Test batch video creation for multiple variables.""" + rng = np.random.default_rng(100) + n_timesteps = len(test_dates_short) + + # Create data streams for multiple variables + channel_arrays_stream = [ + rng.uniform(270, 280, size=(n_timesteps, TEST_HEIGHT, TEST_WIDTH)).astype( + np.float32 + ), + rng.normal(0, 5, size=(n_timesteps, TEST_HEIGHT, TEST_WIDTH)).astype( + np.float32 + ), + rng.uniform(0, 1, size=(n_timesteps, TEST_HEIGHT, TEST_WIDTH)).astype( + np.float32 + ), + ] + channel_names = ["era5:2t", "era5:10u", "osisaf-south:ice_conc"] + + results = video_raw_inputs_for_timesteps( + channel_arrays_stream=channel_arrays_stream, + channel_names=channel_names, + dates=test_dates_short, + plot_spec_base=base_plot_spec, + fps=2, + video_format="gif", + ) + + assert len(results) == 3 + for (name, video_buffer, saved_path), expected_name in zip( + results, channel_names, strict=True + ): + assert name == expected_name + assert isinstance(video_buffer, io.BytesIO) + assert saved_path is None # No save_dir provided + + +# --- Error Handling Tests --- + + +def test_plot_mismatched_arrays_and_names( + era5_temperature_2d: np.ndarray, + base_plot_spec: PlotSpec, +) -> None: + """Test error when arrays and names count mismatch.""" + with pytest.raises(InvalidArrayError, match="Channels count mismatch"): + plot_raw_inputs_for_timestep( + channel_arrays=[era5_temperature_2d, era5_temperature_2d], + channel_names=["era5:2t"], # Only one name for two arrays + when=TEST_DATE, + plot_spec_base=base_plot_spec, + ) + + +def test_plot_wrong_dimension( + base_plot_spec: PlotSpec, +) -> None: + """Test error when input array is not 2D.""" + rng = np.random.default_rng(42) + wrong_dim_array = rng.random((5, 5, 5)).astype(np.float32) # 3D instead of 2D + + with pytest.raises(InvalidArrayError, match="Expected 2D"): + plot_raw_inputs_for_timestep( + channel_arrays=[wrong_dim_array], + channel_names=["era5:2t"], + when=TEST_DATE, + plot_spec_base=base_plot_spec, + ) + + +def test_video_wrong_dimension( + test_dates_short: list[date], + base_plot_spec: PlotSpec, +) -> None: + """Test error when video data is not 3D.""" + rng = np.random.default_rng(42) + wrong_dim_array = rng.random((5, 5)).astype(np.float32) # 2D instead of 3D + + with pytest.raises(InvalidArrayError, match="Expected 3D"): + video_raw_input_for_variable( + variable_name="era5:2t", + data_stream=wrong_dim_array, + dates=test_dates_short, + plot_spec_base=base_plot_spec, + ) + + +def test_video_mismatched_dates( + era5_temperature_3d: np.ndarray, + base_plot_spec: PlotSpec, +) -> None: + """Test error when number of dates doesn't match timesteps.""" + wrong_dates = [ + TEST_DATE, + TEST_DATE + timedelta(days=1), + ] # Only 2 dates for 4 timesteps + + with pytest.raises( + InvalidArrayError, match="Number of dates.*!= number of timesteps" + ): + video_raw_input_for_variable( + variable_name="era5:2t", + data_stream=era5_temperature_3d, + dates=wrong_dates, + plot_spec_base=base_plot_spec, + ) + + +# --- Integration Tests --- + + +def test_full_workflow_static_and_video( + multi_channel_data: tuple[list[np.ndarray], list[str]], + test_dates_short: list[date], + base_plot_spec: PlotSpec, + variable_styles: dict[str, dict[str, Any]], + tmp_path: Path, +) -> None: + """Test complete workflow: static plots + videos for multiple variables.""" + channel_arrays, channel_names = multi_channel_data + + # 1. Create static plots + static_results = plot_raw_inputs_for_timestep( + channel_arrays=channel_arrays, + channel_names=channel_names, + when=TEST_DATE, + plot_spec_base=base_plot_spec, + styles=variable_styles, + save_dir=tmp_path / "static", + ) + + assert len(static_results) == len(channel_names) + for _, _, saved_path in static_results: + assert saved_path is not None + assert saved_path.exists() + + # 2. Create 3D data streams + rng = np.random.default_rng(100) + n_timesteps = len(test_dates_short) + channel_arrays_stream = [ + rng.uniform(arr.min(), arr.max(), size=(n_timesteps, *arr.shape)).astype( + np.float32 + ) + for arr in channel_arrays + ] + + # 3. Create videos + video_results = video_raw_inputs_for_timesteps( + channel_arrays_stream=channel_arrays_stream, + channel_names=channel_names, + dates=test_dates_short, + plot_spec_base=base_plot_spec, + styles=variable_styles, + fps=2, + video_format="gif", + save_dir=tmp_path / "videos", + ) + + assert len(video_results) == len(channel_names) + for _, video_buffer, saved_path in video_results: + assert isinstance(video_buffer, io.BytesIO) + assert saved_path is not None + assert saved_path.exists() + assert saved_path.suffix == ".gif" From 74e7062bfa16ea53ff6b6f78712bdb7b7d9ebd34 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Sat, 22 Nov 2025 15:47:21 +0000 Subject: [PATCH 33/49] Fix path name --- ice_station_zebra/config/evaluate/callbacks/raw_inputs.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ice_station_zebra/config/evaluate/callbacks/raw_inputs.yaml b/ice_station_zebra/config/evaluate/callbacks/raw_inputs.yaml index cd6eee33..6d355515 100644 --- a/ice_station_zebra/config/evaluate/callbacks/raw_inputs.yaml +++ b/ice_station_zebra/config/evaluate/callbacks/raw_inputs.yaml @@ -3,14 +3,14 @@ raw_inputs: _target_: ice_station_zebra.callbacks.RawInputsCallback frequency: null # Plot once; set N to repeat every N batches timestep_index: 0 # Which history timestep to plot (0 = most recent) - save_dir: ./raw_input_plots # Directory to save plots (null to skip saving to disk) + save_dir: ./data/raw_input_plots # Directory to save plots (null to skip saving to disk) log_to_wandb: true # Whether to log plots to WandB (if false and save_dir null, nothing saved) # Animation configuration (following PlottingCallback pattern) make_video_plots: false # Set to true to create temporal animations video_fps: 2 # Frames per second for animations video_format: gif # mp4 or gif - video_save_dir: ./raw_input_animations # Directory to save animations (null to skip disk save) + video_save_dir: ./data/raw_input_animations # Directory to save animations (null to skip disk save) max_animation_frames: null # Limit frames (null = unlimited; 30 ≈ 1 month daily data) # Plot specification (colourmap, hemisphere, etc.) From d66ac301e846aff5d3c2d8bf09bdc7ab531c7af5 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Sat, 22 Nov 2025 15:56:02 +0000 Subject: [PATCH 34/49] Add CLI for plotting raw input data --- README.md | 16 + ice_station_zebra/cli/main.py | 2 + ice_station_zebra/visualisations/__init__.py | 2 + ice_station_zebra/visualisations/cli.py | 353 +++++++++++++++++++ 4 files changed, 373 insertions(+) create mode 100644 ice_station_zebra/visualisations/cli.py diff --git a/README.md b/README.md index 605178a6..080f9c3c 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,22 @@ Run `uv run zebra train` to train using the datasets specified in the config. Run `uv run zebra evaluate --checkpoint PATH_TO_A_CHECKPOINT` to evaluate using a checkpoint from a training run. +### Visualisations + +Plot raw input variables from the test dataset: + +```bash +uv run zebra visualisations plot-raw-inputs --config-name .yaml --sample-idx 0 +``` + +Create animations of raw inputs over time: + +```bash +uv run zebra visualisations animate-raw-inputs --config-name .yaml +``` + +Settings (output directories, styling, animation parameters) are read from `config.evaluate.callbacks.raw_inputs` in your YAML config files. Command-line options can override config values if needed. + ## Adding a new model ### Background diff --git a/ice_station_zebra/cli/main.py b/ice_station_zebra/cli/main.py index 6ed99b86..196e6cb9 100644 --- a/ice_station_zebra/cli/main.py +++ b/ice_station_zebra/cli/main.py @@ -4,6 +4,7 @@ from ice_station_zebra.data_processors import datasets_cli from ice_station_zebra.evaluation import evaluation_cli from ice_station_zebra.training import training_cli +from ice_station_zebra.visualisations import visualisations_cli # Configure hydra logging simple_stdout_log_config() @@ -17,6 +18,7 @@ app.add_typer(datasets_cli, name="datasets") app.add_typer(evaluation_cli) app.add_typer(training_cli) +app.add_typer(visualisations_cli, name="visualisations") if __name__ == "__main__": diff --git a/ice_station_zebra/visualisations/__init__.py b/ice_station_zebra/visualisations/__init__.py index ea88bf9a..e106afdc 100644 --- a/ice_station_zebra/visualisations/__init__.py +++ b/ice_station_zebra/visualisations/__init__.py @@ -6,6 +6,7 @@ from ice_station_zebra.types import PlotSpec +from .cli import visualisations_cli from .plotting_core import detect_land_mask_path from .plotting_maps import ( DEFAULT_SIC_SPEC, @@ -21,4 +22,5 @@ "detect_land_mask_path", "plot_maps", "video_maps", + "visualisations_cli", ] diff --git a/ice_station_zebra/visualisations/cli.py b/ice_station_zebra/visualisations/cli.py new file mode 100644 index 00000000..71b309e5 --- /dev/null +++ b/ice_station_zebra/visualisations/cli.py @@ -0,0 +1,353 @@ +"""CLI commands for visualisation tasks.""" + +import logging +from pathlib import Path +from typing import Annotated, Any + +import hydra +import numpy as np +import typer +from omegaconf import DictConfig, OmegaConf + +from ice_station_zebra.callbacks.raw_inputs_callback import RawInputsCallback +from ice_station_zebra.cli import hydra_adaptor +from ice_station_zebra.data_loaders import CombinedDataset, ZebraDataModule +from ice_station_zebra.visualisations.plotting_raw_inputs import ( + plot_raw_inputs_for_timestep, + video_raw_inputs_for_timesteps, +) + +# Create the typer app +visualisations_cli = typer.Typer(help="Visualisation commands") + +log = logging.getLogger(__name__) + + +def _extract_channel_arrays( + batch: dict[str, Any], test_dataset: CombinedDataset, timestep_idx: int = 0 +) -> list[np.ndarray]: + """Extract channel arrays from a dataset item. + + Args: + batch: Dataset item (returns 4D arrays [T, C, H, W]). + test_dataset: Test dataset instance. + timestep_idx: Index of timestep to extract (default: 0). + + Returns: + List of 2D channel arrays. + + """ + channel_arrays = [] + for ds in test_dataset.inputs: + if ds.name not in batch: + log.warning("Dataset %s not found in batch", ds.name) + continue + + input_data = batch[ds.name] # Shape: [T, C, H, W] - numpy array + timestep_data = input_data[timestep_idx] # Shape: [C, H, W] + channel_arrays.extend([timestep_data[c] for c in range(timestep_data.shape[0])]) + + return channel_arrays + + +def _collect_temporal_data( + test_dataset: CombinedDataset, start_idx: int, n_frames: int +) -> tuple[list[np.ndarray], list[str], list[Any]]: + """Collect temporal data for animation. + + Args: + test_dataset: Test dataset instance. + start_idx: Starting sample index. + n_frames: Number of frames to collect. + + Returns: + Tuple of (data_streams, channel_names, dates). + + """ + channel_names = test_dataset.input_variable_names + n_vars = len(channel_names) + + # Initialize data collection + temporal_data_per_var = [[] for _ in range(n_vars)] + dates = [] + + for idx in range(start_idx, start_idx + n_frames): + batch = test_dataset[idx] + date = test_dataset.date_from_index(idx) + + # Collect data from all input datasets + var_idx = 0 + for ds in test_dataset.inputs: + if ds.name not in batch: + log.warning("Dataset %s not found in batch", ds.name) + continue + + input_data = batch[ds.name] # Shape: [T, C, H, W] + timestep_data = input_data[0] # Take first timestep, Shape: [C, H, W] + + # Add each channel + for c in range(timestep_data.shape[0]): + var_data = timestep_data[c] + temporal_data_per_var[var_idx].append(var_data) + var_idx += 1 + + dates.append(date) + + # Stack into list of 3D arrays [T, H, W] for each variable + data_streams = [ + np.stack(var_frames, axis=0) for var_frames in temporal_data_per_var + ] + + return data_streams, channel_names, dates + + +@visualisations_cli.command() +@hydra_adaptor +def plot_raw_inputs( + config: DictConfig, + sample_idx: Annotated[ + int, + typer.Option(help="Index of the sample to plot (default: 0)"), + ] = 0, + output_dir: Annotated[ + str | None, + typer.Option(help="Directory to save plots (overrides config if provided)"), + ] = None, +) -> None: + r"""Plot raw inputs for a single timestep from the test dataset. + + This command creates static plots of all input variables for a single sample + from the test dataset. Settings are read from config.evaluate.callbacks.raw_inputs + in your YAML config files. + + Args: + config: Hydra config (provided via --config-name option). + sample_idx: Index of the sample to plot (default: 0). + output_dir: Directory to save plots (overrides config if provided). + + Note: + You must specify --config-name to use your local config file (e.g., lfrance.local.yaml). + Settings are read from config.evaluate.callbacks.raw_inputs in that config. + + Example: + uv run zebra visualisations plot-raw-inputs \\ + --config-name lfrance.local.yaml \\ + --sample-idx 0 + + """ + # Instantiate callback from config to get all settings + raw_inputs_cfg = config.get("evaluate", {}).get("callbacks", {}).get("raw_inputs") + if raw_inputs_cfg is None: + # Create minimal config if not found + raw_inputs_cfg = {} + callback = ( + hydra.utils.instantiate(raw_inputs_cfg) + if raw_inputs_cfg + else RawInputsCallback() + ) + callback.config = OmegaConf.to_container(config, resolve=True) + + # Create data module and prepare data + data_module = ZebraDataModule(config) + data_module.prepare_data() + data_module.setup("test") + + # Get the test dataset + test_dataloader = data_module.test_dataloader() + test_dataset = test_dataloader.dataset + if test_dataset is None: + msg = "No test dataset available!" + raise ValueError(msg) + + log.info("Test dataset has %d samples", len(test_dataset)) + + # Validate sample index + if sample_idx < 0 or sample_idx >= len(test_dataset): + msg = f"Sample index {sample_idx} out of range [0, {len(test_dataset)})" + raise ValueError(msg) + + # Get a sample from the dataset + batch = test_dataset[sample_idx] + date = test_dataset.date_from_index(sample_idx) + + log.info("Plotting raw inputs for date: %s", date) + + # Extract channel arrays + channel_arrays = _extract_channel_arrays(batch, test_dataset) + channel_names = test_dataset.input_variable_names + log.info("Total channels: %d", len(channel_names)) + + # Use callback's settings (with command-line override for output_dir) + plot_spec = callback.plot_spec + variable_styles = callback.variable_styles + save_dir = ( + Path(output_dir) + if output_dir + else callback.save_dir or Path("./data/raw_input_plots") + ) + + # Plot the raw inputs + results = plot_raw_inputs_for_timestep( + channel_arrays=channel_arrays, + channel_names=channel_names, + when=date, + plot_spec_base=plot_spec, + land_mask=None, # Will be auto-loaded if available + styles=variable_styles, + save_dir=save_dir, + ) + + log.info("Successfully plotted %d variables to %s", len(results), save_dir) + for var_name, _pil_img, saved_path in results: + if saved_path: + log.info(" - %s: %s", var_name, saved_path) + + +@visualisations_cli.command() +@hydra_adaptor +def animate_raw_inputs( + config: DictConfig, + n_frames: Annotated[ + int | None, + typer.Option( + help="Number of frames to include in animation (overrides config if provided)" + ), + ] = None, + output_dir: Annotated[ + str | None, + typer.Option( + help="Directory to save animations (overrides config if provided)" + ), + ] = None, + fps: Annotated[ + int | None, + typer.Option( + help="Frames per second for animation (overrides config if provided)" + ), + ] = None, + video_format: Annotated[ + str | None, + typer.Option( + help="Video format: 'gif' or 'mp4' (overrides config if provided)" + ), + ] = None, + start_idx: Annotated[ + int, + typer.Option(help="Starting sample index (default: 0)"), + ] = 0, +) -> None: + r"""Create animations of raw inputs over time from the test dataset. + + This command creates temporal animations showing how individual input variables + evolve over time. Settings are read from config.evaluate.callbacks.raw_inputs + in your YAML config files. + + Args: + config: Hydra config (provided via --config-name option). + n_frames: Number of frames to include in animation (overrides config if provided). + output_dir: Directory to save animations (overrides config if provided). + fps: Frames per second for animation (overrides config if provided). + video_format: Video format: 'gif' or 'mp4' (overrides config if provided). + start_idx: Starting sample index (default: 0). + + Note: + You must specify --config-name to use your local config file (e.g., lfrance.local.yaml). + Settings are read from config.evaluate.callbacks.raw_inputs in that config. + + Example: + uv run zebra visualisations animate-raw-inputs \\ + --config-name lfrance.local.yaml \\ + --start-idx 0 + + """ + # Instantiate callback from config to get all settings + raw_inputs_cfg = config.get("evaluate", {}).get("callbacks", {}).get("raw_inputs") + if raw_inputs_cfg is None: + raw_inputs_cfg = {} + callback = ( + hydra.utils.instantiate(raw_inputs_cfg) + if raw_inputs_cfg + else RawInputsCallback() + ) + callback.config = OmegaConf.to_container(config, resolve=True) + + # Get settings from callback (with command-line overrides) + n_frames = n_frames or callback.max_animation_frames or 30 + fps = fps or callback.video_fps + video_format = video_format or callback.video_format + save_dir = ( + Path(output_dir) + if output_dir + else callback.video_save_dir or Path("./data/raw_input_animations") + ) + + # Validate video format + if video_format not in ("gif", "mp4"): + msg = f"Video format must be 'gif' or 'mp4', got '{video_format}'" + raise ValueError(msg) + + # Create data module and prepare data + data_module = ZebraDataModule(config) + data_module.prepare_data() + data_module.setup("test") + + # Get the test dataset + test_dataloader = data_module.test_dataloader() + test_dataset = test_dataloader.dataset + if test_dataset is None: + msg = "No test dataset available!" + raise ValueError(msg) + + log.info("Test dataset has %d samples", len(test_dataset)) + + # Determine number of frames + max_frames = len(test_dataset) - start_idx + n_frames = min(n_frames, max_frames) + if n_frames <= 0: + msg = f"Not enough samples (start_idx={start_idx}, dataset_size={len(test_dataset)})" + raise ValueError(msg) + + log.info( + "Creating animations with %d frames starting from index %d", n_frames, start_idx + ) + + # Collect temporal data for all variables + data_streams, channel_names, dates = _collect_temporal_data( + test_dataset, start_idx, n_frames + ) + + log.info( + "Collected data streams: %d variables x %s", + len(data_streams), + data_streams[0].shape if data_streams else "N/A", + ) + + # Use callback's settings + plot_spec = callback.plot_spec + variable_styles = callback.variable_styles + + # Create animations for all variables + log.info("Creating batch animations...") + results = video_raw_inputs_for_timesteps( + channel_arrays_stream=data_streams, + channel_names=channel_names, + dates=dates, + plot_spec_base=plot_spec, + styles=variable_styles, + fps=fps, + video_format=video_format, # type: ignore[arg-type] + save_dir=save_dir, + ) + + log.info("Successfully created %d animations:", len(results)) + for var_name, video_buffer, save_path in results: + log.info( + " - %s: %s (size: %d bytes)", + var_name, + save_path, + len(video_buffer.getvalue()), + ) + + +if __name__ == "__main__": + visualisations_cli() From a2cf47a5fd1486dfa8ed8c77274c12b74fd0840b Mon Sep 17 00:00:00 2001 From: Lydia France Date: Sat, 22 Nov 2025 16:10:00 +0000 Subject: [PATCH 35/49] Fix type checks for mypy --- ice_station_zebra/visualisations/cli.py | 16 +++++++++------- tests/plotting/test_raw_inputs.py | 6 +++--- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/ice_station_zebra/visualisations/cli.py b/ice_station_zebra/visualisations/cli.py index 71b309e5..cac44cad 100644 --- a/ice_station_zebra/visualisations/cli.py +++ b/ice_station_zebra/visualisations/cli.py @@ -2,7 +2,7 @@ import logging from pathlib import Path -from typing import Annotated, Any +from typing import Annotated, Any, cast import hydra import numpy as np @@ -68,8 +68,8 @@ def _collect_temporal_data( n_vars = len(channel_names) # Initialize data collection - temporal_data_per_var = [[] for _ in range(n_vars)] - dates = [] + temporal_data_per_var: list[list[np.ndarray]] = [[] for _ in range(n_vars)] + dates: list[Any] = [] for idx in range(start_idx, start_idx + n_frames): batch = test_dataset[idx] @@ -154,10 +154,11 @@ def plot_raw_inputs( # Get the test dataset test_dataloader = data_module.test_dataloader() - test_dataset = test_dataloader.dataset - if test_dataset is None: + test_dataset_raw = test_dataloader.dataset + if test_dataset_raw is None: msg = "No test dataset available!" raise ValueError(msg) + test_dataset = cast("CombinedDataset", test_dataset_raw) log.info("Test dataset has %d samples", len(test_dataset)) @@ -293,10 +294,11 @@ def animate_raw_inputs( # Get the test dataset test_dataloader = data_module.test_dataloader() - test_dataset = test_dataloader.dataset - if test_dataset is None: + test_dataset_raw = test_dataloader.dataset + if test_dataset_raw is None: msg = "No test dataset available!" raise ValueError(msg) + test_dataset = cast("CombinedDataset", test_dataset_raw) log.info("Test dataset has %d samples", len(test_dataset)) diff --git a/tests/plotting/test_raw_inputs.py b/tests/plotting/test_raw_inputs.py index d25ab90c..347f83ce 100644 --- a/tests/plotting/test_raw_inputs.py +++ b/tests/plotting/test_raw_inputs.py @@ -9,7 +9,7 @@ import io from datetime import date, timedelta -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import numpy as np import pytest @@ -136,7 +136,7 @@ def test_plot_to_disk( @pytest.mark.parametrize("colourbar_location", ["vertical", "horizontal"]) def test_plot_colourbar_locations( era5_temperature_2d: np.ndarray, - colourbar_location: str, + colourbar_location: Literal["vertical", "horizontal"], ) -> None: """Test plotting with different colorbar orientations.""" plot_spec = PlotSpec( @@ -324,7 +324,7 @@ def test_video_formats( era5_temperature_3d: np.ndarray, test_dates_short: list[date], base_plot_spec: PlotSpec, - video_format: str, + video_format: Literal["gif", "mp4"], ) -> None: """Test creating videos in different formats.""" video_buffer = video_raw_input_for_variable( From 545de3d3a35a2ca7c6e6e5be1b5faebdeda32612 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Thu, 27 Nov 2025 11:08:36 +0000 Subject: [PATCH 36/49] Fix circular import --- ice_station_zebra/callbacks/raw_inputs_callback.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ice_station_zebra/callbacks/raw_inputs_callback.py b/ice_station_zebra/callbacks/raw_inputs_callback.py index 35b470c8..eebdd270 100644 --- a/ice_station_zebra/callbacks/raw_inputs_callback.py +++ b/ice_station_zebra/callbacks/raw_inputs_callback.py @@ -21,8 +21,11 @@ from ice_station_zebra.data_loaders import CombinedDataset from ice_station_zebra.exceptions import InvalidArrayError, VideoRenderError from ice_station_zebra.types import PlotSpec -from ice_station_zebra.visualisations import DEFAULT_SIC_SPEC, detect_land_mask_path -from ice_station_zebra.visualisations.plotting_core import safe_filename +from ice_station_zebra.visualisations.plotting_core import ( + detect_land_mask_path, + safe_filename, +) +from ice_station_zebra.visualisations.plotting_maps import DEFAULT_SIC_SPEC from ice_station_zebra.visualisations.plotting_raw_inputs import ( plot_raw_inputs_for_timestep, video_raw_inputs_for_timesteps, From 3122f5f3a2925d2abdad15ca0773f5971293862a Mon Sep 17 00:00:00 2001 From: Lydia France Date: Thu, 27 Nov 2025 11:13:57 +0000 Subject: [PATCH 37/49] fix circular import --- ice_station_zebra/cli/main.py | 2 +- ice_station_zebra/visualisations/__init__.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/ice_station_zebra/cli/main.py b/ice_station_zebra/cli/main.py index 196e6cb9..877653b8 100644 --- a/ice_station_zebra/cli/main.py +++ b/ice_station_zebra/cli/main.py @@ -4,7 +4,7 @@ from ice_station_zebra.data_processors import datasets_cli from ice_station_zebra.evaluation import evaluation_cli from ice_station_zebra.training import training_cli -from ice_station_zebra.visualisations import visualisations_cli +from ice_station_zebra.visualisations.cli import visualisations_cli # Configure hydra logging simple_stdout_log_config() diff --git a/ice_station_zebra/visualisations/__init__.py b/ice_station_zebra/visualisations/__init__.py index e106afdc..ea88bf9a 100644 --- a/ice_station_zebra/visualisations/__init__.py +++ b/ice_station_zebra/visualisations/__init__.py @@ -6,7 +6,6 @@ from ice_station_zebra.types import PlotSpec -from .cli import visualisations_cli from .plotting_core import detect_land_mask_path from .plotting_maps import ( DEFAULT_SIC_SPEC, @@ -22,5 +21,4 @@ "detect_land_mask_path", "plot_maps", "video_maps", - "visualisations_cli", ] From ff11b2e56a4ff275ff6f779d1a5a5d36637bca36 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Thu, 27 Nov 2025 11:45:30 +0000 Subject: [PATCH 38/49] Fix issues for ruff --- .../visualisations/plotting_maps.py | 162 +++++++++++------- 1 file changed, 96 insertions(+), 66 deletions(-) diff --git a/ice_station_zebra/visualisations/plotting_maps.py b/ice_station_zebra/visualisations/plotting_maps.py index 33be5820..f74efe8e 100644 --- a/ice_station_zebra/visualisations/plotting_maps.py +++ b/ice_station_zebra/visualisations/plotting_maps.py @@ -19,6 +19,8 @@ from matplotlib import animation from matplotlib.axes import Axes from matplotlib.colors import ListedColormap +from matplotlib.figure import Figure +from matplotlib.text import Text from PIL.ImageFile import ImageFile from ice_station_zebra.exceptions import InvalidArrayError @@ -88,6 +90,86 @@ def _safe_linspace(vmin: float, vmax: float, n: int) -> np.ndarray: return np.linspace(vmin, vmax, n) +def _prepare_static_plot( + plot_spec: PlotSpec, + ground_truth: np.ndarray, + prediction: np.ndarray, +) -> tuple[int, int, np.ndarray | None, LayoutConfig | None, list[str], np.ndarray]: + """Validate arrays and compute helpers needed for plotting. + + Returns: + Tuple of (height, width, land_mask, layout_config, warnings, contour levels). + + """ + height, width = validate_2d_pair(ground_truth, prediction) + land_mask = load_land_mask(plot_spec.land_mask_path, (height, width)) + + (gt_min, gt_max), (_pred_min, _pred_max) = compute_display_ranges( + ground_truth, prediction, plot_spec + ) + range_check_report = compute_range_check_report( + ground_truth, + prediction, + vmin=gt_min, + vmax=gt_max, + outside_warn=getattr(plot_spec, "outside_warn", 0.05), + severe_outside=getattr(plot_spec, "severe_outside", 0.20), + include_shared_range_mismatch_check=getattr( + plot_spec, "include_shared_range_mismatch_check", True + ), + ) + warnings = range_check_report.warnings + layout_config = None + if warnings: + layout_config = LayoutConfig(title_footer=TitleFooterConfig(title_space=0.10)) + + levels = levels_from_spec(plot_spec) + return height, width, land_mask, layout_config, warnings, levels + + +def _prepare_difference( + plot_spec: PlotSpec, + ground_truth: np.ndarray, + prediction: np.ndarray, +) -> tuple[np.ndarray | None, DiffColourmapSpec | ListedColormap | None]: + """Compute difference arrays and colour scales if requested.""" + if not plot_spec.include_difference: + return None, None + difference = compute_difference(ground_truth, prediction, plot_spec.diff_mode) + diff_colour_scale = make_diff_colourmap(difference, mode=plot_spec.diff_mode) + return difference, diff_colour_scale + + +def _draw_warning_badge( + fig: Figure, + title_text: Text | None, + warnings: Sequence[str], +) -> None: + """Render a warning badge close to the title.""" + if not warnings: + return + badge = "Warnings: " + ", ".join(warnings) + if title_text is not None: + _, title_y = title_text.get_position() + n_lines = title_text.get_text().count("\n") + 1 + warning_y = max(title_y - (0.05 + 0.02 * (n_lines - 1)), 0.0) + else: + warning_y = 0.90 + draw_badge_with_box(fig, 0.5, warning_y, badge) + + +def _maybe_add_footer(fig: Figure, plot_spec: PlotSpec) -> None: + """Attach footer metadata when enabled.""" + if not getattr(plot_spec, "include_footer_metadata", True): + return + try: + footer_text = _build_footer_static(plot_spec) + if footer_text: + set_footer_with_box(fig, footer_text) + except Exception: + logger.exception("Failed to draw footer; continuing without footer.") + + # --- Static Map Plot --- def plot_maps( plot_spec: PlotSpec, @@ -117,32 +199,14 @@ def plot_maps( InvalidArrayError: If ground_truth and prediction arrays have incompatible shapes. """ - # Check the shapes of the arrays - height, width = validate_2d_pair(ground_truth, prediction) - - # Load land mask if specified - land_mask = load_land_mask(plot_spec.land_mask_path, (height, width)) - - # Pre-compute range check to decide top spacing (warning badge may need extra room) - (gt_min, gt_max), (pred_min, pred_max) = compute_display_ranges( - ground_truth, prediction, plot_spec - ) - range_check_report = compute_range_check_report( - ground_truth, - prediction, - vmin=gt_min, - vmax=gt_max, - outside_warn=getattr(plot_spec, "outside_warn", 0.05), - severe_outside=getattr(plot_spec, "severe_outside", 0.20), - include_shared_range_mismatch_check=getattr( - plot_spec, "include_shared_range_mismatch_check", True - ), - ) - - # Increase title space if warnings are present to avoid overlap with axes titles - layout_config = None - if range_check_report.warnings: - layout_config = LayoutConfig(title_footer=TitleFooterConfig(title_space=0.10)) + ( + height, + width, + land_mask, + layout_config, + warnings, + levels, + ) = _prepare_static_plot(plot_spec, ground_truth, prediction) # Initialise the figure and axes with dynamic top spacing if needed fig, axs, cbar_axes = build_layout( @@ -151,13 +215,11 @@ def plot_maps( width=width, layout_config=layout_config, ) - levels = levels_from_spec(plot_spec) # Prepare difference rendering parameters if needed - diff_colour_scale = None - if plot_spec.include_difference: - difference = compute_difference(ground_truth, prediction, plot_spec.diff_mode) - diff_colour_scale = make_diff_colourmap(difference, mode=plot_spec.diff_mode) + difference, diff_colour_scale = _prepare_difference( + plot_spec, ground_truth, prediction + ) # Draw the ground truth and prediction map images image_groundtruth, image_prediction, image_difference, _ = _draw_frame( @@ -166,7 +228,7 @@ def plot_maps( prediction, plot_spec, diff_colour_scale, - precomputed_difference=difference if plot_spec.include_difference else None, + precomputed_difference=difference, levels_override=levels, land_mask=land_mask, ) @@ -191,40 +253,8 @@ def plot_maps( logger.exception("Failed to draw suptitle; continuing without title.") title_text = None - # Include range_check report (already computed above) - badge = ( - "" - if not range_check_report.warnings - else "Warnings: " + ", ".join(range_check_report.warnings) - ) - if badge: - # Place the warning just below the title - if title_text is not None: - _, title_y = title_text.get_position() - n_lines = title_text.get_text().count("\n") + 1 - # Reduced gap to avoid overlapping axes titles; keep badge close to title - warning_y = max(title_y - (0.05 + 0.02 * (n_lines - 1)), 0.0) - else: - warning_y = 0.90 - draw_badge_with_box(fig, 0.5, warning_y, badge) - - # Footer metadata at the bottom - if getattr(plot_spec, "include_footer_metadata", True): - try: - footer_text = _build_footer_static(plot_spec) - if footer_text: - set_footer_with_box(fig, footer_text) - except Exception: - logger.exception("Failed to draw footer; continuing without footer.") - - # Footer metadata at the bottom - if getattr(plot_spec, "include_footer_metadata", True): - try: - footer_text = _build_footer_static(plot_spec) - if footer_text: - _set_footer_with_box(fig, footer_text) - except Exception: - logger.exception("Failed to draw footer; continuing without footer.") + _draw_warning_badge(fig, title_text, warnings) + _maybe_add_footer(fig, plot_spec) try: return {"sea-ice_concentration-static-maps": [convert.image_from_figure(fig)]} From 0060475d56aca66631792f9f84d84d143fa18a2f Mon Sep 17 00:00:00 2001 From: Lydia France Date: Thu, 27 Nov 2025 11:49:06 +0000 Subject: [PATCH 39/49] Fix issues for mypy --- ice_station_zebra/visualisations/plotting_maps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ice_station_zebra/visualisations/plotting_maps.py b/ice_station_zebra/visualisations/plotting_maps.py index f74efe8e..8ba698af 100644 --- a/ice_station_zebra/visualisations/plotting_maps.py +++ b/ice_station_zebra/visualisations/plotting_maps.py @@ -131,7 +131,7 @@ def _prepare_difference( plot_spec: PlotSpec, ground_truth: np.ndarray, prediction: np.ndarray, -) -> tuple[np.ndarray | None, DiffColourmapSpec | ListedColormap | None]: +) -> tuple[np.ndarray | None, DiffColourmapSpec | None]: """Compute difference arrays and colour scales if requested.""" if not plot_spec.include_difference: return None, None From 058491cc0de882d6c61f12bd13980da0da5af8fa Mon Sep 17 00:00:00 2001 From: Lydia France Date: Thu, 27 Nov 2025 12:01:25 +0000 Subject: [PATCH 40/49] Add ffmpeg to github action to test video saving --- .github/workflows/test_code.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/test_code.yaml b/.github/workflows/test_code.yaml index 6524781d..1e9eb023 100644 --- a/.github/workflows/test_code.yaml +++ b/.github/workflows/test_code.yaml @@ -36,6 +36,9 @@ jobs: uses: astral-sh/setup-uv@v6 with: version: "0.8.3" + - name: Install ffmpeg (so matplotlib can call it) + run: | + sudo apt-get update && sudo apt-get install -y ffmpeg - name: Run pytest run: uv run --group dev pytest - name: Run mypy From 008d00166321e5e5606b007fa9b2385c5713ef51 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Thu, 27 Nov 2025 20:35:23 +0000 Subject: [PATCH 41/49] Change saving locally to base path root --- .../callbacks/raw_inputs_callback.py | 31 ++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/ice_station_zebra/callbacks/raw_inputs_callback.py b/ice_station_zebra/callbacks/raw_inputs_callback.py index eebdd270..5876753a 100644 --- a/ice_station_zebra/callbacks/raw_inputs_callback.py +++ b/ice_station_zebra/callbacks/raw_inputs_callback.py @@ -82,7 +82,22 @@ def __init__( # noqa: PLR0913 self.frequency = None else: self.frequency = int(max(1, frequency)) - self.save_dir = Path(save_dir) if save_dir else None + + self.config = config or {} + + # Get base_path from config to use as root folder + base_path = Path(self.config.get("base_path", "../ice-station-zebra/data")) + + # Resolve save_dir relative to base_path if it's a relative path + if save_dir: + save_dir_path = Path(save_dir) + if save_dir_path.is_absolute(): + self.save_dir = save_dir_path + else: + self.save_dir = (base_path / save_dir_path).resolve() + else: + self.save_dir = None + self.timestep_index = timestep_index self.variable_styles = variable_styles or {} self._has_plotted = False @@ -91,7 +106,17 @@ def __init__( # noqa: PLR0913 self.make_video_plots = make_video_plots self.video_fps = video_fps self.video_format = video_format - self.video_save_dir = Path(video_save_dir) if video_save_dir else self.save_dir + + # Resolve video_save_dir relative to base_path if it's a relative path + if video_save_dir: + video_save_dir_path = Path(video_save_dir) + if video_save_dir_path.is_absolute(): + self.video_save_dir = video_save_dir_path + else: + self.video_save_dir = (base_path / video_save_dir_path).resolve() + else: + self.video_save_dir = self.save_dir + self.max_animation_frames = max_animation_frames # WandB logging control @@ -102,8 +127,6 @@ def __init__( # noqa: PLR0913 self.plot_spec = DEFAULT_SIC_SPEC else: self.plot_spec = plot_spec - - self.config = config or {} self._land_mask_path_detected = False self._land_mask_array: np.ndarray | None = None From 6ad10fbb4f4e7ed828cc1bc612ac0720f81a921d Mon Sep 17 00:00:00 2001 From: Lydia France Date: Thu, 27 Nov 2025 20:36:53 +0000 Subject: [PATCH 42/49] Fix path names --- ice_station_zebra/config/evaluate/callbacks/raw_inputs.yaml | 4 ++-- ice_station_zebra/visualisations/cli.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ice_station_zebra/config/evaluate/callbacks/raw_inputs.yaml b/ice_station_zebra/config/evaluate/callbacks/raw_inputs.yaml index 6d355515..cd6eee33 100644 --- a/ice_station_zebra/config/evaluate/callbacks/raw_inputs.yaml +++ b/ice_station_zebra/config/evaluate/callbacks/raw_inputs.yaml @@ -3,14 +3,14 @@ raw_inputs: _target_: ice_station_zebra.callbacks.RawInputsCallback frequency: null # Plot once; set N to repeat every N batches timestep_index: 0 # Which history timestep to plot (0 = most recent) - save_dir: ./data/raw_input_plots # Directory to save plots (null to skip saving to disk) + save_dir: ./raw_input_plots # Directory to save plots (null to skip saving to disk) log_to_wandb: true # Whether to log plots to WandB (if false and save_dir null, nothing saved) # Animation configuration (following PlottingCallback pattern) make_video_plots: false # Set to true to create temporal animations video_fps: 2 # Frames per second for animations video_format: gif # mp4 or gif - video_save_dir: ./data/raw_input_animations # Directory to save animations (null to skip disk save) + video_save_dir: ./raw_input_animations # Directory to save animations (null to skip disk save) max_animation_frames: null # Limit frames (null = unlimited; 30 ≈ 1 month daily data) # Plot specification (colourmap, hemisphere, etc.) diff --git a/ice_station_zebra/visualisations/cli.py b/ice_station_zebra/visualisations/cli.py index cac44cad..c2c3548f 100644 --- a/ice_station_zebra/visualisations/cli.py +++ b/ice_station_zebra/visualisations/cli.py @@ -184,7 +184,7 @@ def plot_raw_inputs( save_dir = ( Path(output_dir) if output_dir - else callback.save_dir or Path("./data/raw_input_plots") + else callback.save_dir or Path("./raw_input_plots") ) # Plot the raw inputs @@ -279,7 +279,7 @@ def animate_raw_inputs( save_dir = ( Path(output_dir) if output_dir - else callback.video_save_dir or Path("./data/raw_input_animations") + else callback.video_save_dir or Path("./raw_input_animations") ) # Validate video format From b3b1bbe6b87aeb599e64a57b66193fad0cf19dd6 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Thu, 27 Nov 2025 20:39:10 +0000 Subject: [PATCH 43/49] Option for scientific notation on colourbar --- ice_station_zebra/visualisations/layout.py | 26 +++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/ice_station_zebra/visualisations/layout.py b/ice_station_zebra/visualisations/layout.py index 613bca1e..4c383565 100644 --- a/ice_station_zebra/visualisations/layout.py +++ b/ice_station_zebra/visualisations/layout.py @@ -951,10 +951,20 @@ def format_linear_ticks( vmax: float | None = None, decimals: int = 1, is_vertical: bool, + use_scientific_notation: bool = False, ) -> None: """Format a linear colourbar with 5 ticks. If vmin/vmax are not provided, derive them from the colourbar's mappable. + + Args: + colourbar: Colorbar to format. + vmin: Minimum value. + vmax: Maximum value. + decimals: Number of decimal places for tick labels. + is_vertical: Whether the colorbar is vertical. + use_scientific_notation: Whether to format tick labels in scientific notation. + """ axis = colourbar.ax.yaxis if is_vertical else colourbar.ax.xaxis @@ -967,13 +977,17 @@ def format_linear_ticks( float(vmin), float(vmax), _DEFAULT_LAYOUT_CONFIG.formatting.num_ticks_linear ) colourbar.set_ticks([float(t) for t in ticks]) - axis.set_major_formatter(FuncFormatter(lambda x, _: f"{x:.{decimals}f}")) + + if use_scientific_notation: + axis.set_major_formatter(FuncFormatter(lambda x, _: f"{x:.{decimals}e}")) + else: + axis.set_major_formatter(FuncFormatter(lambda x, _: f"{x:.{decimals}f}")) if not is_vertical: colourbar.ax.xaxis.set_tick_params(pad=1) apply_monospace_to_cbar_text(colourbar) -def format_symmetric_ticks( +def format_symmetric_ticks( # noqa: PLR0913 colourbar: Colorbar, *, vmin: float, @@ -981,6 +995,7 @@ def format_symmetric_ticks( decimals: int = 2, is_vertical: bool, centre: float | None = None, + use_scientific_notation: bool = False, ) -> None: """Format symmetric diverging ticks with a centred midpoint. @@ -993,6 +1008,7 @@ def format_symmetric_ticks( decimals: Number of decimal places for tick labels. is_vertical: Whether the colorbar is vertical. centre: centre value for diverging colourmap (default: 0.0). + use_scientific_notation: Whether to format tick labels in scientific notation. """ axis = colourbar.ax.yaxis if is_vertical else colourbar.ax.xaxis @@ -1006,7 +1022,11 @@ def format_symmetric_ticks( vmax, ] colourbar.set_ticks([float(t) for t in ticks]) - axis.set_major_formatter(FuncFormatter(lambda x, _: f"{x:.{decimals}f}")) + + if use_scientific_notation: + axis.set_major_formatter(FuncFormatter(lambda x, _: f"{x:.{decimals}e}")) + else: + axis.set_major_formatter(FuncFormatter(lambda x, _: f"{x:.{decimals}f}")) if not is_vertical: colourbar.ax.xaxis.set_tick_params(pad=1) apply_monospace_to_cbar_text(colourbar) From e4601be246ec0793d8fee9df8f719e92d6ebc6d7 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Thu, 27 Nov 2025 20:40:20 +0000 Subject: [PATCH 44/49] altering colourbar format for raw plots --- .../visualisations/plotting_raw_inputs.py | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/ice_station_zebra/visualisations/plotting_raw_inputs.py b/ice_station_zebra/visualisations/plotting_raw_inputs.py index 4577e7b6..eae2dc1c 100644 --- a/ice_station_zebra/visualisations/plotting_raw_inputs.py +++ b/ice_station_zebra/visualisations/plotting_raw_inputs.py @@ -186,6 +186,11 @@ def plot_raw_inputs_for_timestep( # noqa: PLR0913, C901, PLR0912, PLR0915 cbar = fig.colorbar(image, ax=ax, cax=cax, orientation=orientation) is_vertical = orientation == "vertical" decimals = style.decimals if style.decimals is not None else 2 + use_scientific = ( + style.use_scientific_notation + if style.use_scientific_notation is not None + else False + ) if isinstance(norm, TwoSlopeNorm): format_symmetric_ticks( cbar, @@ -194,10 +199,16 @@ def plot_raw_inputs_for_timestep( # noqa: PLR0913, C901, PLR0912, PLR0915 decimals=decimals, is_vertical=is_vertical, centre=norm.vcenter, + use_scientific_notation=use_scientific, ) else: format_linear_ticks( - cbar, vmin=vmin, vmax=vmax, decimals=decimals, is_vertical=is_vertical + cbar, + vmin=vmin, + vmax=vmax, + decimals=decimals, + is_vertical=is_vertical, + use_scientific_notation=use_scientific, ) # Title @@ -332,6 +343,11 @@ def video_raw_input_for_variable( # noqa: PLR0913, C901, PLR0915 # Format colourbar ticks based on normalisation type decimals = style.decimals if style.decimals is not None else 2 + use_scientific = ( + style.use_scientific_notation + if style.use_scientific_notation is not None + else False + ) if isinstance(norm, TwoSlopeNorm): format_symmetric_ticks( cbar, @@ -340,10 +356,16 @@ def video_raw_input_for_variable( # noqa: PLR0913, C901, PLR0915 decimals=decimals, is_vertical=is_vertical, centre=norm.vcenter, + use_scientific_notation=use_scientific, ) else: format_linear_ticks( - cbar, vmin=vmin, vmax=vmax, decimals=decimals, is_vertical=is_vertical + cbar, + vmin=vmin, + vmax=vmax, + decimals=decimals, + is_vertical=is_vertical, + use_scientific_notation=use_scientific, ) # Create title (will be updated each frame) From c44fd65ae9dffed7582eea28a25ac138a798f43a Mon Sep 17 00:00:00 2001 From: Lydia France Date: Thu, 27 Nov 2025 20:44:15 +0000 Subject: [PATCH 45/49] add scientific notation to variable style --- ice_station_zebra/visualisations/plotting_core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ice_station_zebra/visualisations/plotting_core.py b/ice_station_zebra/visualisations/plotting_core.py index df522959..59a6668e 100644 --- a/ice_station_zebra/visualisations/plotting_core.py +++ b/ice_station_zebra/visualisations/plotting_core.py @@ -36,6 +36,7 @@ class VariableStyle: units: Display units for the variable (e.g., "K", "m/s"). origin: Imshow origin override ("upper" keeps north-up, "lower" keeps south-up). decimals: Number of decimal places for colourbar tick labels (default: 2). + use_scientific_notation: Whether to format colourbar tick labels in scientific notation (default: False). """ @@ -47,6 +48,7 @@ class VariableStyle: units: str | None = None origin: Literal["upper", "lower"] | None = None decimals: int | None = None + use_scientific_notation: bool | None = None def colourmap_with_bad( From e005fe6796039a840c04014f5ae70bb27b77bea6 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Thu, 27 Nov 2025 20:45:09 +0000 Subject: [PATCH 46/49] add test for scientific notation --- tests/plotting/test_raw_inputs.py | 51 +++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/plotting/test_raw_inputs.py b/tests/plotting/test_raw_inputs.py index 347f83ce..eefcd2a0 100644 --- a/tests/plotting/test_raw_inputs.py +++ b/tests/plotting/test_raw_inputs.py @@ -219,6 +219,57 @@ def test_style_for_variable_wildcard_match( assert style.decimals == 4 +def test_style_for_variable_scientific_notation( + variable_styles: dict[str, dict[str, Any]], +) -> None: + """Test scientific notation option in styling.""" + # Add style with scientific notation + styles_with_scientific = { + **variable_styles, + "era5:q_10": { + "cmap": "viridis", + "decimals": 2, + "units": "kg/kg", + "use_scientific_notation": True, + }, + } + + style = style_for_variable("era5:q_10", styles_with_scientific) + + assert style.cmap == "viridis" + assert style.decimals == 2 + assert style.units == "kg/kg" + assert style.use_scientific_notation is True + + +def test_plot_with_scientific_notation( + era5_humidity_2d: np.ndarray, + base_plot_spec: PlotSpec, +) -> None: + """Test plotting with scientific notation enabled.""" + styles_with_scientific = { + "era5:q_10": { + "cmap": "viridis", + "decimals": 2, + "units": "kg/kg", + "use_scientific_notation": True, + }, + } + + results = plot_raw_inputs_for_timestep( + channel_arrays=[era5_humidity_2d], + channel_names=["era5:q_10"], + when=TEST_DATE, + plot_spec_base=base_plot_spec, + styles=styles_with_scientific, + ) + + assert len(results) == 1 + name, pil_image, _ = results[0] + assert name == "era5:q_10" + assert isinstance(pil_image, Image.Image) + + def test_style_for_variable_no_match() -> None: """Test default styling when no match found.""" style = style_for_variable("unknown:variable", {}) From 586a800128299e90d5298607de12ffe0f965d441 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Thu, 27 Nov 2025 20:46:10 +0000 Subject: [PATCH 47/49] Change humidity to scientific notation (other option is more decimals) --- .../config/evaluate/callbacks/raw_inputs.yaml | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/ice_station_zebra/config/evaluate/callbacks/raw_inputs.yaml b/ice_station_zebra/config/evaluate/callbacks/raw_inputs.yaml index cd6eee33..f5c7ba02 100644 --- a/ice_station_zebra/config/evaluate/callbacks/raw_inputs.yaml +++ b/ice_station_zebra/config/evaluate/callbacks/raw_inputs.yaml @@ -34,12 +34,22 @@ raw_inputs: "era5:2t": { cmap: "RdBu_r", two_slope_centre: 273.15, units: "K" } # 2m temperature "era5:t_*": { cmap: "RdBu_r", two_slope_centre: 273.15, units: "K" } # temperature at various levels "era5:msl": { cmap: "RdYlBu_r", units: "Pa" } # mean sea level pressure - "era5:sp": { cmap: "RdYlBu_r" } # surface pressure + "era5:sp": { cmap: "RdYlBu_r", units: "Pa" } # surface pressure + + "era5:q_10": { cmap: "viridis", decimals: 2, units: "kg/kg", use_scientific_notation: true } # specific humidity at various levels (scientific notation handles small values better) + "era5:q_250": { cmap: "viridis", decimals: 2, units: "kg/kg", use_scientific_notation: true } + "era5:q_500": { cmap: "viridis", decimals: 2, units: "kg/kg", use_scientific_notation: true } + "era5:q_1000": { cmap: "viridis", decimals: 2, units: "kg/kg", use_scientific_notation: true } + + # Without scientific notation + # "era5:q_10": { cmap: "viridis", decimals: 7, units: "kg/kg" } # specific humidity at various levels (small values need more decimals) + # "era5:q_250": { cmap: "viridis", decimals: 5, units: "kg/kg" } + # "era5:q_500": { cmap: "viridis", decimals: 4, units: "kg/kg" } + # "era5:q_1000": { cmap: "viridis", decimals: 4, units: "kg/kg" } - "era5:q_*": { cmap: "viridis", decimals: 4 } # specific humidity at various levels (small values need more decimals) "era5:z_*": { cmap: "plasma", units: "m" } # geopotential at various levels - "era5:u_*": { cmap: "RdBu_r", two_slope_centre: 0.0 } # u-wind component at various levels - "era5:v_*": { cmap: "RdBu_r", two_slope_centre: 0.0 } # v-wind component at various levels + "era5:u_*": { cmap: "RdBu_r", two_slope_centre: 0.0, units: "m/s" } # u-wind component at various levels + "era5:v_*": { cmap: "RdBu_r", two_slope_centre: 0.0, units: "m/s" } # v-wind component at various levels # Sea ice concentration "osisaf-south:ice_conc": { cmap: "Blues_r" } From 447db5cbabc76cf69c4489465c0f5c977ed63cfa Mon Sep 17 00:00:00 2001 From: Lydia France Date: Thu, 27 Nov 2025 20:48:57 +0000 Subject: [PATCH 48/49] Fix problem with mypy --- .../callbacks/raw_inputs_callback.py | 44 +++++++++++++++++-- 1 file changed, 41 insertions(+), 3 deletions(-) diff --git a/ice_station_zebra/callbacks/raw_inputs_callback.py b/ice_station_zebra/callbacks/raw_inputs_callback.py index 5876753a..426ce589 100644 --- a/ice_station_zebra/callbacks/raw_inputs_callback.py +++ b/ice_station_zebra/callbacks/raw_inputs_callback.py @@ -43,7 +43,7 @@ class RawInputsCallback(Callback): """A callback to plot raw input variables during evaluation.""" - def __init__( # noqa: PLR0913 + def __init__( # noqa: PLR0913, PLR0912, PLR0915 self, *, frequency: int | None = None, @@ -88,6 +88,44 @@ def __init__( # noqa: PLR0913 # Get base_path from config to use as root folder base_path = Path(self.config.get("base_path", "../ice-station-zebra/data")) + # Resolve save_dir relative to base_path if it's a relative path + if save_dir: + save_dir_path = Path(save_dir) + if save_dir_path.is_absolute(): + self.save_dir: Path | None = save_dir_path + else: + self.save_dir = (base_path / save_dir_path).resolve() + else: + self.save_dir: Path | None = None + """Create raw input plots and/or animations during evaluation. + + Args: + frequency: Create plots every `frequency` batches; `None` plots once per run. + save_dir: Directory to save static plots to. If None and log_to_wandb=False, no plots saved. + plot_spec: Plotting specification (colourmap, hemisphere, etc.). + config: Configuration dictionary for land mask detection. + timestep_index: Which history timestep to plot (0 = most recent). + variable_styles: Per-variable styling overrides (cmap, vmin/vmax, units, etc.). + make_video_plots: Whether to create temporal animations of raw inputs. + video_fps: Frames per second for animations. + video_format: Video format ("mp4" or "gif"). + video_save_dir: Directory to save animations. If None and log_to_wandb=False, no videos saved. + max_animation_frames: Maximum number of frames to include in animations (None = unlimited). + Limits temporal accumulation to control memory and file size. + log_to_wandb: Whether to log plots and animations to WandB (default: True). + + """ + super().__init__() + if frequency is None: + self.frequency = None + else: + self.frequency = int(max(1, frequency)) + + self.config = config or {} + + # Get base_path from config to use as root folder + base_path = Path(self.config.get("base_path", "../ice-station-zebra/data")) + # Resolve save_dir relative to base_path if it's a relative path if save_dir: save_dir_path = Path(save_dir) @@ -111,11 +149,11 @@ def __init__( # noqa: PLR0913 if video_save_dir: video_save_dir_path = Path(video_save_dir) if video_save_dir_path.is_absolute(): - self.video_save_dir = video_save_dir_path + self.video_save_dir: Path | None = video_save_dir_path else: self.video_save_dir = (base_path / video_save_dir_path).resolve() else: - self.video_save_dir = self.save_dir + self.video_save_dir: Path | None = self.save_dir self.max_animation_frames = max_animation_frames From c8d98c3058428fcec8345422c6d99573fb464dc4 Mon Sep 17 00:00:00 2001 From: Lydia France Date: Thu, 27 Nov 2025 20:58:26 +0000 Subject: [PATCH 49/49] Fix mypy problem with save dir --- ice_station_zebra/callbacks/raw_inputs_callback.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ice_station_zebra/callbacks/raw_inputs_callback.py b/ice_station_zebra/callbacks/raw_inputs_callback.py index 426ce589..01ac195a 100644 --- a/ice_station_zebra/callbacks/raw_inputs_callback.py +++ b/ice_station_zebra/callbacks/raw_inputs_callback.py @@ -43,6 +43,9 @@ class RawInputsCallback(Callback): """A callback to plot raw input variables during evaluation.""" + save_dir: Path | None + video_save_dir: Path | None + def __init__( # noqa: PLR0913, PLR0912, PLR0915 self, *, @@ -92,11 +95,11 @@ def __init__( # noqa: PLR0913, PLR0912, PLR0915 if save_dir: save_dir_path = Path(save_dir) if save_dir_path.is_absolute(): - self.save_dir: Path | None = save_dir_path + self.save_dir = save_dir_path else: self.save_dir = (base_path / save_dir_path).resolve() else: - self.save_dir: Path | None = None + self.save_dir = None """Create raw input plots and/or animations during evaluation. Args: @@ -149,11 +152,11 @@ def __init__( # noqa: PLR0913, PLR0912, PLR0915 if video_save_dir: video_save_dir_path = Path(video_save_dir) if video_save_dir_path.is_absolute(): - self.video_save_dir: Path | None = video_save_dir_path + self.video_save_dir = video_save_dir_path else: self.video_save_dir = (base_path / video_save_dir_path).resolve() else: - self.video_save_dir: Path | None = self.save_dir + self.video_save_dir = self.save_dir self.max_animation_frames = max_animation_frames