diff --git a/.github/workflows/test_code.yaml b/.github/workflows/test_code.yaml index 6524781d..1e9eb023 100644 --- a/.github/workflows/test_code.yaml +++ b/.github/workflows/test_code.yaml @@ -36,6 +36,9 @@ jobs: uses: astral-sh/setup-uv@v6 with: version: "0.8.3" + - name: Install ffmpeg (so matplotlib can call it) + run: | + sudo apt-get update && sudo apt-get install -y ffmpeg - name: Run pytest run: uv run --group dev pytest - name: Run mypy diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0c8f6c3a..5a969794 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,7 @@ repos: # Run the linter - id: ruff-check args: ["--fix", "--show-fixes"] - pass_filenames: false + pass_filenames: true # Run the formatter - id: ruff-format - pass_filenames: false + pass_filenames: true diff --git a/README.md b/README.md index 605178a6..080f9c3c 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,22 @@ Run `uv run zebra train` to train using the datasets specified in the config. Run `uv run zebra evaluate --checkpoint PATH_TO_A_CHECKPOINT` to evaluate using a checkpoint from a training run. +### Visualisations + +Plot raw input variables from the test dataset: + +```bash +uv run zebra visualisations plot-raw-inputs --config-name .yaml --sample-idx 0 +``` + +Create animations of raw inputs over time: + +```bash +uv run zebra visualisations animate-raw-inputs --config-name .yaml +``` + +Settings (output directories, styling, animation parameters) are read from `config.evaluate.callbacks.raw_inputs` in your YAML config files. Command-line options can override config values if needed. + ## Adding a new model ### Background diff --git a/ice_station_zebra/callbacks/__init__.py b/ice_station_zebra/callbacks/__init__.py index fe73afc5..6c784959 100644 --- a/ice_station_zebra/callbacks/__init__.py +++ b/ice_station_zebra/callbacks/__init__.py @@ -1,11 +1,13 @@ from .ema_weight_averaging_callback import EMAWeightAveragingCallback from .metric_summary_callback import MetricSummaryCallback from .plotting_callback import PlottingCallback +from .raw_inputs_callback import RawInputsCallback from .unconditional_checkpoint import UnconditionalCheckpoint __all__ = [ "EMAWeightAveragingCallback", "MetricSummaryCallback", "PlottingCallback", + "RawInputsCallback", "UnconditionalCheckpoint", ] diff --git a/ice_station_zebra/callbacks/plotting_callback.py b/ice_station_zebra/callbacks/plotting_callback.py index 0212b899..22d834c3 100644 --- a/ice_station_zebra/callbacks/plotting_callback.py +++ b/ice_station_zebra/callbacks/plotting_callback.py @@ -1,3 +1,8 @@ +# Set matplotlib backend before any plotting imports +import matplotlib as mpl + +mpl.use("Agg") + import dataclasses import logging from collections.abc import Mapping, Sequence @@ -207,7 +212,7 @@ def log_static_plots( ) except InvalidArrayError as err: logger.warning("Static plotting skipped due to invalid arrays: %s", err) - except (ValueError, MemoryError, OSError): + except Exception: logger.exception("Static plotting failed") def log_video_plots( diff --git a/ice_station_zebra/callbacks/raw_inputs_callback.py b/ice_station_zebra/callbacks/raw_inputs_callback.py new file mode 100644 index 00000000..01ac195a --- /dev/null +++ b/ice_station_zebra/callbacks/raw_inputs_callback.py @@ -0,0 +1,588 @@ +"""Callback for plotting raw input variables during evaluation.""" + +# Set matplotlib backend before any plotting imports +import matplotlib as mpl + +mpl.use("Agg") + +import dataclasses +import gc +import logging +from collections.abc import Mapping, Sequence +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal, cast + +import numpy as np +from lightning import LightningModule, Trainer +from lightning.pytorch import Callback +from lightning.pytorch.loggers import Logger as LightningLogger + +from ice_station_zebra.callbacks.metadata import infer_hemisphere +from ice_station_zebra.data_loaders import CombinedDataset +from ice_station_zebra.exceptions import InvalidArrayError, VideoRenderError +from ice_station_zebra.types import PlotSpec +from ice_station_zebra.visualisations.plotting_core import ( + detect_land_mask_path, + safe_filename, +) +from ice_station_zebra.visualisations.plotting_maps import DEFAULT_SIC_SPEC +from ice_station_zebra.visualisations.plotting_raw_inputs import ( + plot_raw_inputs_for_timestep, + video_raw_inputs_for_timesteps, +) + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + +logger = logging.getLogger(__name__) + +# Constants +EXPECTED_INPUT_NDIM = 5 # Expected input data shape: [B, T, C, H, W] + + +class RawInputsCallback(Callback): + """A callback to plot raw input variables during evaluation.""" + + save_dir: Path | None + video_save_dir: Path | None + + def __init__( # noqa: PLR0913, PLR0912, PLR0915 + self, + *, + frequency: int | None = None, + save_dir: str | Path | None = None, + plot_spec: PlotSpec | None = None, + config: dict | None = None, + timestep_index: int = 0, + variable_styles: dict[str, dict[str, Any]] | None = None, + make_video_plots: bool = False, + video_fps: int = 2, + video_format: Literal["mp4", "gif"] = "gif", + video_save_dir: str | Path | None = None, + max_animation_frames: int | None = None, + log_to_wandb: bool = True, + ) -> None: + """Create raw input plots and/or animations during evaluation. + + Args: + frequency: Create plots every `frequency` batches; `None` plots once per run. + save_dir: Directory to save static plots to. If None and log_to_wandb=False, no plots saved. + plot_spec: Plotting specification (colourmap, hemisphere, etc.). + config: Configuration dictionary for land mask detection. + timestep_index: Which history timestep to plot (0 = most recent). + variable_styles: Per-variable styling overrides (cmap, vmin/vmax, units, etc.). + make_video_plots: Whether to create temporal animations of raw inputs. + video_fps: Frames per second for animations. + video_format: Video format ("mp4" or "gif"). + video_save_dir: Directory to save animations. If None and log_to_wandb=False, no videos saved. + max_animation_frames: Maximum number of frames to include in animations (None = unlimited). + Limits temporal accumulation to control memory and file size. + log_to_wandb: Whether to log plots and animations to WandB (default: True). + + """ + super().__init__() + if frequency is None: + self.frequency = None + else: + self.frequency = int(max(1, frequency)) + + self.config = config or {} + + # Get base_path from config to use as root folder + base_path = Path(self.config.get("base_path", "../ice-station-zebra/data")) + + # Resolve save_dir relative to base_path if it's a relative path + if save_dir: + save_dir_path = Path(save_dir) + if save_dir_path.is_absolute(): + self.save_dir = save_dir_path + else: + self.save_dir = (base_path / save_dir_path).resolve() + else: + self.save_dir = None + """Create raw input plots and/or animations during evaluation. + + Args: + frequency: Create plots every `frequency` batches; `None` plots once per run. + save_dir: Directory to save static plots to. If None and log_to_wandb=False, no plots saved. + plot_spec: Plotting specification (colourmap, hemisphere, etc.). + config: Configuration dictionary for land mask detection. + timestep_index: Which history timestep to plot (0 = most recent). + variable_styles: Per-variable styling overrides (cmap, vmin/vmax, units, etc.). + make_video_plots: Whether to create temporal animations of raw inputs. + video_fps: Frames per second for animations. + video_format: Video format ("mp4" or "gif"). + video_save_dir: Directory to save animations. If None and log_to_wandb=False, no videos saved. + max_animation_frames: Maximum number of frames to include in animations (None = unlimited). + Limits temporal accumulation to control memory and file size. + log_to_wandb: Whether to log plots and animations to WandB (default: True). + + """ + super().__init__() + if frequency is None: + self.frequency = None + else: + self.frequency = int(max(1, frequency)) + + self.config = config or {} + + # Get base_path from config to use as root folder + base_path = Path(self.config.get("base_path", "../ice-station-zebra/data")) + + # Resolve save_dir relative to base_path if it's a relative path + if save_dir: + save_dir_path = Path(save_dir) + if save_dir_path.is_absolute(): + self.save_dir = save_dir_path + else: + self.save_dir = (base_path / save_dir_path).resolve() + else: + self.save_dir = None + + self.timestep_index = timestep_index + self.variable_styles = variable_styles or {} + self._has_plotted = False + + # Animation settings + self.make_video_plots = make_video_plots + self.video_fps = video_fps + self.video_format = video_format + + # Resolve video_save_dir relative to base_path if it's a relative path + if video_save_dir: + video_save_dir_path = Path(video_save_dir) + if video_save_dir_path.is_absolute(): + self.video_save_dir = video_save_dir_path + else: + self.video_save_dir = (base_path / video_save_dir_path).resolve() + else: + self.video_save_dir = self.save_dir + + self.max_animation_frames = max_animation_frames + + # WandB logging control + self.log_to_wandb = log_to_wandb + + # Ensure plot_spec is a PlotSpec instance + if plot_spec is None: + self.plot_spec = DEFAULT_SIC_SPEC + else: + self.plot_spec = plot_spec + self._land_mask_path_detected = False + self._land_mask_array: np.ndarray | None = None + + # Temporal data accumulation for animations + self._temporal_data: dict[ + str, list[np.ndarray] + ] = {} # var_name -> list of [H,W] arrays + self._temporal_dates: list[Any] = [] + self._dataset_ref: CombinedDataset | None = None + + def on_test_start(self, _trainer: Trainer, _module: LightningModule) -> None: + """Called when the test loop starts.""" + logger.info("RawInputsCallback: Test loop started") + + def _detect_land_mask_path( # noqa: C901 + self, + dataset: CombinedDataset, + ) -> None: + """Detect and set the land mask path based on the dataset configuration.""" + if self._land_mask_path_detected: + return + + # Get base path from callback config or use default + base_path = self.config.get("base_path", "../ice-station-zebra/data") + + # Try to get dataset name from the target dataset + dataset_name = None + if hasattr(dataset, "target") and hasattr(dataset.target, "name"): + dataset_name = dataset.target.name + + # Detect land mask path + land_mask_path = detect_land_mask_path(base_path, dataset_name) + + # Always try to infer hemisphere regardless of land mask presence + hemisphere: Literal["north", "south"] | None = None + if isinstance(dataset_name, str): + low = dataset_name.lower() + if "south" in low: + hemisphere = cast("Literal['north', 'south']", "south") + elif "north" in low: + hemisphere = cast("Literal['north', 'south']", "north") + if hemisphere is None: + hemi_candidate = infer_hemisphere(dataset) + if hemi_candidate in ("north", "south"): + hemisphere = cast("Literal['north', 'south']", hemi_candidate) + + # Update plot_spec pieces independently + if hemisphere is not None and self.plot_spec.hemisphere != hemisphere: + self.plot_spec = dataclasses.replace(self.plot_spec, hemisphere=hemisphere) + + if land_mask_path: + # Set land mask path when found + self.plot_spec = dataclasses.replace( + self.plot_spec, + land_mask_path=land_mask_path, + ) + logger.info("Auto-detected land mask: %s", land_mask_path) + + # Load the land mask array + try: + self._land_mask_array = np.load(land_mask_path) + logger.debug( + "Loaded land mask array with shape: %s", self._land_mask_array.shape + ) + except Exception: + logger.exception("Failed to load land mask from %s", land_mask_path) + self._land_mask_array = None + else: + logger.debug("No land mask found for dataset: %s", dataset_name) + + self._land_mask_path_detected = True + + def on_test_batch_end( # noqa: C901, PLR0912 + self, + trainer: Trainer, + _module: LightningModule, + _outputs: Any, # noqa: ANN401 + batch: Any, # noqa: ANN401 + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Called when the test batch ends.""" + logger.debug( + "RawInputsCallback.on_test_batch_end called for batch %d", batch_idx + ) + + # Get dataset and date + dl: DataLoader | list[DataLoader] | None = trainer.test_dataloaders + if dl is None: + logger.warning("No test dataloaders found, skipping raw inputs plotting.") + return + + dataset = (dl[dataloader_idx] if isinstance(dl, Sequence) else dl).dataset + if not isinstance(dataset, CombinedDataset): + logger.warning( + "Dataset is of type %s, skipping raw inputs plotting.", type(dataset) + ) + return + + # Detect land mask path if not already done + self._detect_land_mask_path(dataset) + + # Store dataset reference for later use + if self._dataset_ref is None: + self._dataset_ref = dataset + + # Get the date for this batch's first sample + date = dataset.date_from_index(batch_idx) + + # Determine if we should plot static plots this batch + should_plot_static = False + if self.frequency is None: + should_plot_static = not self._has_plotted + else: + should_plot_static = batch_idx % self.frequency == 0 + + # Always accumulate temporal data if animations are enabled + # (regardless of static plot frequency) + if self.make_video_plots: + # Check if we've reached the frame limit + if ( + self.max_animation_frames is not None + and len(self._temporal_dates) >= self.max_animation_frames + ): + logger.debug( + "Reached max animation frames (%d), skipping further accumulation", + self.max_animation_frames, + ) + else: + try: + # Extract data without plotting + channel_arrays, channel_names = self._extract_channel_data( + batch, dataset + ) + if channel_arrays and channel_names: + self._accumulate_temporal_data( + channel_arrays, channel_names, date + ) + except (ValueError, RuntimeError, IndexError, KeyError) as e: + logger.warning( + "Failed to accumulate temporal data for batch %d: %s", + batch_idx, + e, + ) + + # Create static plots according to frequency + if should_plot_static: + logger.debug("Starting raw inputs plotting for batch %d", batch_idx) + try: + self.log_raw_inputs(batch, dataset, date, trainer.loggers, batch_idx) + logger.info( + "Successfully completed raw inputs plotting for batch %d", batch_idx + ) + except Exception: + logger.exception("Raw inputs plotting failed for batch %d", batch_idx) + else: + if self.frequency is None: + self._has_plotted = True + else: + logger.debug( + "Skipping static plots for batch %d (frequency=%s)", + batch_idx, + self.frequency, + ) + + def _extract_channel_data( + self, + batch: Mapping[str, Any], + dataset: CombinedDataset, + ) -> tuple[list[np.ndarray], list[str]]: + """Extract channel data from batch without plotting. + + Returns: + Tuple of (channel_arrays, channel_names). + + """ + # Collect all input channel arrays + channel_arrays = [] + for ds in dataset.inputs: + if ds.name not in batch: + logger.warning("Dataset %s not found in batch, skipping", ds.name) + continue + + input_data = batch[ds.name] # Shape: [B, T, C, H, W] + + # Take first batch and specified timestep + if input_data.ndim != EXPECTED_INPUT_NDIM: + logger.warning( + "Expected 5D input data [B,T,C,H,W], got shape %s for %s", + input_data.shape, + ds.name, + ) + continue + + timestep_data = input_data[0, self.timestep_index] # Shape: [C, H, W] + + # Add each channel as a 2D array + for c in range(timestep_data.shape[0]): + channel_arr = timestep_data[c].detach().cpu().numpy() + channel_arrays.append(channel_arr) + + if not channel_arrays: + logger.warning("No input channels found in batch") + return [], [] + + # Get variable names from dataset + channel_names = dataset.input_variable_names + + if len(channel_arrays) != len(channel_names): + logger.warning( + "Mismatch: %d channel arrays but %d channel names. Using generic names.", + len(channel_arrays), + len(channel_names), + ) + channel_names = [f"channel_{i}" for i in range(len(channel_arrays))] + + return channel_arrays, channel_names + + def log_raw_inputs( + self, + batch: Mapping[str, Any], + dataset: CombinedDataset, + date: Any, # noqa: ANN401 + lightning_loggers: list[LightningLogger], + _batch_idx: int, + ) -> None: + """Extract and log raw input plots.""" + # Early return if nothing will be saved + if not self.log_to_wandb and self.save_dir is None: + logger.debug( + "Skipping raw inputs plotting: log_to_wandb=False and save_dir=None" + ) + return + + try: + # Extract data + channel_arrays, channel_names = self._extract_channel_data(batch, dataset) + + if not channel_arrays: + logger.warning( + "No input channels found in batch, skipping raw inputs plotting" + ) + return + + # Plot the raw inputs + results = plot_raw_inputs_for_timestep( + channel_arrays=channel_arrays, + channel_names=channel_names, + when=date, + plot_spec_base=self.plot_spec, + land_mask=self._land_mask_array, + styles=self.variable_styles, + save_dir=self.save_dir, + ) + + # Log to WandB if enabled + if self.log_to_wandb: + for lightning_logger in lightning_loggers: + if hasattr(lightning_logger, "log_image"): + # Group images by their name for logging + for var_name, pil_image, _saved_path in results: + safe_name = safe_filename(var_name.replace(":", "__")) + lightning_logger.log_image( + key=f"raw_inputs/{safe_name}", + images=[pil_image], + ) + else: + logger.debug( + "Logger %s does not support images.", + lightning_logger.name + if lightning_logger.name + else "unknown", + ) + + logger.debug( + "Plotted %d raw input variables (saved to disk: %s, logged to WandB: %s)", + len(results), + self.save_dir is not None, + self.log_to_wandb, + ) + + except Exception: + logger.exception("Failed to log raw inputs") + + def _accumulate_temporal_data( + self, + channel_arrays: list[np.ndarray], + channel_names: list[str], + date: Any, # noqa: ANN401 + ) -> None: + """Accumulate temporal data for creating animations later. + + Args: + channel_arrays: List of 2D arrays [H, W] for this timestep. + channel_names: Variable names for each channel. + date: Date/datetime for this timestep. + + """ + # Initialise temporal data storage for each variable on first call + if not self._temporal_data: + for name in channel_names: + self._temporal_data[name] = [] + + # Append this timestep's data for each variable + for arr, name in zip(channel_arrays, channel_names, strict=True): + if name in self._temporal_data: + self._temporal_data[name].append(arr) + + # Append date + self._temporal_dates.append(date) + + logger.debug( + "Accumulated temporal data: %d timesteps for %d variables", + len(self._temporal_dates), + len(self._temporal_data), + ) + + def on_test_end(self, trainer: Trainer, _module: LightningModule) -> None: + """Called when the test loop ends. Create and log animations.""" + if not self.make_video_plots: + logger.debug("Video plots disabled, skipping animation creation") + return + + if not self._temporal_data or not self._temporal_dates: + logger.warning("No temporal data collected for animations") + return + + logger.debug("Creating animations for %d variables", len(self._temporal_data)) + try: + self.log_video_plots(trainer.loggers) + finally: + # Clear temporal data to free memory + self._temporal_data.clear() + self._temporal_dates.clear() + # Force garbage collection to clean up animation resources + gc.collect() + + def log_video_plots(self, lightning_loggers: list[LightningLogger]) -> None: # noqa: C901 + """Create and log temporal animations of raw input variables.""" + if not self.make_video_plots: + return + + # Early return if nothing will be saved + if not self.log_to_wandb and self.video_save_dir is None: + logger.debug( + "Skipping video plotting: log_to_wandb=False and video_save_dir=None" + ) + return + + try: + # Convert accumulated data to 3D arrays [T, H, W] + channel_arrays_stream = [] + channel_names = [] + + for var_name, frames in self._temporal_data.items(): + if not frames: + continue + # Stack frames into [T, H, W] + data_stream = np.stack(frames, axis=0) + channel_arrays_stream.append(data_stream) + channel_names.append(var_name) + + if not channel_arrays_stream: + logger.warning("No data to create animations") + return + + logger.info( + "Creating animations: %d variables x %d timesteps", + len(channel_names), + len(self._temporal_dates), + ) + + # Create animations for all variables + results = video_raw_inputs_for_timesteps( + channel_arrays_stream=channel_arrays_stream, + channel_names=channel_names, + dates=self._temporal_dates, + plot_spec_base=self.plot_spec, + land_mask=self._land_mask_array, + styles=self.variable_styles, + fps=self.video_fps, + video_format=self.video_format, + save_dir=self.video_save_dir, + ) + + # Log to WandB if enabled + if self.log_to_wandb: + for lightning_logger in lightning_loggers: + if hasattr(lightning_logger, "log_video"): + for var_name, video_buffer, _saved_path in results: + # Ensure buffer is at start + video_buffer.seek(0) + safe_name = safe_filename(var_name.replace(":", "__")) + lightning_logger.log_video( + key=f"raw_inputs_video/{safe_name}", + videos=[video_buffer], + format=[self.video_format], + ) + logger.debug("Logged video for %s to WandB", var_name) + else: + logger.debug( + "Logger %s does not support videos.", + lightning_logger.name + if lightning_logger.name + else "unknown", + ) + + logger.info( + "Successfully created %d animations (saved to disk: %s, logged to WandB: %s)", + len(results), + self.video_save_dir is not None, + self.log_to_wandb, + ) + + except (InvalidArrayError, VideoRenderError) as err: + logger.warning("Video plotting skipped: %s", err) + except (ValueError, MemoryError, OSError): + logger.exception("Video plotting failed") diff --git a/ice_station_zebra/cli/main.py b/ice_station_zebra/cli/main.py index 6ed99b86..877653b8 100644 --- a/ice_station_zebra/cli/main.py +++ b/ice_station_zebra/cli/main.py @@ -4,6 +4,7 @@ from ice_station_zebra.data_processors import datasets_cli from ice_station_zebra.evaluation import evaluation_cli from ice_station_zebra.training import training_cli +from ice_station_zebra.visualisations.cli import visualisations_cli # Configure hydra logging simple_stdout_log_config() @@ -17,6 +18,7 @@ app.add_typer(datasets_cli, name="datasets") app.add_typer(evaluation_cli) app.add_typer(training_cli) +app.add_typer(visualisations_cli, name="visualisations") if __name__ == "__main__": diff --git a/ice_station_zebra/config/evaluate/callbacks/raw_inputs.yaml b/ice_station_zebra/config/evaluate/callbacks/raw_inputs.yaml new file mode 100644 index 00000000..f5c7ba02 --- /dev/null +++ b/ice_station_zebra/config/evaluate/callbacks/raw_inputs.yaml @@ -0,0 +1,62 @@ +raw_inputs: + # Callback to plot raw input variables (ERA5 data, etc.) + _target_: ice_station_zebra.callbacks.RawInputsCallback + frequency: null # Plot once; set N to repeat every N batches + timestep_index: 0 # Which history timestep to plot (0 = most recent) + save_dir: ./raw_input_plots # Directory to save plots (null to skip saving to disk) + log_to_wandb: true # Whether to log plots to WandB (if false and save_dir null, nothing saved) + + # Animation configuration (following PlottingCallback pattern) + make_video_plots: false # Set to true to create temporal animations + video_fps: 2 # Frames per second for animations + video_format: gif # mp4 or gif + video_save_dir: ./raw_input_animations # Directory to save animations (null to skip disk save) + max_animation_frames: null # Limit frames (null = unlimited; 30 ≈ 1 month daily data) + + # Plot specification (colourmap, hemisphere, etc.) + plot_spec: + _target_: ice_station_zebra.types.PlotSpec + variable: "raw_inputs" + colourmap: "viridis" + colourbar_location: "vertical" + # hemisphere will be auto-detected from dataset name + + # Per-variable styling overrides (optional) + # Uncomment and customize as needed: + variable_styles: + # Diverging (signed) + "era5:10u": { cmap: "RdBu_r", two_slope_centre: 0.0, units: "m/s" } # 10m u-wind component + "era5:10v": { cmap: "RdBu_r", two_slope_centre: 0.0, units: "m/s" } # 10m v-wind component + "era5:sin_julian_day": { cmap: "PuOr", two_slope_centre: 0.0 } # sin of Julian day + "era5:cos_julian_day": { cmap: "PuOr", two_slope_centre: 0.0 } # cos of Julian day + + # Sequential (physical) + "era5:2t": { cmap: "RdBu_r", two_slope_centre: 273.15, units: "K" } # 2m temperature + "era5:t_*": { cmap: "RdBu_r", two_slope_centre: 273.15, units: "K" } # temperature at various levels + "era5:msl": { cmap: "RdYlBu_r", units: "Pa" } # mean sea level pressure + "era5:sp": { cmap: "RdYlBu_r", units: "Pa" } # surface pressure + + "era5:q_10": { cmap: "viridis", decimals: 2, units: "kg/kg", use_scientific_notation: true } # specific humidity at various levels (scientific notation handles small values better) + "era5:q_250": { cmap: "viridis", decimals: 2, units: "kg/kg", use_scientific_notation: true } + "era5:q_500": { cmap: "viridis", decimals: 2, units: "kg/kg", use_scientific_notation: true } + "era5:q_1000": { cmap: "viridis", decimals: 2, units: "kg/kg", use_scientific_notation: true } + + # Without scientific notation + # "era5:q_10": { cmap: "viridis", decimals: 7, units: "kg/kg" } # specific humidity at various levels (small values need more decimals) + # "era5:q_250": { cmap: "viridis", decimals: 5, units: "kg/kg" } + # "era5:q_500": { cmap: "viridis", decimals: 4, units: "kg/kg" } + # "era5:q_1000": { cmap: "viridis", decimals: 4, units: "kg/kg" } + + "era5:z_*": { cmap: "plasma", units: "m" } # geopotential at various levels + "era5:u_*": { cmap: "RdBu_r", two_slope_centre: 0.0, units: "m/s" } # u-wind component at various levels + "era5:v_*": { cmap: "RdBu_r", two_slope_centre: 0.0, units: "m/s" } # v-wind component at various levels + + # Sea ice concentration + "osisaf-south:ice_conc": { cmap: "Blues_r" } + + # Default fallback + "_default": { cmap: "viridis" } + + # Wildcard fallback for ERA5 channels + "era5:*": + origin: "upper" diff --git a/ice_station_zebra/config/evaluate/default.yaml b/ice_station_zebra/config/evaluate/default.yaml index aa312e3a..999d28f7 100644 --- a/ice_station_zebra/config/evaluate/default.yaml +++ b/ice_station_zebra/config/evaluate/default.yaml @@ -2,4 +2,5 @@ defaults: - callbacks: - metric_summary - plotting + - raw_inputs - _self_ diff --git a/ice_station_zebra/data_loaders/combined_dataset.py b/ice_station_zebra/data_loaders/combined_dataset.py index 2e605af7..bd0265b8 100644 --- a/ice_station_zebra/data_loaders/combined_dataset.py +++ b/ice_station_zebra/data_loaders/combined_dataset.py @@ -117,3 +117,20 @@ def start_date(self) -> np.datetime64: msg = f"Datasets have {len(start_date)} different start dates" raise ValueError(msg) return start_date.pop() + + @property + def input_variable_names(self) -> list[str]: + """Return all input variable names across all input datasets. + + Variable names are prefixed with the dataset name for disambiguation. + Format: "{dataset_name}:{variable_name}" + + Returns: + List of variable names in the order they appear in the combined input channels. + + """ + return [ + f"{ds.name}:{var_name}" + for ds in self.inputs + for var_name in ds.variable_names + ] diff --git a/ice_station_zebra/data_loaders/zebra_dataset.py b/ice_station_zebra/data_loaders/zebra_dataset.py index d6a682c0..46ae5b4a 100644 --- a/ice_station_zebra/data_loaders/zebra_dataset.py +++ b/ice_station_zebra/data_loaders/zebra_dataset.py @@ -94,6 +94,20 @@ def start_date(self) -> np.datetime64: """Return the start date of the dataset.""" return self.dates[0] + @cached_property + def variable_names(self) -> list[str]: + """Return the variable names for this dataset. + + The variable names are extracted from the underlying Anemoi dataset. + All datasets must have the same variables. + """ + # Check all datasets have the same variables + per_ds_variables = [tuple(ds.variables) for ds in self.datasets] + if len(set(per_ds_variables)) != 1: + msg = f"All datasets must have the same variables, found {len(set(per_ds_variables))} different sets" + raise ValueError(msg) + return list(self.datasets[0].variables) + def __len__(self) -> int: """Return the total length of the dataset.""" if self._len is None: diff --git a/ice_station_zebra/evaluation/evaluator.py b/ice_station_zebra/evaluation/evaluator.py index 11489388..5abc93d6 100644 --- a/ice_station_zebra/evaluation/evaluator.py +++ b/ice_station_zebra/evaluation/evaluator.py @@ -2,6 +2,11 @@ from pathlib import Path from typing import TYPE_CHECKING +# Set matplotlib backend BEFORE any imports that might use it +import matplotlib as mpl + +mpl.use("Agg") + import hydra from lightning.fabric.utilities import suggested_max_num_workers from omegaconf import DictConfig, OmegaConf diff --git a/ice_station_zebra/types/simple_datatypes.py b/ice_station_zebra/types/simple_datatypes.py index 4eba0fc4..5bdd77c7 100644 --- a/ice_station_zebra/types/simple_datatypes.py +++ b/ice_station_zebra/types/simple_datatypes.py @@ -49,7 +49,7 @@ class DiffColourmapSpec(NamedTuple): norm: Normalisation for mapping values to colours (e.g. TwoSlopeNorm for signed diffs). vmin: Lower bound if no norm is provided. vmax: Upper bound if no norm is provided. - cmap: Matplotlib colormap name. + cmap: Matplotlib colourmap name. """ @@ -69,7 +69,7 @@ class PlotSpec: title_prediction: Title above the prediction panel. title_difference: Title above the difference panel. n_contour_levels: Number of contour levels per panel. - colourmap: Colormap used for GT/prediction panels. + colourmap: colourmap used for GT/prediction panels. include_difference: Whether to draw a difference panel. diff_mode: Difference definition (e.g. "signed", "absolute", "smape"). diff_strategy: Strategy for animations (precompute, two-pass, per-frame). diff --git a/ice_station_zebra/visualisations/__init__.py b/ice_station_zebra/visualisations/__init__.py index e60635f0..ea88bf9a 100644 --- a/ice_station_zebra/visualisations/__init__.py +++ b/ice_station_zebra/visualisations/__init__.py @@ -1,3 +1,9 @@ +# Set non-interactive matplotlib backend early to prevent hangs on macOS +# make sure the import is not reordered by the formatter +import matplotlib as mpl + +mpl.use("Agg") + from ice_station_zebra.types import PlotSpec from .plotting_core import detect_land_mask_path diff --git a/ice_station_zebra/visualisations/animation_helper.py b/ice_station_zebra/visualisations/animation_helper.py new file mode 100644 index 00000000..2a84a423 --- /dev/null +++ b/ice_station_zebra/visualisations/animation_helper.py @@ -0,0 +1,40 @@ +"""Animation helper module for managing matplotlib animation lifecycle. + +Provides centralised cache management to prevent garbage collection of animation +objects during save operations, avoiding warnings and ensuring proper cleanup. +""" + +import contextlib + +from matplotlib import animation + +# Centralized cache to hold strong references to animation objects +_ANIM_CACHE: list[animation.FuncAnimation] = [] + + +def hold_anim(anim: animation.FuncAnimation) -> None: + """Add an animation object to the cache to prevent garbage collection. + + Call this immediately after creating a FuncAnimation object to ensure + it remains in memory during save operations. + + Args: + anim: The FuncAnimation object to cache. + + """ + _ANIM_CACHE.append(anim) + + +def release_anim(anim: animation.FuncAnimation) -> None: + """Remove an animation object from the cache after saving is complete. + + Call this in a finally block after saving the animation to allow + proper cleanup and garbage collection. + + Args: + anim: The FuncAnimation object to remove from cache. + + """ + with contextlib.suppress(ValueError): + # Animation already removed or never added - ignore + _ANIM_CACHE.remove(anim) diff --git a/ice_station_zebra/visualisations/cli.py b/ice_station_zebra/visualisations/cli.py new file mode 100644 index 00000000..c2c3548f --- /dev/null +++ b/ice_station_zebra/visualisations/cli.py @@ -0,0 +1,355 @@ +"""CLI commands for visualisation tasks.""" + +import logging +from pathlib import Path +from typing import Annotated, Any, cast + +import hydra +import numpy as np +import typer +from omegaconf import DictConfig, OmegaConf + +from ice_station_zebra.callbacks.raw_inputs_callback import RawInputsCallback +from ice_station_zebra.cli import hydra_adaptor +from ice_station_zebra.data_loaders import CombinedDataset, ZebraDataModule +from ice_station_zebra.visualisations.plotting_raw_inputs import ( + plot_raw_inputs_for_timestep, + video_raw_inputs_for_timesteps, +) + +# Create the typer app +visualisations_cli = typer.Typer(help="Visualisation commands") + +log = logging.getLogger(__name__) + + +def _extract_channel_arrays( + batch: dict[str, Any], test_dataset: CombinedDataset, timestep_idx: int = 0 +) -> list[np.ndarray]: + """Extract channel arrays from a dataset item. + + Args: + batch: Dataset item (returns 4D arrays [T, C, H, W]). + test_dataset: Test dataset instance. + timestep_idx: Index of timestep to extract (default: 0). + + Returns: + List of 2D channel arrays. + + """ + channel_arrays = [] + for ds in test_dataset.inputs: + if ds.name not in batch: + log.warning("Dataset %s not found in batch", ds.name) + continue + + input_data = batch[ds.name] # Shape: [T, C, H, W] - numpy array + timestep_data = input_data[timestep_idx] # Shape: [C, H, W] + channel_arrays.extend([timestep_data[c] for c in range(timestep_data.shape[0])]) + + return channel_arrays + + +def _collect_temporal_data( + test_dataset: CombinedDataset, start_idx: int, n_frames: int +) -> tuple[list[np.ndarray], list[str], list[Any]]: + """Collect temporal data for animation. + + Args: + test_dataset: Test dataset instance. + start_idx: Starting sample index. + n_frames: Number of frames to collect. + + Returns: + Tuple of (data_streams, channel_names, dates). + + """ + channel_names = test_dataset.input_variable_names + n_vars = len(channel_names) + + # Initialize data collection + temporal_data_per_var: list[list[np.ndarray]] = [[] for _ in range(n_vars)] + dates: list[Any] = [] + + for idx in range(start_idx, start_idx + n_frames): + batch = test_dataset[idx] + date = test_dataset.date_from_index(idx) + + # Collect data from all input datasets + var_idx = 0 + for ds in test_dataset.inputs: + if ds.name not in batch: + log.warning("Dataset %s not found in batch", ds.name) + continue + + input_data = batch[ds.name] # Shape: [T, C, H, W] + timestep_data = input_data[0] # Take first timestep, Shape: [C, H, W] + + # Add each channel + for c in range(timestep_data.shape[0]): + var_data = timestep_data[c] + temporal_data_per_var[var_idx].append(var_data) + var_idx += 1 + + dates.append(date) + + # Stack into list of 3D arrays [T, H, W] for each variable + data_streams = [ + np.stack(var_frames, axis=0) for var_frames in temporal_data_per_var + ] + + return data_streams, channel_names, dates + + +@visualisations_cli.command() +@hydra_adaptor +def plot_raw_inputs( + config: DictConfig, + sample_idx: Annotated[ + int, + typer.Option(help="Index of the sample to plot (default: 0)"), + ] = 0, + output_dir: Annotated[ + str | None, + typer.Option(help="Directory to save plots (overrides config if provided)"), + ] = None, +) -> None: + r"""Plot raw inputs for a single timestep from the test dataset. + + This command creates static plots of all input variables for a single sample + from the test dataset. Settings are read from config.evaluate.callbacks.raw_inputs + in your YAML config files. + + Args: + config: Hydra config (provided via --config-name option). + sample_idx: Index of the sample to plot (default: 0). + output_dir: Directory to save plots (overrides config if provided). + + Note: + You must specify --config-name to use your local config file (e.g., lfrance.local.yaml). + Settings are read from config.evaluate.callbacks.raw_inputs in that config. + + Example: + uv run zebra visualisations plot-raw-inputs \\ + --config-name lfrance.local.yaml \\ + --sample-idx 0 + + """ + # Instantiate callback from config to get all settings + raw_inputs_cfg = config.get("evaluate", {}).get("callbacks", {}).get("raw_inputs") + if raw_inputs_cfg is None: + # Create minimal config if not found + raw_inputs_cfg = {} + callback = ( + hydra.utils.instantiate(raw_inputs_cfg) + if raw_inputs_cfg + else RawInputsCallback() + ) + callback.config = OmegaConf.to_container(config, resolve=True) + + # Create data module and prepare data + data_module = ZebraDataModule(config) + data_module.prepare_data() + data_module.setup("test") + + # Get the test dataset + test_dataloader = data_module.test_dataloader() + test_dataset_raw = test_dataloader.dataset + if test_dataset_raw is None: + msg = "No test dataset available!" + raise ValueError(msg) + test_dataset = cast("CombinedDataset", test_dataset_raw) + + log.info("Test dataset has %d samples", len(test_dataset)) + + # Validate sample index + if sample_idx < 0 or sample_idx >= len(test_dataset): + msg = f"Sample index {sample_idx} out of range [0, {len(test_dataset)})" + raise ValueError(msg) + + # Get a sample from the dataset + batch = test_dataset[sample_idx] + date = test_dataset.date_from_index(sample_idx) + + log.info("Plotting raw inputs for date: %s", date) + + # Extract channel arrays + channel_arrays = _extract_channel_arrays(batch, test_dataset) + channel_names = test_dataset.input_variable_names + log.info("Total channels: %d", len(channel_names)) + + # Use callback's settings (with command-line override for output_dir) + plot_spec = callback.plot_spec + variable_styles = callback.variable_styles + save_dir = ( + Path(output_dir) + if output_dir + else callback.save_dir or Path("./raw_input_plots") + ) + + # Plot the raw inputs + results = plot_raw_inputs_for_timestep( + channel_arrays=channel_arrays, + channel_names=channel_names, + when=date, + plot_spec_base=plot_spec, + land_mask=None, # Will be auto-loaded if available + styles=variable_styles, + save_dir=save_dir, + ) + + log.info("Successfully plotted %d variables to %s", len(results), save_dir) + for var_name, _pil_img, saved_path in results: + if saved_path: + log.info(" - %s: %s", var_name, saved_path) + + +@visualisations_cli.command() +@hydra_adaptor +def animate_raw_inputs( + config: DictConfig, + n_frames: Annotated[ + int | None, + typer.Option( + help="Number of frames to include in animation (overrides config if provided)" + ), + ] = None, + output_dir: Annotated[ + str | None, + typer.Option( + help="Directory to save animations (overrides config if provided)" + ), + ] = None, + fps: Annotated[ + int | None, + typer.Option( + help="Frames per second for animation (overrides config if provided)" + ), + ] = None, + video_format: Annotated[ + str | None, + typer.Option( + help="Video format: 'gif' or 'mp4' (overrides config if provided)" + ), + ] = None, + start_idx: Annotated[ + int, + typer.Option(help="Starting sample index (default: 0)"), + ] = 0, +) -> None: + r"""Create animations of raw inputs over time from the test dataset. + + This command creates temporal animations showing how individual input variables + evolve over time. Settings are read from config.evaluate.callbacks.raw_inputs + in your YAML config files. + + Args: + config: Hydra config (provided via --config-name option). + n_frames: Number of frames to include in animation (overrides config if provided). + output_dir: Directory to save animations (overrides config if provided). + fps: Frames per second for animation (overrides config if provided). + video_format: Video format: 'gif' or 'mp4' (overrides config if provided). + start_idx: Starting sample index (default: 0). + + Note: + You must specify --config-name to use your local config file (e.g., lfrance.local.yaml). + Settings are read from config.evaluate.callbacks.raw_inputs in that config. + + Example: + uv run zebra visualisations animate-raw-inputs \\ + --config-name lfrance.local.yaml \\ + --start-idx 0 + + """ + # Instantiate callback from config to get all settings + raw_inputs_cfg = config.get("evaluate", {}).get("callbacks", {}).get("raw_inputs") + if raw_inputs_cfg is None: + raw_inputs_cfg = {} + callback = ( + hydra.utils.instantiate(raw_inputs_cfg) + if raw_inputs_cfg + else RawInputsCallback() + ) + callback.config = OmegaConf.to_container(config, resolve=True) + + # Get settings from callback (with command-line overrides) + n_frames = n_frames or callback.max_animation_frames or 30 + fps = fps or callback.video_fps + video_format = video_format or callback.video_format + save_dir = ( + Path(output_dir) + if output_dir + else callback.video_save_dir or Path("./raw_input_animations") + ) + + # Validate video format + if video_format not in ("gif", "mp4"): + msg = f"Video format must be 'gif' or 'mp4', got '{video_format}'" + raise ValueError(msg) + + # Create data module and prepare data + data_module = ZebraDataModule(config) + data_module.prepare_data() + data_module.setup("test") + + # Get the test dataset + test_dataloader = data_module.test_dataloader() + test_dataset_raw = test_dataloader.dataset + if test_dataset_raw is None: + msg = "No test dataset available!" + raise ValueError(msg) + test_dataset = cast("CombinedDataset", test_dataset_raw) + + log.info("Test dataset has %d samples", len(test_dataset)) + + # Determine number of frames + max_frames = len(test_dataset) - start_idx + n_frames = min(n_frames, max_frames) + if n_frames <= 0: + msg = f"Not enough samples (start_idx={start_idx}, dataset_size={len(test_dataset)})" + raise ValueError(msg) + + log.info( + "Creating animations with %d frames starting from index %d", n_frames, start_idx + ) + + # Collect temporal data for all variables + data_streams, channel_names, dates = _collect_temporal_data( + test_dataset, start_idx, n_frames + ) + + log.info( + "Collected data streams: %d variables x %s", + len(data_streams), + data_streams[0].shape if data_streams else "N/A", + ) + + # Use callback's settings + plot_spec = callback.plot_spec + variable_styles = callback.variable_styles + + # Create animations for all variables + log.info("Creating batch animations...") + results = video_raw_inputs_for_timesteps( + channel_arrays_stream=data_streams, + channel_names=channel_names, + dates=dates, + plot_spec_base=plot_spec, + styles=variable_styles, + fps=fps, + video_format=video_format, # type: ignore[arg-type] + save_dir=save_dir, + ) + + log.info("Successfully created %d animations:", len(results)) + for var_name, video_buffer, save_path in results: + log.info( + " - %s: %s (size: %d bytes)", + var_name, + save_path, + len(video_buffer.getvalue()), + ) + + +if __name__ == "__main__": + visualisations_cli() diff --git a/ice_station_zebra/visualisations/convert.py b/ice_station_zebra/visualisations/convert.py index 9b093276..a0b69504 100644 --- a/ice_station_zebra/visualisations/convert.py +++ b/ice_station_zebra/visualisations/convert.py @@ -1,5 +1,9 @@ +import contextlib +import gc import io +import logging import tempfile +from collections.abc import Iterator from pathlib import Path from typing import Literal @@ -15,65 +19,90 @@ DEFAULT_DPI = 200 -def _image_from_figure(fig: Figure) -> ImageFile: - """Convert a matplotlib figure to a PIL image file.""" +@contextlib.contextmanager +def _suppress_mpl_animation_logs() -> Iterator[None]: + """Temporarily suppress matplotlib animation INFO log messages.""" + mpl_logger = logging.getLogger("matplotlib.animation") + original_level = mpl_logger.level + try: + mpl_logger.setLevel(logging.WARNING) + yield + finally: + mpl_logger.setLevel(original_level) + + +def image_from_figure(fig: Figure) -> ImageFile: + """Convert a matplotlib figure to a PIL image file. + + Uses the same save parameters as save_figure for consistency: + - dpi=300 for matching resolution + - bbox_inches="tight" to crop to content (matching disk saves) + """ buf = io.BytesIO() - fig.savefig(buf, format="png") + fig.savefig(buf, format="png", dpi=300, bbox_inches="tight") buf.seek(0) return Image.open(buf) -def _image_from_array(a: np.ndarray) -> ImageFile: +def image_from_array(a: np.ndarray) -> ImageFile: """Convert a numpy array to a PIL image file.""" fig, ax = plt.subplots(figsize=(6, 6)) ax.imshow(a) ax.axis("off") try: - return _image_from_figure(fig) + return image_from_figure(fig) finally: plt.close(fig) -def _save_animation( +def save_animation( anim: animation.FuncAnimation, *, - fps: int | None = None, + fps: int = 2, video_format: Literal["mp4", "gif"] = "gif", - _fps: int | None = None, - _video_format: Literal["mp4", "gif"] | None = None, ) -> io.BytesIO: """Save an animation to a temporary file and return BytesIO (with cleanup).""" - # Accept both standard and underscored names for test compatibility - fps_value: int = int(fps if fps is not None else (_fps if _fps is not None else 2)) - if _video_format is not None: - video_format = _video_format # prefer underscored override if provided + fps_value: int = int(fps) suffix = ".gif" if video_format.lower() == "gif" else ".mp4" - with tempfile.NamedTemporaryFile(suffix=suffix, delete=True) as tmp: - try: - # Save video to tempfile - writer: animation.AbstractMovieWriter = ( - animation.PillowWriter(fps=fps_value) - if suffix == ".gif" - else animation.FFMpegWriter( - fps=fps_value, - codec="libx264", - bitrate=1800, - # Ensure dimensions are compatible with yuv420p (even width/height) - # by applying a scale filter that truncates to the nearest even integers. - extra_args=[ - "-pix_fmt", - "yuv420p", - "-vf", - "scale=trunc(iw/2)*2:trunc(ih/2)*2", - ], + + writer: animation.AbstractMovieWriter | None = None + try: + with tempfile.NamedTemporaryFile(suffix=suffix, delete=True) as tmp: + try: + # Save video to tempfile + writer = ( + animation.PillowWriter(fps=fps_value) + if suffix == ".gif" + else animation.FFMpegWriter( + fps=fps_value, + codec="libx264", + bitrate=1800, + # Ensure dimensions are compatible with yuv420p (even width/height) + # by applying a scale filter that truncates to the nearest even integers. + extra_args=[ + "-pix_fmt", + "yuv420p", + "-vf", + "scale=trunc(iw/2)*2:trunc(ih/2)*2", + ], + ) ) - ) - anim.save(tmp.name, writer=writer, dpi=DEFAULT_DPI) - # Load tempfile into a BytesIO buffer - with Path(tmp.name).open("rb") as fh: - buffer = io.BytesIO(fh.read()) - except (OSError, MemoryError) as err: - msg = f"Video encoding failed: {err!s}" - raise VideoRenderError(msg) from err + # Suppress matplotlib's INFO log message about writer selection + with _suppress_mpl_animation_logs(): + anim.save(tmp.name, writer=writer, dpi=DEFAULT_DPI) + # Load tempfile into a BytesIO buffer + with Path(tmp.name).open("rb") as fh: + buffer = io.BytesIO(fh.read()) + except (OSError, MemoryError) as err: + msg = f"Video encoding failed: {err!s}" + raise VideoRenderError(msg) from err + finally: + # Explicitly cleanup writer to prevent semaphore leaks + if writer is not None and hasattr(writer, "cleanup"): + with contextlib.suppress(OSError, RuntimeError, AttributeError): + writer.cleanup() + # Force garbage collection to clean up any remaining resources + gc.collect() + buffer.seek(0) return buffer diff --git a/ice_station_zebra/visualisations/layout.py b/ice_station_zebra/visualisations/layout.py index cde46fe0..4c383565 100644 --- a/ice_station_zebra/visualisations/layout.py +++ b/ice_station_zebra/visualisations/layout.py @@ -15,7 +15,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import contextlib +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Literal import matplotlib.pyplot as plt import numpy as np @@ -33,59 +35,317 @@ from .plotting_core import DiffColourmapSpec, PlotSpec -# Layout constants +import logging + +logger = logging.getLogger(__name__) + +# Panel index constants (used for layout logic) PREDICTION_PANEL_INDEX = 1 DIFFERENCE_PANEL_INDEX = 2 MIN_PANEL_COUNT_FOR_PREDICTION = 2 MIN_PANEL_COUNT_FOR_DIFFERENCE = 3 - -# -# --- LAYOUT CONFIGURATION CONSTANTS --- -# These constants define the default proportional spacing for the GridSpec layout. -# All values are expressed as fractions of the total figure dimensions to ensure -# consistent scaling across different figure sizes. - -# Spacing and margin constants (as fractions of figure dimensions) -DEFAULT_OUTER_MARGIN = 0.05 # Outer margin around entire figure (prevents clipping) -DEFAULT_GUTTER_HORIZONTAL = 0.03 # Smaller gaps when colourbar is below -DEFAULT_GUTTER_VERTICAL = 0.03 # Default for side-by-side with vertical bars -DEFAULT_CBAR_WIDTH = ( - 0.06 # Width allocated for colourbar slots (fraction of panel width) -) -DEFAULT_TITLE_SPACE = 0.07 # Reduced: simple title + warning badge + axes titles -DEFAULT_FOOTER_SPACE = 0.08 # Increased space reserved at bottom for metadata footer - -# Horizontal colourbar sizing (fractions of figure height) -DEFAULT_CBAR_HEIGHT = ( - 0.07 # Thickness of horizontal colourbar row (relative to plot height) -) -DEFAULT_CBAR_PAD = 0.03 # Vertical gap between plots and the colourbar row - -# Horizontal colourbar width within its slot (inset fraction within parent slot) +# Horizontal colourbar inset constants HCBAR_WIDTH_FRAC = 0.75 # centred 75% width HCBAR_LEFT_FRAC = (1.0 - HCBAR_WIDTH_FRAC) / 2.0 -# Default figure sizes for different panel configurations -DEFAULT_FIGSIZE_TWO_PANELS = (12, 6) # Ground truth + Prediction -DEFAULT_FIGSIZE_THREE_PANELS = (18, 6) # Ground truth + Prediction + Difference +# Small epsilon for numerical stability +EPSILON_SMALL = 1e-6 + +# --- Layout Configuration Dataclasses --- + + +@dataclass(frozen=True) +class ColourbarConfig: + """Colourbar sizing and clamp behaviour (fractions of figure). + + Example: + # Wider colourbar with higher max fraction + cbar_config = ColourbarConfig(default_width_frac=0.08, max_fraction_of_fig=0.15) + + """ + + default_width_frac: float = ( + 0.06 # Width allocated for colourbar slots (fraction of panel width) + ) + min_width_frac: float = 0.01 # Minimum colourbar width clamp + max_width_frac: float = 0.5 # Maximum colourbar width clamp + desired_physical_width_in: float = ( + 0.45 # Desired physical colourbar width in inches + ) + max_fraction_of_fig: float = 0.12 # Maximum colourbar width as fraction of figure + default_height_frac: float = 0.07 # Thickness of horizontal colourbar row + default_pad_frac: float = 0.03 # Vertical gap between plots and colourbar row + + def clamp_frac(self, frac: float, fig_width_in: float) -> float: + """Return a sane clamped fraction for the cbar slot based on figure width. + + Computes a physical-limit fraction derived from desired physical width, + then clamps the input fraction between min/max bounds. + """ + # Compute a physical-limit fraction derived from desired physical width + phys_frac = min( + self.max_fraction_of_fig, + self.desired_physical_width_in / max(fig_width_in, EPSILON_SMALL), + ) + return float( + max(self.min_width_frac, min(frac, phys_frac, self.max_width_frac)) + ) + + +@dataclass(frozen=True) +class TitleFooterConfig: + """Title and footer spacing, positioning, and styling.""" + + title_space: float = 0.07 # Fraction of figure height reserved for title + footer_space: float = 0.08 # Fraction of figure height reserved for footer + title_fontsize: int = 12 # Font size for title + footer_fontsize: int = 11 # Font size for footer and badge + title_y: float = 0.98 # Y position for title (near top, in figure coordinates) + footer_y: float = 0.03 # Y position for footer (near bottom, in figure coordinates) + bbox_pad_title: float = 2.0 # Padding for title bbox + bbox_pad_badge: float = 1.5 # Padding for badge bbox + zorder_high: int = 1000 # High z-order for text overlays + + +@dataclass(frozen=True) +class GapConfig: + """Aspect-aware physical gap settings for single-panel layouts.""" + + base: float = 0.15 # Preferred gap for aspect≈1 panels + min_val: float = 0.10 # Hard minimum + max_val: float = 0.22 # Hard maximum + wide_limit: float = 4.0 # Aspect ratio at which we reach the "wide" gap + tall_limit: float = 0.5 # Aspect ratio at which we reach the "tall" gap + wide_gap: float = 0.11 # Gap to use for very wide panels + tall_gap: float = 0.19 # Gap to use for very tall panels + + +@dataclass(frozen=True) +class SinglePanelSpacing: + """Encapsulate spacing controls for standalone panels.""" + + gap: GapConfig = field(default_factory=GapConfig) + outer_buffer_in: float = 0.16 # Padding outside map/cbar (both sides) + edge_guard_in: float = 0.25 # Minimum blank space beyond colourbar for ticks + right_margin_scale: float = 0.5 # Scale factor for right margin adjustment + right_margin_offset: float = 0.02 # Offset for right margin adjustment + + +@dataclass(frozen=True) +class FormattingConfig: + """Tick formatting and fallback value constants.""" + + num_ticks_linear: int = 5 # Number of ticks for linear colourbars + midpoint_factor: float = ( + 0.5 # Factor for calculating midpoint values in symmetric ticks + ) + default_vmin_fallback: float = 0.0 # Fallback minimum value + default_vmax_fallback: float = 1.0 # Fallback maximum value + default_vmin_diff_fallback: float = -1.0 # Fallback minimum for difference plots + default_vmax_diff_fallback: float = 1.0 # Fallback maximum for difference plots + + +@dataclass(frozen=True) +class LayoutConfig: + """Top-level layout tuning surface - passed into build functions. + + Groups all layout-related configuration into a single, discoverable object. + This makes it easy to override defaults for testing or custom layouts. + + Example: + # Wider colourbar and less title space + my_layout = LayoutConfig( + colourbar=ColourbarConfig(default_width_frac=0.08), + title_footer=TitleFooterConfig(title_space=0.04) + ) + fig, axs, cax = build_single_panel_figure( + height=200, width=300, layout_config=my_layout, colourbar_location="vertical" + ) + + """ + + base_height_in: float = 6.0 # Standard figure height in inches + outer_margin: float = 0.05 # Outer margin around entire figure (prevents clipping) + gutter_vertical: float = 0.03 # Default for side-by-side with vertical bars + gutter_horizontal: float = 0.03 # Smaller gaps when colourbar is below + colourbar: ColourbarConfig = field(default_factory=ColourbarConfig) + title_footer: TitleFooterConfig = field(default_factory=TitleFooterConfig) + single_panel_spacing: SinglePanelSpacing = field(default_factory=SinglePanelSpacing) + formatting: FormattingConfig = field(default_factory=FormattingConfig) + min_plot_fraction: float = ( + 0.2 # Minimum plot width/height as fraction of available space + ) + min_usable_height_fraction: float = ( + 0.6 # Minimum fraction of figure height for plotting area + ) + default_figsizes: dict[int, tuple[float, float]] = field( + default_factory=lambda: { + 1: (8, 6), # Single panel + 2: (12, 6), # Ground truth + Prediction + 3: (18, 6), # Ground truth + Prediction + Difference + } + ) + +# Default layout configuration instance (used when layout_config is None) +_DEFAULT_LAYOUT_CONFIG = LayoutConfig() # --- Main Layout Functions --- -def _build_layout( # noqa: PLR0913 +def build_single_panel_figure( # noqa: PLR0913, PLR0915 + *, + height: int, + width: int, + layout_config: LayoutConfig | None = None, + colourbar_location: Literal["vertical", "horizontal"] = "vertical", + cbar_width: float | None = None, + cbar_height: float | None = None, + cbar_pad: float | None = None, +) -> tuple[Figure, Axes, Axes]: + """Create a single-panel figure with layout consistent with multi-panel maps. + + Args: + height: Data height in pixels. + width: Data width in pixels. + layout_config: Optional LayoutConfig instance to override defaults. + If None, uses default configuration. + colourbar_location: Orientation of colourbar ("vertical" or "horizontal"). + cbar_width: Optional override for colourbar width fraction. + cbar_height: Optional override for colourbar height fraction (horizontal only). + cbar_pad: Optional override for colourbar padding. + + """ + if height <= 0 or width <= 0: + msg = "height and width must be positive for single panel layout." + raise ValueError(msg) + + layout = layout_config or _DEFAULT_LAYOUT_CONFIG + outer_margin = layout.outer_margin + title_space = layout.title_footer.title_space + footer_space = layout.title_footer.footer_space + + base_h = layout.base_height_in + aspect = width / max(1, height) + fig_w = base_h * aspect + fig = plt.figure( + figsize=(fig_w, base_h), constrained_layout=False, facecolor="none" + ) + + top_val = max(layout.min_usable_height_fraction, 1.0 - (outer_margin + title_space)) + bottom_val = outer_margin + footer_space + usable_height = top_val - bottom_val + + if colourbar_location == "vertical": + # --- Interpret cbar_width as "fraction of panel width" --- + cw_panel = float( + layout.colourbar.default_width_frac if cbar_width is None else cbar_width + ) + fig_w_in = float(fig.get_size_inches()[0]) + cw_panel = layout.colourbar.clamp_frac(cw_panel, fig_w_in) + + spacing = layout.single_panel_spacing + + extra_left_frac = spacing.outer_buffer_in / max(fig_w_in, EPSILON_SMALL) + extra_right_frac = spacing.outer_buffer_in / max(fig_w_in, EPSILON_SMALL) + + if cbar_pad is None: + cbar_pad_inches = _default_vertical_gap_inches(aspect, spacing.gap) + logger.debug( + "single-panel vertical gap auto: w=%d h=%d aspect=%.3f gap=%.4fin", + width, + height, + aspect, + cbar_pad_inches, + ) + else: + base_pad_inches = float(cbar_pad) * fig_w_in + cbar_pad_inches = np.clip( + base_pad_inches, spacing.gap.min_val, spacing.gap.max_val + ) + logger.debug( + "single-panel vertical gap override: w=%d h=%d aspect=%.3f raw=%.4fin clipped=%.4fin", + width, + height, + aspect, + base_pad_inches, + cbar_pad_inches, + ) + cbar_pad = cbar_pad_inches / max(fig_w_in, EPSILON_SMALL) + + # --- Right margin (slightly smaller than left) with physical safeguard --- + right_margin = max( + outer_margin * spacing.right_margin_scale, + outer_margin - spacing.right_margin_offset, + ) + right_margin_inches = right_margin * fig_w_in + if right_margin_inches < spacing.edge_guard_in: + right_margin = spacing.edge_guard_in / max(fig_w_in, EPSILON_SMALL) + + # === 2) Compute plot width from panel-fraction rule === + available = 1.0 - ( + outer_margin + extra_left_frac + cbar_pad + right_margin + extra_right_frac + ) + plot_width_candidate = max( + layout.min_plot_fraction, available / (1.0 + cw_panel) + ) + aspect_width_target = usable_height + plot_width = min(plot_width_candidate, aspect_width_target) + + # initial colourbar width (fraction of figure) + cax_width = cw_panel * plot_width + + # === 3) CAP COLOURBAR PHYSICAL WIDTH === + # Compute maximum allowed fraction based on desired physical width + phys_frac = layout.colourbar.desired_physical_width_in / max( + fig_w_in, EPSILON_SMALL + ) + max_cax_frac = min(layout.colourbar.max_fraction_of_fig, phys_frac) + + if cax_width > max_cax_frac: + cax_width = max_cax_frac + plot_width = min( + aspect_width_target, + max( + layout.min_plot_fraction, + 1.0 - (outer_margin + cbar_pad + cax_width + right_margin), + ), + ) + + # === 4) Place axes === + plot_left = outer_margin + extra_left_frac + ax = fig.add_axes((plot_left, bottom_val, plot_width, usable_height)) + + cax_left = plot_left + plot_width + cbar_pad + cax = fig.add_axes((cax_left, bottom_val, cax_width, usable_height)) + + else: + cbar_height = ( + layout.colourbar.default_height_frac if cbar_height is None else cbar_height + ) + cbar_pad = layout.colourbar.default_pad_frac if cbar_pad is None else cbar_pad + + plot_left = outer_margin + plot_width = 1.0 - 2 * outer_margin + plot_height = usable_height - (cbar_height + cbar_pad) + plot_height = max(plot_height, layout.min_plot_fraction) + + ax_bottom = bottom_val + cbar_height + cbar_pad + ax = fig.add_axes((plot_left, ax_bottom, plot_width, plot_height)) + cax = fig.add_axes((plot_left, bottom_val, plot_width, cbar_height)) + + _style_axes([ax]) + _set_axes_limits([ax], width=width, height=height) + return fig, ax, cax + + +def build_layout( *, plot_spec: PlotSpec, height: int | None = None, width: int | None = None, - outer_margin: float = DEFAULT_OUTER_MARGIN, - gutter: float | None = None, - cbar_width: float = DEFAULT_CBAR_WIDTH, - title_space: float = DEFAULT_TITLE_SPACE, - footer_space: float = DEFAULT_FOOTER_SPACE, - cbar_height: float = DEFAULT_CBAR_HEIGHT, - cbar_pad: float = DEFAULT_CBAR_PAD, + layout_config: LayoutConfig | None = None, ) -> tuple[Figure, list[Axes], dict[str, Axes | None]]: """Create a GridSpec layout for multi-panel plots. @@ -106,19 +366,8 @@ def _build_layout( # noqa: PLR0913 height: Optional data height for aspect-ratio-aware figure sizing. If provided with width, the figure dimensions will be calculated to maintain proper data aspect ratios. width: Optional data width for aspect-ratio-aware figure sizing. - outer_margin: Fraction of figure dimensions reserved for outer margins (prevents - plot elements from being clipped at figure edges). - gutter: Fraction of panel width used as horizontal spacing between panel groups. - Only applied between prediction+colourbar and difference panel. - cbar_width: Fraction of panel width allocated for each colourbar slot. - title_space: Fraction of figure height reserved at the top for figure title - (prevents title from overlapping with plot content). - footer_space: Fraction of figure height reserved at the bottom for the metadata - footer so it does not overlap colourbars. - cbar_height: Height fraction for the horizontal colourbar row (row 2 when - orientation is 'horizontal'). Controls the bar thickness. - cbar_pad: Vertical gap fraction between the plot row and the colourbar row in - horizontal layouts. Set to 0.0 for flush rows. + layout_config: Optional LayoutConfig instance to override default layout parameters. + If None, uses default configuration. Returns: tuple containing: @@ -132,28 +381,35 @@ def _build_layout( # noqa: PLR0913 InvalidArrayError: If the arrays are not 2D or have different shapes. """ + layout = layout_config or _DEFAULT_LAYOUT_CONFIG + outer_margin = layout.outer_margin + title_space = layout.title_footer.title_space + footer_space = layout.title_footer.footer_space + cbar_width = layout.colourbar.default_width_frac + cbar_height = layout.colourbar.default_height_frac + cbar_pad = layout.colourbar.default_pad_frac + # Decide how many main panels are required and which orientation the colourbars use n_panels = 3 if plot_spec.include_difference else 2 orientation = plot_spec.colourbar_location - # Choose gutter default per orientation if not provided - if gutter is None: - gutter = ( - DEFAULT_GUTTER_HORIZONTAL - if orientation == "horizontal" - else DEFAULT_GUTTER_VERTICAL - ) + # Choose gutter default per orientation + gutter = ( + layout.gutter_horizontal + if orientation == "horizontal" + else layout.gutter_vertical + ) # Calculate top boundary: ensure title space does not consume too much of the figure. # At least 60% of the figure height is reserved for the plotting area. - top_val = max(0.6, 1.0 - (outer_margin + title_space)) + top_val = max(layout.min_usable_height_fraction, 1.0 - (outer_margin + title_space)) # Calculate bottom boundary, reserving footer space for metadata bottom_val = outer_margin + footer_space # Calculate figure size based on data aspect ratio or use defaults if height and width and height > 0: # Calculate panel width maintaining data aspect ratio - base_h = 6.0 # Standard height in inches + base_h = layout.base_height_in # Standard height in inches aspect = width / height if orientation == "vertical": @@ -177,11 +433,7 @@ def _build_layout( # noqa: PLR0913 fig_size = (fig_w, base_h) else: # Use predefined sizes when data dimensions are unknown - fig_size = ( - DEFAULT_FIGSIZE_THREE_PANELS - if plot_spec.include_difference - else DEFAULT_FIGSIZE_TWO_PANELS - ) + fig_size = layout.default_figsizes[n_panels] fig = plt.figure(figsize=fig_size, constrained_layout=False, facecolor="none") @@ -477,7 +729,7 @@ def _set_titles(axs: list[Axes], plot_spec: PlotSpec) -> None: bbox={ "facecolor": "white", "edgecolor": "none", - "pad": 2.0, + "pad": _DEFAULT_LAYOUT_CONFIG.title_footer.bbox_pad_title, "alpha": 1.0, }, ) @@ -525,7 +777,7 @@ def _set_axes_limits(axs: list[Axes], *, width: int, height: int) -> None: # --- Colourbar Functions --- -def _get_cbar_limits_from_mappable(cbar: Colorbar) -> tuple[float, float]: +def get_cbar_limits_from_mappable(cbar: Colorbar) -> tuple[float, float]: """Return (vmin, vmax) for a colourbar's mappable with robust fallbacks.""" vmin = vmax = None try: # Works for many matplotlib mappables @@ -535,7 +787,10 @@ def _get_cbar_limits_from_mappable(cbar: Colorbar) -> tuple[float, float]: vmin = getattr(norm, "vmin", None) vmax = getattr(norm, "vmax", None) if vmin is None or vmax is None: - vmin, vmax = 0.0, 1.0 + vmin, vmax = ( + _DEFAULT_LAYOUT_CONFIG.formatting.default_vmin_fallback, + _DEFAULT_LAYOUT_CONFIG.formatting.default_vmax_fallback, + ) return float(vmin), float(vmax) @@ -555,7 +810,7 @@ def _add_colourbars( # noqa: PLR0913, PLR0912 1. Shared colourbar for ground truth and prediction (same data scale) 2. Separate colourbar for difference data (different scale, often symmetric) - The function works with the layout from _build_layout, using dedicated + The function works with the layout from build_layout, using dedicated colourbar axes when available, or falling back to automatic matplotlib placement. Colourbar Design: @@ -594,7 +849,7 @@ def _add_colourbars( # noqa: PLR0913, PLR0912 colourbar_groundtruth = plt.colorbar( image_groundtruth, cax=cbar_axes["groundtruth"], orientation=orientation ) - _format_linear_ticks( + format_linear_ticks( colourbar_groundtruth, vmin=float(plot_spec.vmin) if plot_spec.vmin is not None else None, vmax=float(plot_spec.vmax) if plot_spec.vmax is not None else None, @@ -607,7 +862,7 @@ def _add_colourbars( # noqa: PLR0913, PLR0912 colourbar_prediction = plt.colorbar( image_prediction, cax=cbar_axes["prediction"], orientation=orientation ) - _format_linear_ticks( + format_linear_ticks( colourbar_prediction, decimals=1, is_vertical=is_vertical ) else: @@ -624,7 +879,7 @@ def _add_colourbars( # noqa: PLR0913, PLR0912 ) # Tick formatting - _format_linear_ticks( + format_linear_ticks( colourbar_truth, vmin=float(plot_spec.vmin) if plot_spec.vmin is not None else None, vmax=float(plot_spec.vmax) if plot_spec.vmax is not None else None, @@ -666,71 +921,118 @@ def _add_colourbars( # noqa: PLR0913, PLR0912 # Tick formatting: symmetric for TwoSlopeNorm, otherwise linear if isinstance(image_difference.norm, TwoSlopeNorm): - vmin = float(image_difference.norm.vmin or -1.0) - vmax = float(image_difference.norm.vmax or 1.0) - _format_symmetric_ticks( + vmin = float( + image_difference.norm.vmin + or _DEFAULT_LAYOUT_CONFIG.formatting.default_vmin_diff_fallback + ) + vmax = float( + image_difference.norm.vmax + or _DEFAULT_LAYOUT_CONFIG.formatting.default_vmax_diff_fallback + ) + format_symmetric_ticks( colourbar_diff, vmin=vmin, vmax=vmax, decimals=2, is_vertical=is_vertical, + centre=image_difference.norm.vcenter, ) else: - _format_linear_ticks(colourbar_diff, decimals=2, is_vertical=is_vertical) + format_linear_ticks(colourbar_diff, decimals=2, is_vertical=is_vertical) # --- Tick Formatting Functions --- -def _format_linear_ticks( +def format_linear_ticks( colourbar: Colorbar, *, vmin: float | None = None, vmax: float | None = None, decimals: int = 1, is_vertical: bool, + use_scientific_notation: bool = False, ) -> None: """Format a linear colourbar with 5 ticks. If vmin/vmax are not provided, derive them from the colourbar's mappable. + + Args: + colourbar: Colorbar to format. + vmin: Minimum value. + vmax: Maximum value. + decimals: Number of decimal places for tick labels. + is_vertical: Whether the colorbar is vertical. + use_scientific_notation: Whether to format tick labels in scientific notation. + """ axis = colourbar.ax.yaxis if is_vertical else colourbar.ax.xaxis if vmin is None or vmax is None: - mvmin, mvmax = _get_cbar_limits_from_mappable(colourbar) + mvmin, mvmax = get_cbar_limits_from_mappable(colourbar) vmin = mvmin if vmin is None else vmin vmax = mvmax if vmax is None else vmax - ticks = np.linspace(float(vmin), float(vmax), 5) + ticks = np.linspace( + float(vmin), float(vmax), _DEFAULT_LAYOUT_CONFIG.formatting.num_ticks_linear + ) colourbar.set_ticks([float(t) for t in ticks]) - axis.set_major_formatter(FuncFormatter(lambda x, _: f"{x:.{decimals}f}")) + + if use_scientific_notation: + axis.set_major_formatter(FuncFormatter(lambda x, _: f"{x:.{decimals}e}")) + else: + axis.set_major_formatter(FuncFormatter(lambda x, _: f"{x:.{decimals}f}")) if not is_vertical: colourbar.ax.xaxis.set_tick_params(pad=1) - _apply_monospace_to_cbar_text(colourbar) + apply_monospace_to_cbar_text(colourbar) -def _format_symmetric_ticks( +def format_symmetric_ticks( # noqa: PLR0913 colourbar: Colorbar, *, vmin: float, vmax: float, decimals: int = 2, is_vertical: bool, + centre: float | None = None, + use_scientific_notation: bool = False, ) -> None: - """Format symmetric diverging ticks with a 0-centered midpoint. + """Format symmetric diverging ticks with a centred midpoint. + + Places five ticks: [vmin, midpoint to centre, centre, centre to midpoint, vmax]. + + Args: + colourbar: Colorbar to format. + vmin: Minimum value. + vmax: Maximum value. + decimals: Number of decimal places for tick labels. + is_vertical: Whether the colorbar is vertical. + centre: centre value for diverging colourmap (default: 0.0). + use_scientific_notation: Whether to format tick labels in scientific notation. - Places five ticks: [vmin, midpoint to 0, 0, 0 to midpoint, vmax]. """ axis = colourbar.ax.yaxis if is_vertical else colourbar.ax.xaxis - ticks = [vmin, 0.5 * (vmin + 0.0), 0.0, 0.5 * (0.0 + vmax), vmax] + centre_val = centre if centre is not None else 0.0 + midpoint_factor = _DEFAULT_LAYOUT_CONFIG.formatting.midpoint_factor + ticks = [ + vmin, + midpoint_factor * (vmin + centre_val), + centre_val, + midpoint_factor * (centre_val + vmax), + vmax, + ] colourbar.set_ticks([float(t) for t in ticks]) - axis.set_major_formatter(FuncFormatter(lambda x, _: f"{x:.{decimals}f}")) + + if use_scientific_notation: + axis.set_major_formatter(FuncFormatter(lambda x, _: f"{x:.{decimals}e}")) + else: + axis.set_major_formatter(FuncFormatter(lambda x, _: f"{x:.{decimals}f}")) if not is_vertical: colourbar.ax.xaxis.set_tick_params(pad=1) - _apply_monospace_to_cbar_text(colourbar) + apply_monospace_to_cbar_text(colourbar) -def _apply_monospace_to_cbar_text(colourbar: Colorbar) -> None: +def apply_monospace_to_cbar_text(colourbar: Colorbar) -> None: """Set tick labels and axis labels on a colourbar to monospace family.""" ax = colourbar.ax for label in list(ax.get_xticklabels()) + list(ax.get_yticklabels()): @@ -738,3 +1040,133 @@ def _apply_monospace_to_cbar_text(colourbar: Colorbar) -> None: # Ensure axis labels also use monospace if present ax.xaxis.label.set_fontfamily("monospace") ax.yaxis.label.set_fontfamily("monospace") + + +def _default_vertical_gap_inches(aspect: float, cfg: GapConfig) -> float: + """Return an aspect-aware physical gap between panel and vertical colourbar.""" + aspect = max(aspect, EPSILON_SMALL) + if aspect >= 1.0: + wide = min(aspect, cfg.wide_limit) + if np.isclose(cfg.wide_limit, 1.0): + return cfg.base + t = (wide - 1.0) / (cfg.wide_limit - 1.0) + gap = cfg.base - t * (cfg.base - cfg.wide_gap) + else: + tall_limit = max(cfg.tall_limit, EPSILON_SMALL) + tall = min(1.0 / aspect, 1.0 / tall_limit) + denom = (1.0 / tall_limit) - 1.0 + t = 0.0 if np.isclose(denom, 0.0) else (tall - 1.0) / denom + gap = cfg.base + t * (cfg.tall_gap - cfg.base) + return float(np.clip(gap, cfg.min_val, cfg.max_val)) + + +# --- Text and Box Annotation Functions --- + + +def set_suptitle_with_box(fig: Figure, text: str) -> plt.Text: + """Draw a fixed-position title with a white box that doesn't influence layout. + + Returns the Text artist so callers can update with set_text during animation. + This version avoids kwargs that are unsupported on older Matplotlib. + + Args: + fig: Matplotlib Figure object. + text: Title text to display. + + Returns: + Text artist for the title. + + """ + config = _DEFAULT_LAYOUT_CONFIG.title_footer + bbox = { + "facecolor": "white", + "edgecolor": "none", + "pad": config.bbox_pad_title, + "alpha": 1.0, + } + t = fig.text( + x=0.5, + y=config.title_y, + s=text, + ha="center", + va="top", + fontsize=config.title_fontsize, + fontfamily="monospace", + transform=fig.transFigure, + bbox=bbox, + ) + with contextlib.suppress(Exception): + t.set_zorder(config.zorder_high) + return t + + +def set_footer_with_box(fig: Figure, text: str) -> plt.Text: + """Draw a fixed-position footer with a white box at bottom centre. + + Footer is intended for metadata and secondary information. + + Args: + fig: Matplotlib Figure object. + text: Footer text to display. + + Returns: + Text artist for the footer. + + """ + config = _DEFAULT_LAYOUT_CONFIG.title_footer + bbox = { + "facecolor": "white", + "edgecolor": "none", + "pad": config.bbox_pad_title, + "alpha": 1.0, + } + t = fig.text( + x=0.5, + y=config.footer_y, + s=text, + ha="center", + va="bottom", + fontsize=config.footer_fontsize, + fontfamily="monospace", + transform=fig.transFigure, + bbox=bbox, + ) + with contextlib.suppress(Exception): + t.set_zorder(config.zorder_high) + return t + + +def draw_badge_with_box(fig: Figure, x: float, y: float, text: str) -> plt.Text: + """Draw a warning/info badge with white background box at figure coords. + + Args: + fig: Matplotlib Figure object. + x: X position in figure coordinates (0-1). + y: Y position in figure coordinates (0-1). + text: Badge text to display. + + Returns: + Text artist for the badge. + + """ + config = _DEFAULT_LAYOUT_CONFIG.title_footer + bbox = { + "facecolor": "white", + "edgecolor": "none", + "pad": config.bbox_pad_badge, + "alpha": 1.0, + } + t = fig.text( + x=x, + y=y, + s=text, + fontsize=config.footer_fontsize, + fontfamily="monospace", + color="firebrick", + ha="center", + va="top", + bbox=bbox, + ) + with contextlib.suppress(Exception): + t.set_zorder(config.zorder_high) + return t diff --git a/ice_station_zebra/visualisations/plotting_core.py b/ice_station_zebra/visualisations/plotting_core.py index 75a9054f..59a6668e 100644 --- a/ice_station_zebra/visualisations/plotting_core.py +++ b/ice_station_zebra/visualisations/plotting_core.py @@ -1,15 +1,201 @@ +import logging +from collections.abc import Mapping +from dataclasses import dataclass from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal +import matplotlib as mpl import numpy as np -from matplotlib.colors import TwoSlopeNorm +from matplotlib.colors import Normalize, TwoSlopeNorm from ice_station_zebra.exceptions import InvalidArrayError from ice_station_zebra.types import DiffColourmapSpec, DiffMode, DiffStrategy, PlotSpec +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + import matplotlib.pyplot as plt + # Constants for land mask validation EXPECTED_LAND_MASK_DIMENSIONS = 2 +# --- Variable Styling --- + + +@dataclass +class VariableStyle: + """Styling configuration for individual variables. + + Attributes: + cmap: Matplotlib colourmap name (e.g., "viridis", "RdBu_r"). + colourbar_strategy: "shared" or "separate" (kept for compatibility). + vmin: Minimum value for colour scale. + vmax: Maximum value for colour scale. + two_slope_centre: Centre value for diverging colourmap (TwoSlopeNorm). + units: Display units for the variable (e.g., "K", "m/s"). + origin: Imshow origin override ("upper" keeps north-up, "lower" keeps south-up). + decimals: Number of decimal places for colourbar tick labels (default: 2). + use_scientific_notation: Whether to format colourbar tick labels in scientific notation (default: False). + + """ + + cmap: str | None = None + colourbar_strategy: str | None = None + vmin: float | None = None + vmax: float | None = None + two_slope_centre: float | None = None + units: str | None = None + origin: Literal["upper", "lower"] | None = None + decimals: int | None = None + use_scientific_notation: bool | None = None + + +def colourmap_with_bad( + cmap_name: str | None, bad_color: str = "#dcdcdc" +) -> mpl.colors.Colormap: + """Create a colourmap copy with a specified color for bad (NaN) values. + + This function copies the specified colourmap and sets the 'bad' color to handle + NaN values consistently, preventing white artifacts in visualisations. + + Args: + cmap_name: Name of the matplotlib colourmap (e.g., "viridis", "RdBu_r"). + If None, defaults to "viridis". + bad_color: Color to use for NaN/bad values. Default is light grey (#dcdcdc). + + Returns: + A copy of the colourmap with set_bad() configured. + + """ + if cmap_name is None: + cmap = mpl.colormaps.get_cmap("viridis") + else: + cmap = mpl.colormaps.get_cmap(cmap_name) + + try: + cmap = cmap.copy() + except (AttributeError, TypeError): + # Some matplotlib versions return non-copyable colourmap; create new + cmap = mpl.colormaps.get_cmap(cmap.name) + + cmap.set_bad(bad_color) + return cmap + + +def safe_nanmin(arr: np.ndarray, default: float = 0.0) -> float: + """Safely compute nanmin with fallback for empty or all-NaN arrays. + + Args: + arr: Array to compute minimum from. + default: Default value if array is empty or all NaN. + + Returns: + Minimum value or default. + + """ + if np.isfinite(arr).any(): + result = np.nanmin(arr) + return float(result) if np.isfinite(result) else default + return default + + +def safe_nanmax(arr: np.ndarray, default: float = 1.0) -> float: + """Safely compute nanmax with fallback for empty or all-NaN arrays. + + Args: + arr: Array to compute maximum from. + default: Default value if array is empty or all NaN. + + Returns: + Maximum value or default. + + """ + if np.isfinite(arr).any(): + result = np.nanmax(arr) + return float(result) if np.isfinite(result) else default + return default + + +def style_for_variable( # noqa: C901, PLR0911 + var_name: str, styles: dict[str, dict[str, Any]] | None +) -> VariableStyle: + """Return best matching style for a variable from config styles dict. + + Matching priority: + 1) exact key + 2) wildcard prefix key ending with '*' + 3) _default + 4) empty style + Accepts any Mapping (so OmegaConf DictConfig works). + """ + + def _normalise_name(name: str) -> str: + # Convert double-underscore to colon (this maps 'era5__2t' -> 'era5:2t') + name = name.replace("__", ":") + # Treat hyphens as separators too: 'era5-2t' -> 'era5:2t' + name = name.replace("-", ":") + # Collapse accidental repeated '::' to single ':' + while "::" in name: + name = name.replace("::", ":") + # Keep single underscores (they are meaningful in some variable names) + return name + + if not styles: + return VariableStyle() + + # Accept Mapping-like configs (Dict, DictConfig, etc.) + if not isinstance(styles, Mapping): + logger.info("style_for_variable: styles is not a Mapping; ignoring styles") + return VariableStyle() + + # Quick exact match first (try raw var_name) + spec = styles.get(var_name) + if isinstance(spec, Mapping): + return VariableStyle(**{k: spec.get(k) for k in VariableStyle.__annotations__}) + + # Try normalised exact match + norm_var = _normalise_name(var_name) + if norm_var != var_name: + spec = styles.get(norm_var) + if isinstance(spec, Mapping): + return VariableStyle( + **{k: spec.get(k) for k in VariableStyle.__annotations__} + ) + + # Wildcard prefix match: scan keys ending with '*' (normalise the key before comparing) + # We iterate keys so keep original order (OmegaConf preserves insertion order). + for key in styles: + if isinstance(key, str) and key.endswith("*"): + prefix = key[:-1] + prefix_norm = _normalise_name(prefix) + # If prefix_norm is empty (user wrote '*' only) skip it + if not prefix_norm: + continue + # Compare against both raw and normalised var names + if var_name.startswith(prefix) or norm_var.startswith(prefix_norm): + spec = styles.get(key) + if isinstance(spec, Mapping): + return VariableStyle( + **{k: spec.get(k) for k in VariableStyle.__annotations__} + ) + logger.info( + "style_for_variable: wildcard candidate %r not a dict (type=%s)", + key, + type(spec), + ) + + # Fallback to _default + spec = styles.get("_default") + if isinstance(spec, Mapping): + return VariableStyle(**{k: spec.get(k) for k in VariableStyle.__annotations__}) + + return VariableStyle() + + +# --- Colour Scale Generation --- + + def levels_from_spec(spec: PlotSpec) -> np.ndarray: """Generate contour levels from a plotting specification. @@ -30,6 +216,63 @@ def levels_from_spec(spec: PlotSpec) -> np.ndarray: return np.linspace(vmin, vmax, spec.n_contour_levels) +def create_normalisation( + data: np.ndarray, + *, + vmin: float | None = None, + vmax: float | None = None, + centre: float | None = None, +) -> tuple[Normalize | TwoSlopeNorm, float, float]: + """Create appropriate normalisation for data with optional centring. + + This function creates either a linear Normalize or a diverging TwoSlopeNorm + based on whether a centre value is provided. When a centre is specified, + the normalisation will be symmetric around that value. + + Args: + data: 2D array of data to normalise. + vmin: Minimum value for colour scale. If None, inferred from data. + vmax: Maximum value for colour scale. If None, inferred from data. + centre: Centre value for diverging colourmap. If provided, creates + a symmetric TwoSlopeNorm around this value. + + Returns: + Tuple of (normalisation, vmin, vmax) where: + - normalisation: Normalize or TwoSlopeNorm object + - vmin: Computed minimum value + - vmax: Computed maximum value + + """ + # Compute data range with robust handling of NaN/inf + data_min = float(np.nanmin(data)) if np.isfinite(data).any() else 0.0 + data_max = float(np.nanmax(data)) if np.isfinite(data).any() else 1.0 + + if centre is not None: + # Diverging colourmap centred at specified value + low = vmin if vmin is not None else data_min + high = vmax if vmax is not None else data_max + + # Make symmetric around the centre where possible + span_low = abs(centre - low) + span_high = abs(high - centre) + span = max(span_low, span_high, 1e-6) + + final_vmin = float(centre - span) + final_vmax = float(centre + span) + + norm: Normalize | TwoSlopeNorm = TwoSlopeNorm( + vmin=final_vmin, vcenter=float(centre), vmax=final_vmax + ) + return norm, final_vmin, final_vmax + + # Linear colourmap + final_vmin = float(vmin if vmin is not None else data_min) + final_vmax = float(vmax if vmax is not None else data_max) + + norm_linear: Normalize | TwoSlopeNorm = Normalize(vmin=final_vmin, vmax=final_vmax) + return norm_linear, final_vmin, final_vmax + + def compute_difference( ground_truth: np.ndarray, prediction: np.ndarray, diff_mode: DiffMode ) -> np.ndarray: @@ -87,13 +330,9 @@ def make_diff_colourmap( max_abs = max(1.0, float(abs(sample))) vmin, vmax = -max_abs, max_abs else: - # Find the min and max values of the sample array - vmin_data = float( - np.nanmin(sample) if np.nanmin(sample) is not None else -1.0 - ) - vmax_data = float( - np.nanmax(sample) if np.nanmax(sample) is not None else 1.0 - ) + # Find the min and max values of the sample array using safe helpers + vmin_data = safe_nanmin(sample, default=-1.0) + vmax_data = safe_nanmax(sample, default=1.0) # Find the maximum absolute value of the sample array max_abs = max(1.0, abs(vmin_data), abs(vmax_data)) vmin, vmax = -max_abs, max_abs @@ -110,7 +349,7 @@ def make_diff_colourmap( if isinstance(sample, (float, int)): vmax = max(1e-6, float(sample)) else: - vmax = max(1e-6, float(np.nanmax(sample) or 0.0)) + vmax = max(1e-6, safe_nanmax(sample, default=0.0)) return DiffColourmapSpec( norm=None, @@ -330,7 +569,7 @@ def compute_display_ranges_stream( return (groundtruth_min, groundtruth_max), (prediction_min, prediction_max) -def detect_land_mask_path( +def detect_land_mask_path( # noqa: C901 base_path: str | Path, dataset_name: str | None = None, hemisphere: str | None = None, @@ -338,12 +577,14 @@ def detect_land_mask_path( """Automatically detect the land mask path based on dataset configuration. This function looks for land mask files in the expected locations based on - the dataset name and hemisphere. It follows the pattern: + the dataset name and hemisphere. It accepts either the repository root + (containing a ``data`` directory) or the ``data`` directory itself as + ``base_path``. It follows the pattern: - {base_path}/data/preprocessing/{dataset_name}/IceNetSIC/data/masks/{hemisphere}/masks/land_mask.npy - {base_path}/data/preprocessing/IceNetSIC/data/masks/{hemisphere}/masks/land_mask.npy Args: - base_path: Base path to the data directory. + base_path: Base path to the project root **or** the ``data`` directory. dataset_name: Name of the dataset (e.g., 'samp-sicsouth-osisaf-25k-2017-2019-24h-v1'). hemisphere: Hemisphere ('north' or 'south'). @@ -363,37 +604,54 @@ def detect_land_mask_path( if hemisphere is None: return None - # Try dataset-specific path first - if dataset_name is not None: - dataset_specific_path = ( - base_path - / "data" - / "preprocessing" - / dataset_name - / "IceNetSIC" - / "data" - / "masks" - / hemisphere - / "masks" - / "land_mask.npy" - ) - if dataset_specific_path.exists(): - return str(dataset_specific_path) - - # Try general IceNetSIC path - general_path = ( - base_path - / "data" - / "preprocessing" - / "IceNetSIC" - / "data" - / "masks" - / hemisphere - / "masks" - / "land_mask.npy" - ) - if general_path.exists(): - return str(general_path) + # Support repo-root, data-directory, and nested data/data layouts + candidate_data_roots: list[Path] = [] + if base_path.name != "data": + candidate_data_roots.append(base_path / "data") + candidate_data_roots.append(base_path) + + seen_roots: set[Path] = set() + for data_root in candidate_data_roots: + resolved_root = data_root.resolve() + if resolved_root in seen_roots: + continue + seen_roots.add(resolved_root) + + # Some deployments keep datasets in data/data/, + # so we probe both the root and an extra nested data/ layer. + preprocessing_roots = { + resolved_root / "preprocessing", + resolved_root / "data" / "preprocessing", + } + + for preproc_root in preprocessing_roots: + # Try dataset-specific path first + if dataset_name is not None: + dataset_specific_path = ( + preproc_root + / dataset_name + / "IceNetSIC" + / "data" + / "masks" + / hemisphere + / "masks" + / "land_mask.npy" + ) + if dataset_specific_path.exists(): + return str(dataset_specific_path) + + # Try general IceNetSIC path + general_path = ( + preproc_root + / "IceNetSIC" + / "data" + / "masks" + / hemisphere + / "masks" + / "land_mask.npy" + ) + if general_path.exists(): + return str(general_path) return None @@ -437,3 +695,55 @@ def load_land_mask( land_mask = land_mask.astype(bool) return land_mask + + +# --- File Utilities --- + + +def safe_filename(name: str) -> str: + """Sanitise a string for use as a filename. + + Replaces non-alphanumeric characters (except hyphens and underscores) + with hyphens and strips leading/trailing hyphens. + + Args: + name: Input string to sanitise. + + Returns: + Sanitised filename string. + + Examples: + >>> safe_filename("era5:2t") + 'era5-2t' + >>> safe_filename("My Variable Name!") + 'My-Variable-Name' + + """ + keep = [c if c.isalnum() or c in ("-", "_") else "-" for c in name.strip()] + return "".join(keep).strip("-") or "var" + + +def save_figure( + fig: "plt.Figure", save_dir: Path | None, base_name: str +) -> Path | None: + """Save a matplotlib figure to disk as PNG. + + Creates the save directory if it doesn't exist. Sanitises the filename + to avoid filesystem issues. + + Args: + fig: Matplotlib Figure object to save. + save_dir: Directory to save the figure in. If None, does not save. + base_name: Base name for the file (will be sanitised). + + Returns: + Path to the saved file, or None if save_dir was None. + + """ + if save_dir is None: + return None + + save_dir.mkdir(parents=True, exist_ok=True) + path = save_dir / f"{safe_filename(base_name)}.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + return path diff --git a/ice_station_zebra/visualisations/plotting_maps.py b/ice_station_zebra/visualisations/plotting_maps.py index 429f2976..8ba698af 100644 --- a/ice_station_zebra/visualisations/plotting_maps.py +++ b/ice_station_zebra/visualisations/plotting_maps.py @@ -8,7 +8,6 @@ """ -import contextlib import io import logging from collections.abc import Sequence @@ -20,6 +19,7 @@ from matplotlib import animation from matplotlib.axes import Axes from matplotlib.colors import ListedColormap +from matplotlib.figure import Figure from matplotlib.text import Text from PIL.ImageFile import ImageFile @@ -27,8 +27,20 @@ from ice_station_zebra.types import DiffColourmapSpec, PlotSpec from . import convert -from .layout import _add_colourbars, _build_layout, _set_axes_limits, _set_titles +from .animation_helper import hold_anim, release_anim +from .layout import ( + LayoutConfig, + TitleFooterConfig, + _add_colourbars, + _set_axes_limits, + _set_titles, + build_layout, + draw_badge_with_box, + set_footer_with_box, + set_suptitle_with_box, +) from .plotting_core import ( + colourmap_with_bad, compute_difference, compute_display_ranges, compute_display_ranges_stream, @@ -43,10 +55,6 @@ logger = logging.getLogger(__name__) - -# Keep strong references to animation objects during save to avoid GC-related warnings -_ANIM_CACHE: list[animation.FuncAnimation] = [] - #: Default plotting specification for sea ice concentration visualisation DEFAULT_SIC_SPEC = PlotSpec( variable="sea_ice_concentration", @@ -65,6 +73,103 @@ ) +def _safe_linspace(vmin: float, vmax: float, n: int) -> np.ndarray: + """Return an increasing linspace even if vmin==vmax or the inputs are swapped.""" + # ensure float + vmin = float(vmin) + vmax = float(vmax) + if not np.isfinite(vmin) or not np.isfinite(vmax): + # fallback to a stable interval + vmin, vmax = 0.0, 1.0 + if vmax < vmin: + vmin, vmax = vmax, vmin + if vmax == vmin: + # create a tiny range so contourf accepts the levels + eps = max(1e-6, abs(vmin) * 1e-6) + vmax = vmin + eps + return np.linspace(vmin, vmax, n) + + +def _prepare_static_plot( + plot_spec: PlotSpec, + ground_truth: np.ndarray, + prediction: np.ndarray, +) -> tuple[int, int, np.ndarray | None, LayoutConfig | None, list[str], np.ndarray]: + """Validate arrays and compute helpers needed for plotting. + + Returns: + Tuple of (height, width, land_mask, layout_config, warnings, contour levels). + + """ + height, width = validate_2d_pair(ground_truth, prediction) + land_mask = load_land_mask(plot_spec.land_mask_path, (height, width)) + + (gt_min, gt_max), (_pred_min, _pred_max) = compute_display_ranges( + ground_truth, prediction, plot_spec + ) + range_check_report = compute_range_check_report( + ground_truth, + prediction, + vmin=gt_min, + vmax=gt_max, + outside_warn=getattr(plot_spec, "outside_warn", 0.05), + severe_outside=getattr(plot_spec, "severe_outside", 0.20), + include_shared_range_mismatch_check=getattr( + plot_spec, "include_shared_range_mismatch_check", True + ), + ) + warnings = range_check_report.warnings + layout_config = None + if warnings: + layout_config = LayoutConfig(title_footer=TitleFooterConfig(title_space=0.10)) + + levels = levels_from_spec(plot_spec) + return height, width, land_mask, layout_config, warnings, levels + + +def _prepare_difference( + plot_spec: PlotSpec, + ground_truth: np.ndarray, + prediction: np.ndarray, +) -> tuple[np.ndarray | None, DiffColourmapSpec | None]: + """Compute difference arrays and colour scales if requested.""" + if not plot_spec.include_difference: + return None, None + difference = compute_difference(ground_truth, prediction, plot_spec.diff_mode) + diff_colour_scale = make_diff_colourmap(difference, mode=plot_spec.diff_mode) + return difference, diff_colour_scale + + +def _draw_warning_badge( + fig: Figure, + title_text: Text | None, + warnings: Sequence[str], +) -> None: + """Render a warning badge close to the title.""" + if not warnings: + return + badge = "Warnings: " + ", ".join(warnings) + if title_text is not None: + _, title_y = title_text.get_position() + n_lines = title_text.get_text().count("\n") + 1 + warning_y = max(title_y - (0.05 + 0.02 * (n_lines - 1)), 0.0) + else: + warning_y = 0.90 + draw_badge_with_box(fig, 0.5, warning_y, badge) + + +def _maybe_add_footer(fig: Figure, plot_spec: PlotSpec) -> None: + """Attach footer metadata when enabled.""" + if not getattr(plot_spec, "include_footer_metadata", True): + return + try: + footer_text = _build_footer_static(plot_spec) + if footer_text: + set_footer_with_box(fig, footer_text) + except Exception: + logger.exception("Failed to draw footer; continuing without footer.") + + # --- Static Map Plot --- def plot_maps( plot_spec: PlotSpec, @@ -94,50 +199,27 @@ def plot_maps( InvalidArrayError: If ground_truth and prediction arrays have incompatible shapes. """ - # Check the shapes of the arrays - height, width = validate_2d_pair(ground_truth, prediction) - - # Load land mask if specified - land_mask = load_land_mask(plot_spec.land_mask_path, (height, width)) - - # Pre-compute range check to decide top spacing (warning badge may need extra room) - (gt_min, gt_max), (pred_min, pred_max) = compute_display_ranges( - ground_truth, prediction, plot_spec - ) - range_check_report = compute_range_check_report( - ground_truth, - prediction, - vmin=gt_min, - vmax=gt_max, - outside_warn=getattr(plot_spec, "outside_warn", 0.05), - severe_outside=getattr(plot_spec, "severe_outside", 0.20), - include_shared_range_mismatch_check=getattr( - plot_spec, "include_shared_range_mismatch_check", True - ), - ) - - # Increase title space if warnings are present to avoid overlap with axes titles - title_space_override = 0.10 if range_check_report.warnings else None + ( + height, + width, + land_mask, + layout_config, + warnings, + levels, + ) = _prepare_static_plot(plot_spec, ground_truth, prediction) # Initialise the figure and axes with dynamic top spacing if needed - if title_space_override is not None: - fig, axs, cbar_axes = _build_layout( - plot_spec=plot_spec, - height=height, - width=width, - title_space=title_space_override, - ) - else: - fig, axs, cbar_axes = _build_layout( - plot_spec=plot_spec, height=height, width=width - ) - levels = levels_from_spec(plot_spec) + fig, axs, cbar_axes = build_layout( + plot_spec=plot_spec, + height=height, + width=width, + layout_config=layout_config, + ) # Prepare difference rendering parameters if needed - diff_colour_scale = None - if plot_spec.include_difference: - difference = compute_difference(ground_truth, prediction, plot_spec.diff_mode) - diff_colour_scale = make_diff_colourmap(difference, mode=plot_spec.diff_mode) + difference, diff_colour_scale = _prepare_difference( + plot_spec, ground_truth, prediction + ) # Draw the ground truth and prediction map images image_groundtruth, image_prediction, image_difference, _ = _draw_frame( @@ -146,7 +228,7 @@ def plot_maps( prediction, plot_spec, diff_colour_scale, - precomputed_difference=difference if plot_spec.include_difference else None, + precomputed_difference=difference, levels_override=levels, land_mask=land_mask, ) @@ -166,39 +248,16 @@ def plot_maps( _set_axes_limits(axs, width=width, height=height) try: - title_text = _set_suptitle_with_box(fig, _build_title_static(plot_spec, date)) + title_text = set_suptitle_with_box(fig, _build_title_static(plot_spec, date)) except Exception: logger.exception("Failed to draw suptitle; continuing without title.") title_text = None - # Include range_check report (already computed above) - badge = ( - "" - if not range_check_report.warnings - else "Warnings: " + ", ".join(range_check_report.warnings) - ) - if badge: - # Place the warning just below the title - if title_text is not None: - _, title_y = title_text.get_position() - n_lines = title_text.get_text().count("\n") + 1 - # Reduced gap to avoid overlapping axes titles; keep badge close to title - warning_y = max(title_y - (0.05 + 0.02 * (n_lines - 1)), 0.0) - else: - warning_y = 0.90 - _draw_badge_with_box(fig, 0.5, warning_y, badge) - - # Footer metadata at the bottom - if getattr(plot_spec, "include_footer_metadata", True): - try: - footer_text = _build_footer_static(plot_spec) - if footer_text: - _set_footer_with_box(fig, footer_text) - except Exception: - logger.exception("Failed to draw footer; continuing without footer.") + _draw_warning_badge(fig, title_text, warnings) + _maybe_add_footer(fig, plot_spec) try: - return {"sea-ice_concentration-static-maps": [convert._image_from_figure(fig)]} + return {"sea-ice_concentration-static-maps": [convert.image_from_figure(fig)]} finally: plt.close(fig) @@ -266,8 +325,12 @@ def video_maps( prediction_stream = np.where(land_mask, np.nan, prediction_stream) # Initialise the figure and axes with a larger footer space for videos - fig, axs, cbar_axes = _build_layout( - plot_spec=plot_spec, height=height, width=width, footer_space=0.11 + layout_config = LayoutConfig(title_footer=TitleFooterConfig(footer_space=0.11)) + fig, axs, cbar_axes = build_layout( + plot_spec=plot_spec, + height=height, + width=width, + layout_config=layout_config, ) levels = levels_from_spec(plot_spec) @@ -322,9 +385,7 @@ def video_maps( ) _set_axes_limits(axs, width=width, height=height) try: - title_text = _set_suptitle_with_box( - fig, _build_title_video(plot_spec, dates, 0) - ) + title_text = set_suptitle_with_box(fig, _build_title_video(plot_spec, dates, 0)) except Exception: logger.exception("Failed to draw suptitle; continuing without title.") title_text = None @@ -334,7 +395,7 @@ def video_maps( try: footer_text = _build_footer_video(plot_spec, dates) if footer_text: - _set_footer_with_box(fig, footer_text) + set_footer_with_box(fig, footer_text) except Exception: logger.exception("Failed to draw footer; continuing without footer.") @@ -372,19 +433,18 @@ def animate(tt: int) -> tuple[()]: blit=False, repeat=True, ) - # Keep a strong reference without touching figure attributes (ruff-fix) - _ANIM_CACHE.append(animation_object) + # Keep a strong reference to prevent garbage collection during save + hold_anim(animation_object) # Save -> BytesIO and clean up temp file try: - video_buffer = convert._save_animation( - animation_object, _fps=fps, _video_format=video_format + video_buffer = convert.save_animation( + animation_object, fps=fps, video_format=video_format ) return {"sea-ice_concentration-video-maps": video_buffer} finally: # Drop strong reference now that saving is done - with contextlib.suppress(ValueError): - _ANIM_CACHE.remove(animation_object) + release_anim(animation_object) plt.close(fig) @@ -473,19 +533,22 @@ def _draw_main_panels( # noqa: PLR0913 # Use PlotSpec levels for ground_truth and prediction unless overridden levels = levels_from_spec(plot_spec) if levels_override is None else levels_override + # Create colourmap with bad color handling for NaN values + cmap = colourmap_with_bad(plot_spec.colourmap, bad_color="lightgrey") + if plot_spec.colourbar_strategy == "separate": # For separate strategy, use explicit levels to prevent breathing - groundtruth_levels = np.linspace( + groundtruth_levels = _safe_linspace( groundtruth_vmin, groundtruth_vmax, plot_spec.n_contour_levels ) - prediction_levels = np.linspace( + prediction_levels = _safe_linspace( prediction_vmin, prediction_vmax, plot_spec.n_contour_levels ) image_groundtruth = axs[0].contourf( ground_truth, levels=groundtruth_levels, - cmap=plot_spec.colourmap, + cmap=cmap, vmin=groundtruth_vmin, vmax=groundtruth_vmax, origin="lower", @@ -493,7 +556,7 @@ def _draw_main_panels( # noqa: PLR0913 image_prediction = axs[1].contourf( prediction, levels=prediction_levels, - cmap=plot_spec.colourmap, + cmap=cmap, vmin=prediction_vmin, vmax=prediction_vmax, origin="lower", @@ -503,7 +566,7 @@ def _draw_main_panels( # noqa: PLR0913 image_groundtruth = axs[0].contourf( ground_truth, levels=levels, - cmap=plot_spec.colourmap, + cmap=cmap, vmin=groundtruth_vmin, vmax=groundtruth_vmax, origin="lower", @@ -511,7 +574,7 @@ def _draw_main_panels( # noqa: PLR0913 image_prediction = axs[1].contourf( prediction, levels=levels, - cmap=plot_spec.colourmap, + cmap=cmap, vmin=prediction_vmin, vmax=prediction_vmax, origin="lower", @@ -603,16 +666,21 @@ def _draw_frame( # noqa: PLR0913 error_msg = "diff_colour_scale must be provided when including difference" raise InvalidArrayError(error_msg) + # Create colourmap with bad color handling for NaN values + diff_cmap = colourmap_with_bad(diff_colour_scale.cmap, bad_color="lightgrey") + if diff_colour_scale.norm is not None: # Signed differences with TwoSlopeNorm - use explicit levels to ensure consistency diff_vmin = diff_colour_scale.norm.vmin or 0.0 diff_vmax = diff_colour_scale.norm.vmax or 1.0 - diff_levels = np.linspace(diff_vmin, diff_vmax, plot_spec.n_contour_levels) + diff_levels = _safe_linspace( + diff_vmin, diff_vmax, plot_spec.n_contour_levels + ) image_difference = axs[2].contourf( difference, levels=diff_levels, - cmap=diff_colour_scale.cmap, + cmap=diff_cmap, vmin=diff_vmin, vmax=diff_vmax, origin="lower", @@ -621,7 +689,7 @@ def _draw_frame( # noqa: PLR0913 # Non-negative differences with vmin/vmax vmin = diff_colour_scale.vmin or 0.0 vmax = diff_colour_scale.vmax or 1.0 - diff_levels = np.linspace( + diff_levels = _safe_linspace( vmin, vmax, plot_spec.n_contour_levels, @@ -630,7 +698,7 @@ def _draw_frame( # noqa: PLR0913 image_difference = axs[2].contourf( difference, levels=diff_levels, - cmap=diff_colour_scale.cmap, + cmap=diff_cmap, vmin=vmin, vmax=vmax, origin="lower", @@ -652,7 +720,7 @@ def _overlay_nans(ax: Axes, arr: np.ndarray, land_color: str = "white") -> None: if np.isnan(arr).any(): # Create overlay for NaN areas (land mask) nan_mask = np.isnan(arr).astype(float) - # Create a custom colormap: 0=transparent, 1=land color + # Create a custom colourmap: 0=transparent, 1=land color # Land colour options: typically 'white' or 'black' (or any valid Matplotlib color) colors = ["white", land_color] # 0=white (transparent), 1=land color @@ -691,70 +759,6 @@ def _clear_plot(ax: Axes) -> None: ax.set_title("") -def _set_suptitle_with_box(fig: plt.Figure, text: str) -> Text: - """Draw a fixed-position title with a white box that doesn't influence layout. - - Returns the Text artist so callers can update with set_text during animation. - This version avoids kwargs that are unsupported on older Matplotlib. - """ - bbox = {"facecolor": "white", "edgecolor": "none", "pad": 2.0, "alpha": 1.0} - t = fig.text( - x=0.5, - y=0.98, - s=text, - ha="center", - va="top", - fontsize=12, - fontfamily="monospace", - transform=fig.transFigure, - bbox=bbox, - ) - with contextlib.suppress(Exception): - t.set_zorder(1000) - return t - - -def _set_footer_with_box(fig: plt.Figure, text: str) -> Text: - """Draw a fixed-position footer with a white box at bottom center. - - Footer is intended for metadata and secondary information. - """ - bbox = {"facecolor": "white", "edgecolor": "none", "pad": 2.0, "alpha": 1.0} - t = fig.text( - x=0.5, - y=0.03, - s=text, - ha="center", - va="bottom", - fontsize=11, - fontfamily="monospace", - transform=fig.transFigure, - bbox=bbox, - ) - with contextlib.suppress(Exception): - t.set_zorder(1000) - return t - - -def _draw_badge_with_box(fig: plt.Figure, x: float, y: float, text: str) -> Text: - """Draw a warning/info badge with white background box at figure coords.""" - bbox = {"facecolor": "white", "edgecolor": "none", "pad": 1.5, "alpha": 1.0} - t = fig.text( - x=x, - y=y, - s=text, - fontsize=11, - fontfamily="monospace", - color="firebrick", - ha="center", - va="top", - bbox=bbox, - ) - with contextlib.suppress(Exception): - t.set_zorder(1000) - return t - - # --- Title helpers --- def _formatted_variable_name(variable: str) -> str: """Return a human-friendly variable name for titles. diff --git a/ice_station_zebra/visualisations/plotting_raw_inputs.py b/ice_station_zebra/visualisations/plotting_raw_inputs.py new file mode 100644 index 00000000..eae2dc1c --- /dev/null +++ b/ice_station_zebra/visualisations/plotting_raw_inputs.py @@ -0,0 +1,510 @@ +"""Raw input variable visualisation for a single timestep or animation over time. + +Creates one static map per input channel with optional per-variable style overrides, +or animations showing temporal evolution of individual variables. + +This module is intentionally lightweight and single-panel oriented while reusing the +layout and formatting helpers from the sea-ice plotting stack to keep appearances aligned. +""" + +from __future__ import annotations + +import logging +from datetime import date, datetime +from typing import TYPE_CHECKING, Any, Literal + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib import animation +from matplotlib.colors import TwoSlopeNorm + +from ice_station_zebra.exceptions import InvalidArrayError +from ice_station_zebra.visualisations.animation_helper import hold_anim, release_anim +from ice_station_zebra.visualisations.layout import ( + build_single_panel_figure, + format_linear_ticks, + format_symmetric_ticks, + set_suptitle_with_box, +) +from ice_station_zebra.visualisations.plotting_core import ( + VariableStyle, + colourmap_with_bad, + create_normalisation, + load_land_mask, + safe_filename, + save_figure, + style_for_variable, +) + +if TYPE_CHECKING: + import io + from collections.abc import Sequence + from pathlib import Path + + from matplotlib.text import Text + from PIL.ImageFile import ImageFile + + from ice_station_zebra.types import PlotSpec + +logger = logging.getLogger(__name__) + + +def _format_title( + variable: str, hemisphere: str | None, when: date | datetime, units: str | None +) -> str: + """Format a title string for a raw input variable plot. + + Args: + variable: Variable name. + hemisphere: Hemisphere ("north" or "south"), if applicable. + when: Date or datetime of the data. + units: Display units for the variable. + + Returns: + Formatted title string. + + """ + hemi = f" ({hemisphere.capitalize()})" if hemisphere else "" + units_s = f" [{units}]" if units else "" + shown = when.date().isoformat() if isinstance(when, datetime) else when.isoformat() + return f"{variable}{units_s}{hemi} Shown: {shown}" + + +def plot_raw_inputs_for_timestep( # noqa: PLR0913, C901, PLR0912, PLR0915 + *, + channel_arrays: list[np.ndarray], + channel_names: list[str], + when: date | datetime, + plot_spec_base: PlotSpec, + land_mask: np.ndarray | None = None, + styles: dict[str, dict[str, Any]] | None = None, + save_dir: Path | None = None, +) -> list[tuple[str, ImageFile, Path | None]]: + """Plot one image per input channel as a static map. + + Returns list of (name, PIL.ImageFile, saved_path|None). + """ + if len(channel_arrays) != len(channel_names): + msg = ( + f"Channels count mismatch: arrays={len(channel_arrays)} " + f"names={len(channel_names)}" + ) + raise InvalidArrayError(msg) + + # Ensure we pick up styles either from the explicit argument or from the PlotSpec + if styles is None: + # some callers may embed variable_styles inside the plot_spec dataclass + styles = getattr(plot_spec_base, "variable_styles", None) + + results: list[tuple[str, ImageFile, Path | None]] = [] + land_mask_cache: dict[tuple[int, int], np.ndarray | None] = {} + + from . import convert # noqa: PLC0415 # local import to avoid circulars + + expected_ndim = 2 + for arr, name in zip(channel_arrays, channel_names, strict=True): + if arr.ndim != expected_ndim: + msg = f"Expected 2D [H,W] for channel '{name}', got {arr.shape}" + raise InvalidArrayError(msg) + + # Apply land mask (mask out land to NaN) + active_mask = land_mask + if active_mask is None and plot_spec_base.land_mask_path: + shape = arr.shape + if shape not in land_mask_cache: + try: + land_mask_cache[shape] = load_land_mask( + plot_spec_base.land_mask_path, shape + ) + except InvalidArrayError: + logger.exception( + "Failed to load land mask for '%s' with shape %s", name, shape + ) + land_mask_cache[shape] = None + active_mask = land_mask_cache.get(shape) + + # Apply mask if available, creating a new variable to avoid overwriting loop variable + arr_to_plot = arr + if active_mask is not None: + if active_mask.shape == arr.shape: + arr_to_plot = np.where(active_mask, np.nan, arr) + else: + logger.debug( + "Skipping land mask for '%s': mask shape %s != array shape %s", + name, + active_mask.shape, + arr.shape, + ) + + # Prefer an explicit styles dict, otherwise fall back to the PlotSpec attribute if present + styles_effective = ( + styles + if styles is not None + else getattr(plot_spec_base, "variable_styles", None) + ) + logger.debug( + "plot_raw_inputs: incoming name=%r; using styles keys=%s", + name, + list(styles_effective.keys()) if styles_effective else None, + ) + style = style_for_variable(name, styles_effective) + logger.debug( + "plot_raw_inputs: resolved style for %r => cmap=%r, vmin=%r, vmax=%r, origin=%r", + name, + style.cmap, + style.vmin, + style.vmax, + style.origin, + ) + + # Create normalisation using shared function + norm, vmin, vmax = create_normalisation( + arr_to_plot, + vmin=style.vmin, + vmax=style.vmax, + centre=style.two_slope_centre, + ) + + # Build figure and axis + height, width = arr_to_plot.shape + fig, ax, cax = build_single_panel_figure( + height=height, + width=width, + colourbar_location=plot_spec_base.colourbar_location, + ) + + # Render + cmap_name = style.cmap or plot_spec_base.colourmap + cmap = colourmap_with_bad(cmap_name, bad_color="lightgrey") + origin = style.origin or "lower" + image = ax.imshow( + arr_to_plot, cmap=cmap, norm=norm, origin=origin, interpolation="nearest" + ) + + # Colourbar + orientation = plot_spec_base.colourbar_location + cbar = fig.colorbar(image, ax=ax, cax=cax, orientation=orientation) + is_vertical = orientation == "vertical" + decimals = style.decimals if style.decimals is not None else 2 + use_scientific = ( + style.use_scientific_notation + if style.use_scientific_notation is not None + else False + ) + if isinstance(norm, TwoSlopeNorm): + format_symmetric_ticks( + cbar, + vmin=vmin, + vmax=vmax, + decimals=decimals, + is_vertical=is_vertical, + centre=norm.vcenter, + use_scientific_notation=use_scientific, + ) + else: + format_linear_ticks( + cbar, + vmin=vmin, + vmax=vmax, + decimals=decimals, + is_vertical=is_vertical, + use_scientific_notation=use_scientific, + ) + + # Title + try: + set_suptitle_with_box( + fig, + _format_title(name, plot_spec_base.hemisphere, when, style.units), + ) + except (ValueError, AttributeError, RuntimeError) as err: + logger.debug( + "Failed to draw raw-inputs title: %s; continuing without title.", err + ) + + # Save to disk if requested (colons in variable names replaced) + file_base = name.replace(":", "__") + saved_path = save_figure(fig, save_dir, file_base) + + # Convert to PIL + try: + pil_img = convert.image_from_figure(fig) + finally: + plt.close(fig) + + results.append((name, pil_img, saved_path)) + + return results + + +def video_raw_input_for_variable( # noqa: PLR0913, C901, PLR0915 + *, + variable_name: str, + data_stream: np.ndarray, + dates: Sequence[date | datetime], + plot_spec_base: PlotSpec, + land_mask: np.ndarray | None = None, + style: VariableStyle | None = None, + styles: dict[str, dict[str, Any]] | None = None, + fps: int = 2, + video_format: Literal["mp4", "gif"] = "gif", + save_path: Path | None = None, +) -> io.BytesIO: + """Create animation showing temporal evolution of a single variable. + + This function creates a video animation of a single input variable over time, + with consistent colour scaling across all frames to avoid visual "breathing". + + Args: + variable_name: Name of the variable (e.g., "era5:2t", "osisaf-south:ice_conc"). + data_stream: 3D array of data over time [T, H, W]. + dates: Sequence of dates corresponding to each timestep (length must match T). + plot_spec_base: Base plotting specification (colourmap, hemisphere, etc.). + land_mask: Optional 2D boolean array marking land areas [H, W]. + style: Optional pre-computed VariableStyle for this variable. + styles: Optional dictionary of style configurations for matching. + fps: Frames per second for the output video. + video_format: Output format, either "mp4" or "gif". + save_path: Optional path to save the video file to disk. + + Returns: + BytesIO buffer containing the encoded video data. + + Raises: + InvalidArrayError: If data_stream is not 3D or dates length mismatches. + VideoRenderError: If video encoding fails. + + """ + from . import convert # noqa: PLC0415 # Local import to avoid circulars + + # Validate input + if data_stream.ndim != 3: # noqa: PLR2004 + msg = f"Expected 3D data [T,H,W], got shape {data_stream.shape}" + raise InvalidArrayError(msg) + + n_timesteps, height, width = data_stream.shape + if len(dates) != n_timesteps: + msg = f"Number of dates ({len(dates)}) != number of timesteps ({n_timesteps})" + raise InvalidArrayError(msg) + + # Apply land mask if provided + if land_mask is not None: + if land_mask.shape != (height, width): + logger.debug( + "Land mask shape %s doesn't match data shape (%d, %d), skipping mask", + land_mask.shape, + height, + width, + ) + else: + # Mask out land areas by setting them to NaN + data_stream = np.where(land_mask, np.nan, data_stream) + + # Get styling for this variable + if style is None: + style = style_for_variable(variable_name, styles) + + # Create stable normalisation across all frames + # Use global min/max to prevent colour scale "breathing" + data_min = float(np.nanmin(data_stream)) if np.isfinite(data_stream).any() else 0.0 + data_max = float(np.nanmax(data_stream)) if np.isfinite(data_stream).any() else 1.0 + + # Override with style limits if provided + effective_vmin = style.vmin if style.vmin is not None else data_min + effective_vmax = style.vmax if style.vmax is not None else data_max + + # Create normalisation using first frame to establish type + norm, vmin, vmax = create_normalisation( + data_stream[0], + vmin=effective_vmin, + vmax=effective_vmax, + centre=style.two_slope_centre, + ) + + # Build figure with single panel + colourbar + fig, ax, cax = build_single_panel_figure( + height=height, + width=width, + colourbar_location=plot_spec_base.colourbar_location, + ) + + # Render initial frame + cmap_name = style.cmap or plot_spec_base.colourmap + cmap = colourmap_with_bad(cmap_name, bad_color="lightgrey") + origin = style.origin or "lower" + image = ax.imshow( + data_stream[0], cmap=cmap, norm=norm, origin=origin, interpolation="nearest" + ) + + # Create colourbar (stable across all frames) + orientation = plot_spec_base.colourbar_location + cbar = fig.colorbar(image, ax=ax, cax=cax, orientation=orientation) + is_vertical = orientation == "vertical" + + # Format colourbar ticks based on normalisation type + decimals = style.decimals if style.decimals is not None else 2 + use_scientific = ( + style.use_scientific_notation + if style.use_scientific_notation is not None + else False + ) + if isinstance(norm, TwoSlopeNorm): + format_symmetric_ticks( + cbar, + vmin=vmin, + vmax=vmax, + decimals=decimals, + is_vertical=is_vertical, + centre=norm.vcenter, + use_scientific_notation=use_scientific, + ) + else: + format_linear_ticks( + cbar, + vmin=vmin, + vmax=vmax, + decimals=decimals, + is_vertical=is_vertical, + use_scientific_notation=use_scientific, + ) + + # Create title (will be updated each frame) + title_text: Text | None = None + try: + title_text = set_suptitle_with_box( + fig, + _format_title( + variable_name, plot_spec_base.hemisphere, dates[0], style.units + ), + ) + except (ValueError, AttributeError, RuntimeError) as err: + logger.debug("Failed to draw title: %s; continuing without title.", err) + + # Animation update function + def animate(tt: int) -> tuple[()]: + """Update function for each frame of the animation.""" + # Update image data + image.set_data(data_stream[tt]) + + # Update title with current date + if title_text is not None: + title_text.set_text( + _format_title( + variable_name, plot_spec_base.hemisphere, dates[tt], style.units + ) + ) + + return () + + # Create animation object + anim = animation.FuncAnimation( + fig, + animate, + frames=n_timesteps, + interval=1000 // fps, + blit=False, + repeat=True, + ) + + # Keep strong reference to prevent garbage collection during save + hold_anim(anim) + + try: + # Save to BytesIO buffer + video_buffer = convert.save_animation(anim, fps=fps, video_format=video_format) + + # Optionally save to disk + if save_path is not None: + save_path.parent.mkdir(parents=True, exist_ok=True) + save_path.write_bytes(video_buffer.getvalue()) + logger.debug("Saved animation to %s", save_path) + # Reset buffer position after writing + video_buffer.seek(0) + + return video_buffer + + finally: + # Clean up: remove from cache and close figure + release_anim(anim) + plt.close(fig) + + +def video_raw_inputs_for_timesteps( # noqa: PLR0913 + *, + channel_arrays_stream: list[np.ndarray], + channel_names: list[str], + dates: Sequence[date | datetime], + plot_spec_base: PlotSpec, + land_mask: np.ndarray | None = None, + styles: dict[str, dict[str, Any]] | None = None, + fps: int = 2, + video_format: Literal["mp4", "gif"] = "gif", + save_dir: Path | None = None, +) -> list[tuple[str, io.BytesIO, Path | None]]: + """Create animations for multiple input variables over time. + + This is a convenience wrapper around `video_raw_input_for_variable()` that + processes multiple variables in a batch, applying consistent styling and + land masking to all variables. + + Args: + channel_arrays_stream: List of 3D arrays, one per variable [T, H, W]. + channel_names: List of variable names corresponding to each array. + dates: Sequence of dates for each timestep (length must match T). + plot_spec_base: Base plotting specification. + land_mask: Optional 2D land mask to apply to all variables. + styles: Optional dictionary of style configurations. + fps: Frames per second for videos. + video_format: Output format ("mp4" or "gif"). + save_dir: Optional directory to save videos to disk. + + Returns: + List of tuples (variable_name, video_buffer, saved_path). + + Raises: + InvalidArrayError: If arrays/names count mismatch or invalid shapes. + + """ + if len(channel_arrays_stream) != len(channel_names): + msg = ( + f"Channel count mismatch: {len(channel_arrays_stream)} arrays, " + f"{len(channel_names)} names" + ) + raise InvalidArrayError(msg) + + results: list[tuple[str, io.BytesIO, Path | None]] = [] + + for data_stream, var_name in zip(channel_arrays_stream, channel_names, strict=True): + logger.debug("Creating animation for variable: %s", var_name) + + # Determine save path if save_dir is provided + save_path: Path | None = None + if save_dir is not None: + # Sanitise variable name for filename + file_base = var_name.replace(":", "__") + suffix = ".gif" if video_format == "gif" else ".mp4" + save_path = save_dir / f"{safe_filename(file_base)}{suffix}" + + try: + # Create animation for this variable + video_buffer = video_raw_input_for_variable( + variable_name=var_name, + data_stream=data_stream, + dates=dates, + plot_spec_base=plot_spec_base, + land_mask=land_mask, + styles=styles, + fps=fps, + video_format=video_format, + save_path=save_path, + ) + + results.append((var_name, video_buffer, save_path)) + + except (InvalidArrayError, ValueError, MemoryError, OSError): + logger.exception( + "Failed to create animation for variable %s, skipping", var_name + ) + continue + + return results diff --git a/tests/plotting/conftest.py b/tests/plotting/conftest.py index 11d1ef6b..f31895b2 100644 --- a/tests/plotting/conftest.py +++ b/tests/plotting/conftest.py @@ -1,10 +1,12 @@ import warnings +from collections.abc import Iterator from datetime import date, timedelta from pathlib import Path from typing import Any, Protocol, cast import hydra import matplotlib as mpl +import matplotlib.pyplot as plt import numpy as np import pytest import torch @@ -12,6 +14,7 @@ from omegaconf import errors as oc_errors from ice_station_zebra.data_loaders import ZebraDataModule +from ice_station_zebra.types import PlotSpec from tests.conftest import make_varying_sic_stream mpl.use("Agg") @@ -25,6 +28,15 @@ ) TEST_DATE = date(2020, 1, 15) +TEST_HEIGHT = 48 +TEST_WIDTH = 48 + + +@pytest.fixture(autouse=True) +def close_all_figures() -> Iterator[None]: + """Automatically close all matplotlib figures after each test to prevent warnings.""" + yield + plt.close("all") @pytest.fixture @@ -259,3 +271,114 @@ def example_checkpoint_path(pytestconfig: pytest.Config) -> Path | None: return ckpt.resolve() return None + + +# --- Raw Inputs Fixtures --- + + +@pytest.fixture +def base_plot_spec() -> PlotSpec: + """Base plotting specification for raw inputs.""" + return PlotSpec( + variable="raw_inputs", + colourmap="viridis", + colourbar_location="vertical", + hemisphere="south", + ) + + +@pytest.fixture +def test_dates_short() -> list[date]: + """Generate a short sequence of test dates for animations (4 days).""" + return [TEST_DATE + timedelta(days=i) for i in range(4)] + + +@pytest.fixture +def land_mask_2d() -> np.ndarray: + """Create a simple circular land mask for testing [H, W].""" + dist = make_central_distance_grid(TEST_HEIGHT, TEST_WIDTH) + radius = min(TEST_HEIGHT, TEST_WIDTH) * 0.25 + return (dist < radius).astype(bool) + + +@pytest.fixture +def era5_temperature_2d() -> np.ndarray: + """Generate synthetic ERA5 2m temperature data (K) [H, W].""" + rng = np.random.default_rng(100) + # Temperature centreed around 273.15K (0°C) with realistic variation + base_temp = 273.15 + rng.normal(0, 10, size=(TEST_HEIGHT, TEST_WIDTH)) + return base_temp.astype(np.float32) + + +@pytest.fixture +def era5_humidity_2d() -> np.ndarray: + """Generate synthetic ERA5 specific humidity data (kg/kg) [H, W].""" + rng = np.random.default_rng(100) + # Humidity values are very small (0.001 to 0.01) + humidity = rng.uniform(0.0005, 0.015, size=(TEST_HEIGHT, TEST_WIDTH)) + return humidity.astype(np.float32) + + +@pytest.fixture +def era5_wind_u_2d() -> np.ndarray: + """Generate synthetic ERA5 u-wind component (m/s) [H, W].""" + rng = np.random.default_rng(100) + # Wind centreed around 0 with realistic variation + wind = rng.normal(0, 5, size=(TEST_HEIGHT, TEST_WIDTH)) + return wind.astype(np.float32) + + +@pytest.fixture +def osisaf_ice_conc_2d() -> np.ndarray: + """Generate synthetic OSISAF sea ice concentration data (fraction 0-1) [H, W].""" + rng = np.random.default_rng(100) + # Ice concentration between 0 and 1 + ice_conc = rng.uniform(0.0, 1.0, size=(TEST_HEIGHT, TEST_WIDTH)) + return ice_conc.astype(np.float32) + + +@pytest.fixture +def era5_temperature_3d(test_dates_short: list[date]) -> np.ndarray: + """Generate synthetic 3D temperature stream [T, H, W].""" + rng = np.random.default_rng(100) + n_timesteps = len(test_dates_short) + # Temperature evolving over time + data = np.zeros((n_timesteps, TEST_HEIGHT, TEST_WIDTH), dtype=np.float32) + for t in range(n_timesteps): + data[t] = 273.15 + rng.normal(0, 10, size=(TEST_HEIGHT, TEST_WIDTH)) + return data + + +@pytest.fixture +def multi_channel_data() -> tuple[list[np.ndarray], list[str]]: + """Generate multiple channels of raw input data.""" + rng = np.random.default_rng(100) + channels = [ + rng.uniform(270, 280, size=(TEST_HEIGHT, TEST_WIDTH)).astype( + np.float32 + ), # temperature + rng.normal(0, 5, size=(TEST_HEIGHT, TEST_WIDTH)).astype(np.float32), # u-wind + rng.normal(0, 5, size=(TEST_HEIGHT, TEST_WIDTH)).astype(np.float32), # v-wind + rng.uniform(0, 1, size=(TEST_HEIGHT, TEST_WIDTH)).astype( + np.float32 + ), # ice conc + ] + names = ["era5:2t", "era5:10u", "era5:10v", "osisaf-south:ice_conc"] + return channels, names + + +@pytest.fixture +def variable_styles() -> dict[str, dict[str, Any]]: + """Sample variable styling configuration for raw inputs.""" + return { + "era5:2t": { + "cmap": "RdBu_r", + "two_slope_centre": 273.15, + "units": "K", + "decimals": 1, + }, + "era5:10u": {"cmap": "RdBu_r", "two_slope_centre": 0.0, "units": "m/s"}, + "era5:10v": {"cmap": "RdBu_r", "two_slope_centre": 0.0, "units": "m/s"}, + "era5:q_10": {"cmap": "viridis", "decimals": 4, "units": "kg/kg"}, + "osisaf-south:ice_conc": {"cmap": "Blues_r"}, + } diff --git a/tests/plotting/test_api.py b/tests/plotting/test_api.py index 5e5fea66..b670902b 100644 --- a/tests/plotting/test_api.py +++ b/tests/plotting/test_api.py @@ -88,14 +88,18 @@ def test_plot_maps_emits_warning_badge( @pytest.fixture def fake_save_animation(monkeypatch: pytest.MonkeyPatch) -> Callable[..., io.BytesIO]: - """Monkeypatch _save_animation so video_maps runs fast in tests.""" + """Monkeypatch save_animation so video_maps runs fast in tests.""" def _fake_save( - _anim: object, *, _fps: int, _video_format: str = "gif" + _anim: object, + *, + fps: int = 2, # noqa: ARG001 + video_format: str = "gif", # noqa: ARG001 ) -> io.BytesIO: + # Parameters match real save_animation signature but are unused in fake return io.BytesIO(b"fake-video-data") - monkeypatch.setattr(convert, "_save_animation", _fake_save) + monkeypatch.setattr(convert, "save_animation", _fake_save) return _fake_save @@ -134,10 +138,10 @@ def test_plot_maps_with_land_mask( ground_truth, prediction, date = sic_pair_2d height, width = ground_truth.shape - # Create a simple land mask with land in the center + # Create a simple land mask with land in the centre land_mask = np.zeros((height, width), dtype=bool) - center_h, center_w = height // 2, width // 2 - land_mask[center_h - 5 : center_h + 5, center_w - 5 : center_w + 5] = True + centre_h, centre_w = height // 2, width // 2 + land_mask[centre_h - 5 : centre_h + 5, centre_w - 5 : centre_w + 5] = True # Save land mask to temporary file with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as tmp_file: diff --git a/tests/plotting/test_layout.py b/tests/plotting/test_layout.py index 2ebc823f..50ba3760 100644 --- a/tests/plotting/test_layout.py +++ b/tests/plotting/test_layout.py @@ -9,13 +9,15 @@ import matplotlib.pyplot as plt import pytest -from ice_station_zebra.visualisations.layout import _build_layout, _set_axes_limits -from ice_station_zebra.visualisations.plotting_maps import ( - DEFAULT_SIC_SPEC, - _draw_badge_with_box, - _set_footer_with_box, - _set_suptitle_with_box, +from ice_station_zebra.visualisations.layout import ( + _set_axes_limits, + build_layout, + build_single_panel_figure, + draw_badge_with_box, + set_footer_with_box, + set_suptitle_with_box, ) +from ice_station_zebra.visualisations.plotting_maps import DEFAULT_SIC_SPEC from .test_helper_plot_layout import axis_rectangle, rectangles_overlap @@ -59,7 +61,7 @@ def test_no_axes_overlap( colourbar_strategy=colourbar_strategy, # type: ignore[arg-type] ) - _, axes, colourbar_axes = _build_layout( + _, axes, colourbar_axes = build_layout( plot_spec=spec, height=ground_truth.shape[0], width=ground_truth.shape[1] ) @@ -96,7 +98,7 @@ def test_axes_have_reasonable_gaps( colourbar_strategy=colourbar_strategy, # type: ignore[arg-type] ) - _, axes, colourbar_axes = _build_layout( + _, axes, colourbar_axes = build_layout( plot_spec=spec, height=ground_truth.shape[0], width=ground_truth.shape[1] ) @@ -201,15 +203,15 @@ def test_figure_text_boxes_do_not_overlap( include_difference=include_difference, ) - fig, axes, caxes = _build_layout( + fig, axes, caxes = build_layout( plot_spec=spec, height=ground_truth.shape[0], width=ground_truth.shape[1] ) # Add figure-level title, warning badge (synthetic), and footer - title = _set_suptitle_with_box(fig, "Title") + title = set_suptitle_with_box(fig, "Title") ty = title.get_position()[1] - badge = _draw_badge_with_box(fig, 0.5, max(ty - 0.05, 0.0), "Warnings: example") - footer = _set_footer_with_box(fig, "Footer metadata") + badge = draw_badge_with_box(fig, 0.5, max(ty - 0.05, 0.0), "Warnings: example") + footer = set_footer_with_box(fig, "Footer metadata") # Collect rectangles: panels, colourbar axes, and figure texts rectangles = [axis_rectangle(ax) for ax in axes] @@ -223,3 +225,135 @@ def test_figure_text_boxes_do_not_overlap( assert not rectangles_overlap(rect_a, rect_b), ( f"Found overlap between rectangles {rect_a} and {rect_b}" ) + + +# --- Single Panel Layout Tests --- + + +@pytest.mark.parametrize("colourbar_location", ["horizontal", "vertical"]) +def test_single_panel_no_overlap( + era5_temperature_2d: np.ndarray, + *, + colourbar_location: str, +) -> None: + """Single panel layout: main panel and colourbar must not overlap.""" + height, width = era5_temperature_2d.shape + + fig, ax, cax = build_single_panel_figure( + height=height, + width=width, + colourbar_location=colourbar_location, # type: ignore[arg-type] + ) + + # Get rectangles for panel and colourbar + panel_rect = axis_rectangle(ax) + cbar_rect = axis_rectangle(cax) + + # They should not overlap + assert not rectangles_overlap(panel_rect, cbar_rect), ( + f"Panel and colourbar overlap: panel={panel_rect}, cbar={cbar_rect}" + ) + + plt.close(fig) + + +@pytest.mark.parametrize("colourbar_location", ["horizontal", "vertical"]) +def test_single_panel_has_reasonable_gap( + era5_temperature_2d: np.ndarray, + *, + colourbar_location: str, +) -> None: + """Single panel layout: require a minimum gap between panel and colourbar.""" + height, width = era5_temperature_2d.shape + + fig, ax, cax = build_single_panel_figure( + height=height, + width=width, + colourbar_location=colourbar_location, # type: ignore[arg-type] + ) + + panel_rect = axis_rectangle(ax) + cbar_rect = axis_rectangle(cax) + + pl, pb, pr, pt = panel_rect + cl, cb, cr, ct = cbar_rect + + if colourbar_location == "vertical": + # Vertical colorbar should be to the right of panel + gap = cl - pr + assert gap >= 0.005, ( + f"Expected ≥0.5% gap between panel and vertical colorbar, got {gap:.5f}" + ) + else: # horizontal + # Horizontal colorbar should be below panel + gap = pb - ct + assert gap >= 0.005, ( + f"Expected ≥0.5% gap between panel and horizontal colorbar, got {gap:.5f}" + ) + + plt.close(fig) + + +def test_single_panel_with_text_annotations( + era5_temperature_2d: np.ndarray, +) -> None: + """Single panel: title and annotations should not overlap panel or colourbar.""" + height, width = era5_temperature_2d.shape + + fig, ax, cax = build_single_panel_figure( + height=height, + width=width, + colourbar_location="vertical", + ) + + # Add text annotations + title = set_suptitle_with_box(fig, "Test Title") + footer = set_footer_with_box(fig, "Test Footer") + + # Collect all rectangles + panel_rect = axis_rectangle(ax) + cbar_rect = axis_rectangle(cax) + title_rect = _text_rectangle(fig, title) + footer_rect = _text_rectangle(fig, footer) + + rectangles = [panel_rect, cbar_rect, title_rect, footer_rect] + + # No overlaps + for rect_a, rect_b in combinations(rectangles, 2): + assert not rectangles_overlap(rect_a, rect_b), ( + f"Found overlap between {rect_a} and {rect_b}" + ) + + plt.close(fig) + + +@pytest.mark.parametrize( + ("height", "width"), + [ + (48, 48), # Square + (181, 720), # Wide (ERA5-like) + (432, 432), # Square (OSISAF-like) + (100, 200), # Wide + (200, 100), # Tall + ], +) +def test_single_panel_various_aspect_ratios( + height: int, + width: int, +) -> None: + """Single panel layout should handle various aspect ratios without overlap.""" + fig, ax, cax = build_single_panel_figure( + height=height, + width=width, + colourbar_location="vertical", + ) + + panel_rect = axis_rectangle(ax) + cbar_rect = axis_rectangle(cax) + + # No overlap + assert not rectangles_overlap(panel_rect, cbar_rect), ( + f"Overlap for {height}x{width}: panel={panel_rect}, cbar={cbar_rect}" + ) + + plt.close(fig) diff --git a/tests/plotting/test_raw_inputs.py b/tests/plotting/test_raw_inputs.py new file mode 100644 index 00000000..eefcd2a0 --- /dev/null +++ b/tests/plotting/test_raw_inputs.py @@ -0,0 +1,569 @@ +"""Tests for raw input plotting functionality. + +This module tests both static plots and animations of raw input variables, +covering ERA5 weather data, OSISAF sea ice concentration, and various +styling configurations. +""" + +from __future__ import annotations + +import io +from datetime import date, timedelta +from typing import TYPE_CHECKING, Any, Literal + +import numpy as np +import pytest +from PIL import Image + +from ice_station_zebra.exceptions import InvalidArrayError +from ice_station_zebra.types import PlotSpec +from ice_station_zebra.visualisations.plotting_core import style_for_variable +from ice_station_zebra.visualisations.plotting_raw_inputs import ( + plot_raw_inputs_for_timestep, + video_raw_input_for_variable, + video_raw_inputs_for_timesteps, +) + +# Import test constants from conftest +from .conftest import TEST_DATE, TEST_HEIGHT, TEST_WIDTH + +if TYPE_CHECKING: + from pathlib import Path + +# --- Tests for Static Plotting --- + + +def test_plot_single_channel_basic( + era5_temperature_2d: np.ndarray, + base_plot_spec: PlotSpec, +) -> None: + """Test basic single channel plotting.""" + results = plot_raw_inputs_for_timestep( + channel_arrays=[era5_temperature_2d], + channel_names=["era5:2t"], + when=TEST_DATE, + plot_spec_base=base_plot_spec, + ) + + assert len(results) == 1 + name, pil_image, saved_path = results[0] + assert name == "era5:2t" + assert isinstance(pil_image, Image.Image) + assert saved_path is None # No save_dir provided + + +def test_plot_with_land_mask( + era5_temperature_2d: np.ndarray, + land_mask_2d: np.ndarray, + base_plot_spec: PlotSpec, +) -> None: + """Test plotting with land mask applied.""" + results = plot_raw_inputs_for_timestep( + channel_arrays=[era5_temperature_2d], + channel_names=["era5:2t"], + when=TEST_DATE, + plot_spec_base=base_plot_spec, + land_mask=land_mask_2d, + ) + + assert len(results) == 1 + name, pil_image, _ = results[0] + assert name == "era5:2t" + assert isinstance(pil_image, Image.Image) + + +def test_plot_with_custom_styles( + era5_temperature_2d: np.ndarray, + base_plot_spec: PlotSpec, + variable_styles: dict[str, dict[str, Any]], +) -> None: + """Test plotting with custom variable styling.""" + results = plot_raw_inputs_for_timestep( + channel_arrays=[era5_temperature_2d], + channel_names=["era5:2t"], + when=TEST_DATE, + plot_spec_base=base_plot_spec, + styles=variable_styles, + ) + + assert len(results) == 1 + name, pil_image, _ = results[0] + assert name == "era5:2t" + assert isinstance(pil_image, Image.Image) + + +def test_plot_multiple_channels( + multi_channel_data: tuple[list[np.ndarray], list[str]], + base_plot_spec: PlotSpec, +) -> None: + """Test plotting multiple channels at once.""" + channel_arrays, channel_names = multi_channel_data + + results = plot_raw_inputs_for_timestep( + channel_arrays=channel_arrays, + channel_names=channel_names, + when=TEST_DATE, + plot_spec_base=base_plot_spec, + ) + + assert len(results) == len(channel_names) + for (name, pil_image, _), expected_name in zip(results, channel_names, strict=True): + assert name == expected_name + assert isinstance(pil_image, Image.Image) + + +def test_plot_to_disk( + era5_temperature_2d: np.ndarray, + base_plot_spec: PlotSpec, + tmp_path: Path, +) -> None: + """Test saving plots to disk.""" + results = plot_raw_inputs_for_timestep( + channel_arrays=[era5_temperature_2d], + channel_names=["era5:2t"], + when=TEST_DATE, + plot_spec_base=base_plot_spec, + save_dir=tmp_path, + ) + + assert len(results) == 1 + _, _, saved_path = results[0] + assert saved_path is not None + assert saved_path.exists() + assert saved_path.suffix == ".png" + + +@pytest.mark.parametrize("colourbar_location", ["vertical", "horizontal"]) +def test_plot_colourbar_locations( + era5_temperature_2d: np.ndarray, + colourbar_location: Literal["vertical", "horizontal"], +) -> None: + """Test plotting with different colorbar orientations.""" + plot_spec = PlotSpec( + variable="raw_inputs", + colourmap="viridis", + colourbar_location=colourbar_location, + ) + + results = plot_raw_inputs_for_timestep( + channel_arrays=[era5_temperature_2d], + channel_names=["era5:2t"], + when=TEST_DATE, + plot_spec_base=plot_spec, + ) + + assert len(results) == 1 + + +@pytest.mark.parametrize( + ("var_name", "fixture_name"), + [ + ("era5:2t", "era5_temperature_2d"), + ("era5:q_10", "era5_humidity_2d"), + ("era5:10u", "era5_wind_u_2d"), + ("osisaf-south:ice_conc", "osisaf_ice_conc_2d"), + ], +) +def test_plot_different_variables( + var_name: str, + fixture_name: str, + base_plot_spec: PlotSpec, + variable_styles: dict[str, dict[str, Any]], + request: pytest.FixtureRequest, +) -> None: + """Test plotting different types of variables with appropriate styling.""" + data = request.getfixturevalue(fixture_name) + + results = plot_raw_inputs_for_timestep( + channel_arrays=[data], + channel_names=[var_name], + when=TEST_DATE, + plot_spec_base=base_plot_spec, + styles=variable_styles, + ) + + assert len(results) == 1 + name, pil_image, _ = results[0] + assert name == var_name + assert isinstance(pil_image, Image.Image) + + +# --- Tests for Style Resolution --- + + +def test_style_for_variable_exact_match( + variable_styles: dict[str, dict[str, Any]], +) -> None: + """Test exact variable name matching in styling.""" + style = style_for_variable("era5:2t", variable_styles) + + assert style.cmap == "RdBu_r" + assert style.two_slope_centre == 273.15 + assert style.units == "K" + assert style.decimals == 1 + + +def test_style_for_variable_wildcard_match( + variable_styles: dict[str, dict[str, Any]], +) -> None: + """Test wildcard pattern matching in styling.""" + # Add wildcard pattern + styles_with_wildcard = { + **variable_styles, + "era5:q_*": {"cmap": "viridis", "decimals": 4}, + } + + style = style_for_variable("era5:q_500", styles_with_wildcard) + + assert style.cmap == "viridis" + assert style.decimals == 4 + + +def test_style_for_variable_scientific_notation( + variable_styles: dict[str, dict[str, Any]], +) -> None: + """Test scientific notation option in styling.""" + # Add style with scientific notation + styles_with_scientific = { + **variable_styles, + "era5:q_10": { + "cmap": "viridis", + "decimals": 2, + "units": "kg/kg", + "use_scientific_notation": True, + }, + } + + style = style_for_variable("era5:q_10", styles_with_scientific) + + assert style.cmap == "viridis" + assert style.decimals == 2 + assert style.units == "kg/kg" + assert style.use_scientific_notation is True + + +def test_plot_with_scientific_notation( + era5_humidity_2d: np.ndarray, + base_plot_spec: PlotSpec, +) -> None: + """Test plotting with scientific notation enabled.""" + styles_with_scientific = { + "era5:q_10": { + "cmap": "viridis", + "decimals": 2, + "units": "kg/kg", + "use_scientific_notation": True, + }, + } + + results = plot_raw_inputs_for_timestep( + channel_arrays=[era5_humidity_2d], + channel_names=["era5:q_10"], + when=TEST_DATE, + plot_spec_base=base_plot_spec, + styles=styles_with_scientific, + ) + + assert len(results) == 1 + name, pil_image, _ = results[0] + assert name == "era5:q_10" + assert isinstance(pil_image, Image.Image) + + +def test_style_for_variable_no_match() -> None: + """Test default styling when no match found.""" + style = style_for_variable("unknown:variable", {}) + + # Should use defaults + assert style.cmap is None + assert style.decimals is None + + +# --- Tests for Animations --- + + +def test_video_single_variable_basic( + era5_temperature_3d: np.ndarray, + test_dates_short: list[date], + base_plot_spec: PlotSpec, +) -> None: + """Test basic video creation for a single variable.""" + video_buffer = video_raw_input_for_variable( + variable_name="era5:2t", + data_stream=era5_temperature_3d, + dates=test_dates_short, + plot_spec_base=base_plot_spec, + fps=2, + video_format="gif", + ) + + assert isinstance(video_buffer, io.BytesIO) + video_buffer.seek(0) + assert len(video_buffer.read()) > 1000 # Reasonable GIF size + + +def test_video_with_land_mask( + era5_temperature_3d: np.ndarray, + test_dates_short: list[date], + land_mask_2d: np.ndarray, + base_plot_spec: PlotSpec, +) -> None: + """Test video creation with land mask.""" + video_buffer = video_raw_input_for_variable( + variable_name="era5:2t", + data_stream=era5_temperature_3d, + dates=test_dates_short, + plot_spec_base=base_plot_spec, + land_mask=land_mask_2d, + fps=2, + video_format="gif", + ) + + assert isinstance(video_buffer, io.BytesIO) + video_buffer.seek(0) + assert len(video_buffer.read()) > 1000 + + +def test_video_with_custom_style( + era5_temperature_3d: np.ndarray, + test_dates_short: list[date], + base_plot_spec: PlotSpec, + variable_styles: dict[str, dict[str, Any]], +) -> None: + """Test video creation with custom styling.""" + video_buffer = video_raw_input_for_variable( + variable_name="era5:2t", + data_stream=era5_temperature_3d, + dates=test_dates_short, + plot_spec_base=base_plot_spec, + styles=variable_styles, + fps=2, + video_format="gif", + ) + + assert isinstance(video_buffer, io.BytesIO) + video_buffer.seek(0) + assert len(video_buffer.read()) > 1000 + + +def test_video_save_to_disk( + era5_temperature_3d: np.ndarray, + test_dates_short: list[date], + base_plot_spec: PlotSpec, + tmp_path: Path, +) -> None: + """Test saving video to disk.""" + save_path = tmp_path / "test_animation.gif" + + video_buffer = video_raw_input_for_variable( + variable_name="era5:2t", + data_stream=era5_temperature_3d, + dates=test_dates_short, + plot_spec_base=base_plot_spec, + save_path=save_path, + fps=2, + video_format="gif", + ) + + assert isinstance(video_buffer, io.BytesIO) + assert save_path.exists() + assert save_path.stat().st_size > 1000 + + +@pytest.mark.parametrize("video_format", ["gif", "mp4"]) +def test_video_formats( + era5_temperature_3d: np.ndarray, + test_dates_short: list[date], + base_plot_spec: PlotSpec, + video_format: Literal["gif", "mp4"], +) -> None: + """Test creating videos in different formats.""" + video_buffer = video_raw_input_for_variable( + variable_name="era5:2t", + data_stream=era5_temperature_3d, + dates=test_dates_short, + plot_spec_base=base_plot_spec, + fps=2, + video_format=video_format, + ) + + assert isinstance(video_buffer, io.BytesIO) + video_buffer.seek(0) + content = video_buffer.read() + assert len(content) > 1000 + + # Check file signature + if video_format == "gif": + assert content[:6] == b"GIF89a" # GIF header + elif video_format == "mp4": + # MP4 typically has ftyp box early + assert b"ftyp" in content[:100] + + +def test_video_multiple_variables( + test_dates_short: list[date], + base_plot_spec: PlotSpec, +) -> None: + """Test batch video creation for multiple variables.""" + rng = np.random.default_rng(100) + n_timesteps = len(test_dates_short) + + # Create data streams for multiple variables + channel_arrays_stream = [ + rng.uniform(270, 280, size=(n_timesteps, TEST_HEIGHT, TEST_WIDTH)).astype( + np.float32 + ), + rng.normal(0, 5, size=(n_timesteps, TEST_HEIGHT, TEST_WIDTH)).astype( + np.float32 + ), + rng.uniform(0, 1, size=(n_timesteps, TEST_HEIGHT, TEST_WIDTH)).astype( + np.float32 + ), + ] + channel_names = ["era5:2t", "era5:10u", "osisaf-south:ice_conc"] + + results = video_raw_inputs_for_timesteps( + channel_arrays_stream=channel_arrays_stream, + channel_names=channel_names, + dates=test_dates_short, + plot_spec_base=base_plot_spec, + fps=2, + video_format="gif", + ) + + assert len(results) == 3 + for (name, video_buffer, saved_path), expected_name in zip( + results, channel_names, strict=True + ): + assert name == expected_name + assert isinstance(video_buffer, io.BytesIO) + assert saved_path is None # No save_dir provided + + +# --- Error Handling Tests --- + + +def test_plot_mismatched_arrays_and_names( + era5_temperature_2d: np.ndarray, + base_plot_spec: PlotSpec, +) -> None: + """Test error when arrays and names count mismatch.""" + with pytest.raises(InvalidArrayError, match="Channels count mismatch"): + plot_raw_inputs_for_timestep( + channel_arrays=[era5_temperature_2d, era5_temperature_2d], + channel_names=["era5:2t"], # Only one name for two arrays + when=TEST_DATE, + plot_spec_base=base_plot_spec, + ) + + +def test_plot_wrong_dimension( + base_plot_spec: PlotSpec, +) -> None: + """Test error when input array is not 2D.""" + rng = np.random.default_rng(42) + wrong_dim_array = rng.random((5, 5, 5)).astype(np.float32) # 3D instead of 2D + + with pytest.raises(InvalidArrayError, match="Expected 2D"): + plot_raw_inputs_for_timestep( + channel_arrays=[wrong_dim_array], + channel_names=["era5:2t"], + when=TEST_DATE, + plot_spec_base=base_plot_spec, + ) + + +def test_video_wrong_dimension( + test_dates_short: list[date], + base_plot_spec: PlotSpec, +) -> None: + """Test error when video data is not 3D.""" + rng = np.random.default_rng(42) + wrong_dim_array = rng.random((5, 5)).astype(np.float32) # 2D instead of 3D + + with pytest.raises(InvalidArrayError, match="Expected 3D"): + video_raw_input_for_variable( + variable_name="era5:2t", + data_stream=wrong_dim_array, + dates=test_dates_short, + plot_spec_base=base_plot_spec, + ) + + +def test_video_mismatched_dates( + era5_temperature_3d: np.ndarray, + base_plot_spec: PlotSpec, +) -> None: + """Test error when number of dates doesn't match timesteps.""" + wrong_dates = [ + TEST_DATE, + TEST_DATE + timedelta(days=1), + ] # Only 2 dates for 4 timesteps + + with pytest.raises( + InvalidArrayError, match="Number of dates.*!= number of timesteps" + ): + video_raw_input_for_variable( + variable_name="era5:2t", + data_stream=era5_temperature_3d, + dates=wrong_dates, + plot_spec_base=base_plot_spec, + ) + + +# --- Integration Tests --- + + +def test_full_workflow_static_and_video( + multi_channel_data: tuple[list[np.ndarray], list[str]], + test_dates_short: list[date], + base_plot_spec: PlotSpec, + variable_styles: dict[str, dict[str, Any]], + tmp_path: Path, +) -> None: + """Test complete workflow: static plots + videos for multiple variables.""" + channel_arrays, channel_names = multi_channel_data + + # 1. Create static plots + static_results = plot_raw_inputs_for_timestep( + channel_arrays=channel_arrays, + channel_names=channel_names, + when=TEST_DATE, + plot_spec_base=base_plot_spec, + styles=variable_styles, + save_dir=tmp_path / "static", + ) + + assert len(static_results) == len(channel_names) + for _, _, saved_path in static_results: + assert saved_path is not None + assert saved_path.exists() + + # 2. Create 3D data streams + rng = np.random.default_rng(100) + n_timesteps = len(test_dates_short) + channel_arrays_stream = [ + rng.uniform(arr.min(), arr.max(), size=(n_timesteps, *arr.shape)).astype( + np.float32 + ) + for arr in channel_arrays + ] + + # 3. Create videos + video_results = video_raw_inputs_for_timesteps( + channel_arrays_stream=channel_arrays_stream, + channel_names=channel_names, + dates=test_dates_short, + plot_spec_base=base_plot_spec, + styles=variable_styles, + fps=2, + video_format="gif", + save_dir=tmp_path / "videos", + ) + + assert len(video_results) == len(channel_names) + for _, video_buffer, saved_path in video_results: + assert isinstance(video_buffer, io.BytesIO) + assert saved_path is not None + assert saved_path.exists() + assert saved_path.suffix == ".gif"