diff --git a/README.md b/README.md index 7a2d30d..31c6cd0 100755 --- a/README.md +++ b/README.md @@ -85,6 +85,13 @@ This package provides tools to read, process, and analyze several key solar and - SWIFT: `SWSWIFTEnsemble` - Combined: `read_solar_wind_from_multiple_models` +- **Plasmasphere Density Predictions**: + Reader utilities for PAGER plasmasphere density grids and model combined inputs. + - **Sources & Classes:** + - Density predictions: `PlasmaspherePredictionReader` + - Combined inputs: `PlasmasphereCombinedInputsReader` + - Density cube container: `PlasmasphereDensityCube` + Each index can be accessed via these dedicated reader classes, which handle downloading and read methods. See the code in `swvo/io` or API documentation for details on each index's implementation. ## Installation diff --git a/swvo/io/RBMDataSet/RBMDataSet.py b/swvo/io/RBMDataSet/RBMDataSet.py index c6c7a01..670934a 100644 --- a/swvo/io/RBMDataSet/RBMDataSet.py +++ b/swvo/io/RBMDataSet/RBMDataSet.py @@ -229,14 +229,13 @@ def __getattr__(self, name: str) -> NDArray[np.float64]: raise AttributeError(msg) def load(self, name_or_var: str | VariableEnum) -> None: - """ Load data into memory """ + """Load data into memory""" if isinstance(name_or_var, VariableEnum): getattr(self, name_or_var.var_name) else: getattr(self, name_or_var) - def find_similar_variable(self, name: str) -> tuple[None | VariableEnum, dict[str, Any]]: levenstein_info: dict[str, Any] = {"min_distance": 10, "var_name": ""} sat_variable = None @@ -245,7 +244,7 @@ def find_similar_variable(self, name: str) -> tuple[None | VariableEnum, dict[st sat_variable = var break else: - dist = distance.levenshtein(name, var.var_name) # ty:ignore[possibly-missing-attribute] + dist = distance.levenshtein(name, var.var_name) if name.lower() in var.var_name.lower(): dist = 1 diff --git a/swvo/io/RBMDataSet/scripts/create_RBSP_line_data.py b/swvo/io/RBMDataSet/scripts/create_RBSP_line_data.py index 896f2e7..7fa9f49 100644 --- a/swvo/io/RBMDataSet/scripts/create_RBSP_line_data.py +++ b/swvo/io/RBMDataSet/scripts/create_RBSP_line_data.py @@ -82,7 +82,7 @@ def create_RBSP_line_data( InstrumentEnum.MAGEIS, InstrumentEnum.REPT, ] - satellites = satellites or [SatelliteEnum.RBSPA, SatelliteEnum.RBSPB] + satellites = satellites or [SatelliteEnum.RBSPA, SatelliteEnum.RBSPB] # ty :ignore[invalid-assignment] # pass and check args if isinstance(data_server_path, str): @@ -92,7 +92,7 @@ def create_RBSP_line_data( if not isinstance(target_en, Iterable): target_en = [target_en] if not isinstance(satellites, Iterable) or isinstance(satellites, str): - satellites = [satellites] + satellites = [satellites] # ty :ignore[invalid-assignment] if isinstance(target_type, str): target_type = TargetType[target_type] @@ -102,7 +102,7 @@ def create_RBSP_line_data( result_arr = [] list_instruments_used = [] - for satellite in satellites: + for satellite in satellites: # ty :ignore[not-iterable] rbm_data: list[RBMDataSet] = [] for i, instrument in enumerate(instruments): diff --git a/swvo/io/RBMDataSet/utils.py b/swvo/io/RBMDataSet/utils.py index 4fb8585..540f477 100644 --- a/swvo/io/RBMDataSet/utils.py +++ b/swvo/io/RBMDataSet/utils.py @@ -31,12 +31,10 @@ def join_var(var1: NDArray[np.generic], var2: NDArray[np.generic]) -> NDArray[np def get_file_path_any_format(folder_path: Path, file_stem: str, preferred_ext: str) -> Path | None: """Get the file path for a given file stem and preferred extension.""" pattern = re.compile(fnmatch.translate(file_stem + ".*"), re.IGNORECASE) - - if not folder_path.exists(): - return None - - all_files = [p for p in folder_path.iterdir() if pattern.match(p.name)] - + try: + all_files = [p for p in folder_path.iterdir() if pattern.match(p.name)] + except FileNotFoundError: + all_files = [] if len(all_files) == 0: warnings.warn(f"File not found: {folder_path / (file_stem + '.*')}", stacklevel=2) @@ -120,10 +118,9 @@ def matlab2python(datenum: float | Iterable[float]) -> Iterable[datetime] | date datenum = pd.to_datetime(datenum - 719529, unit="D", origin=pd.Timestamp("1970-01-01")).to_pydatetime() # ty:ignore[unresolved-attribute] if isinstance(datenum, Iterable): - datenum = enforce_utc_timezone(list(datenum)) # ty:ignore[invalid-assignment] + datenum = enforce_utc_timezone(list(datenum)) # ty:ignore[no-matching-overload] datenum = [ # ty:ignore[invalid-assignment] - round_seconds(x) # ty:ignore[invalid-argument-type] - for x in datenum # ty:ignore[not-iterable] + round_seconds(x) for x in datenum ] else: datenum = round_seconds(enforce_utc_timezone(datenum)) # ty:ignore[invalid-assignment] diff --git a/swvo/io/omni/omni_high_res.py b/swvo/io/omni/omni_high_res.py index 3fc0afc..f0444ad 100644 --- a/swvo/io/omni/omni_high_res.py +++ b/swvo/io/omni/omni_high_res.py @@ -413,11 +413,11 @@ def _get_data_from_omni(self, start: datetime, end: datetime, cadence: int = 1) if cadence == 1: params = {"res": "min", "spacecraft": "omni_min"} payload.update(params) - payload.update(common_vars) + payload.update(common_vars) # ty: ignore[no-matching-overload] elif cadence == 5: params = {"res": "5min", "spacecraft": "omni_5min"} payload.update(params) - payload.update(common_vars) + payload.update(common_vars) # ty: ignore[no-matching-overload] else: msg = f"Invalid cadence: {cadence}. Only 1 or 5 minutes are supported." diff --git a/swvo/io/plasmasphere/__init__.py b/swvo/io/plasmasphere/__init__.py index ac00e11..f1b1861 100644 --- a/swvo/io/plasmasphere/__init__.py +++ b/swvo/io/plasmasphere/__init__.py @@ -1,5 +1,10 @@ -# SPDX-FileCopyrightText: 2025 GFZ Helmholtz Centre for Geosciences +# SPDX-FileCopyrightText: 2026 GFZ Helmholtz Centre for Geosciences +# SPDX-FileContributor: Sahil Jhawar # # SPDX-License-Identifier: Apache-2.0 -from swvo.io.plasmasphere import read_plasmasphere as read_plasmasphere +from swvo.io.plasmasphere.read_plasmasphere import PlasmasphereDensityCube as PlasmasphereDensityCube +from swvo.io.plasmasphere.read_plasmasphere import PlasmaspherePredictionReader as PlasmaspherePredictionReader +from swvo.io.plasmasphere.read_plasmasphere_combined_inputs import ( + PlasmasphereCombinedInputsReader as PlasmasphereCombinedInputsReader, +) diff --git a/swvo/io/plasmasphere/read_plasmasphere.py b/swvo/io/plasmasphere/read_plasmasphere.py index a4a17d8..0d89f37 100644 --- a/swvo/io/plasmasphere/read_plasmasphere.py +++ b/swvo/io/plasmasphere/read_plasmasphere.py @@ -1,18 +1,132 @@ # SPDX-FileCopyrightText: 2025 GFZ Helmholtz Centre for Geosciences +# SPDX-FileContributor: Stefano Bianco +# SPDX-FileContributor: Sahil Jhawar # # SPDX-License-Identifier: Apache-2.0 import logging import os +from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from typing import Optional +import numpy as np import pandas as pd logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class PlasmasphereDensityCube: + """A structured container for plasmaspheric electron density data. + + Attributes + ------- + time : np.ndarray[datetime] + Array of Python datetime values. + l : np.ndarray + Array of L-values. + mlt : np.ndarray + Array of MLT-values. + l_grid : np.ndarray + Grid of L-values. L x MLT shape. + mlt_grid : np.ndarray + Grid of MLT-values. L x MLT shape. + density_grid : list[np.ndarray] + List of arrays (n_time x n_L x n_MLT) containing electron density values for each time step. + density_column : list[str] + Name(s) of the column(s) containing electron density data. + """ + + time: np.ndarray[datetime] # ty: ignore[invalid-type-arguments] + l: np.ndarray # noqa: E741 + mlt: np.ndarray + l_grid: np.ndarray + mlt_grid: np.ndarray + density_grid: list[np.ndarray] + density_column: list[str] + + def __str__(self) -> str: + """Readable summary for logging and printing.""" + num_times = len(self.time) + + l_range = f"[{self.l.min():.2f}, {self.l.max():.2f}]" + mlt_range = f"[{self.mlt.min():.2f}, {self.mlt.max():.2f}]" + + summary = [ + "--- Plasmasphere Density Cube ---", + f"Temporal Span : {num_times} steps ({self.time[0]} to {self.time[-1]})", + f"Spatial L-Bins: {len(self.l)} {l_range}", + f"Spatial MLT-Bins: {len(self.mlt)} {mlt_range}", + f"Density Grid Geometry per Time Step : {self.density_grid[0].shape if isinstance(self.density_grid, list) else self.density_grid.shape} (Time x L x MLT)", + f"Data Columns : {self.density_column}", + "----------------------------------", + ] + return "\n".join(summary) + + def __post_init__(self) -> None: + if len(self.density_grid) != len(self.density_column): + msg = f"Length of density_grid ({len(self.density_grid)}) must match length of density_column ({len(self.density_column)})." + logger.error(msg) + raise ValueError(msg) + + def __eq__(self, other: object) -> bool: + return isinstance(other, PlasmasphereDensityCube) and not self.diff(other) + + def diff(self, other: object) -> list[str]: + issues = [] + if not isinstance(other, PlasmasphereDensityCube): + issues.append("type mismatch") + return issues + + if not np.array_equal(self.time, other.time): + issues.append("time mismatch") + if not np.array_equal(self.l, other.l): + issues.append("l mismatch") + if not np.array_equal(self.mlt, other.mlt): + issues.append("mlt mismatch") + if not np.array_equal(self.l_grid, other.l_grid): + issues.append("l_grid mismatch") + if not np.array_equal(self.mlt_grid, other.mlt_grid): + issues.append("mlt_grid mismatch") + if not all(np.array_equal(a, b) for a, b in zip(self.density_grid, other.density_grid)): + issues.append("density_grid mismatch") + if self.density_column != other.density_column: + issues.append("density_column mismatch") + + return issues + + def get_density_at_time(self, time: datetime) -> list[np.ndarray]: + """Extract density grid for a specific time. + + Parameters + ---------- + time : datetime + The specific time for which to extract the density grid. + + Returns + ------- + list[np.ndarray] + The list of ensemble density grid corresponding to the specified time. + + Raises + ------ + IndexError + If the specified time is not found in the density cube. + """ + if time not in self.time: + logger.error(f"Requested time {time} not found in density cube.") + raise IndexError(f"Requested time {time} not found in density cube.") + + time_index = np.where(self.time == time)[0][0] + + if isinstance(self.density_grid, list): + return [grid[time_index] for grid in self.density_grid] + else: + return [self.density_grid[time_index]] + + class PlasmaspherePredictionReader: """Reads one of the available PAGER plasmasphere density prediction. @@ -48,6 +162,11 @@ def __init__(self, data_dir: Optional[Path] = None) -> None: logger.error(msg) raise FileNotFoundError(msg) + def _parse_none_date(self, date: datetime | None) -> datetime: + if date is None: + return datetime.now(timezone.utc).replace(microsecond=0, minute=0, second=0) + return date.replace(minute=0, second=0, microsecond=0) + def read(self, requested_date: datetime | None = None) -> pd.DataFrame | None: """ Reads one of the available PAGER plasmasphere density prediction. @@ -68,21 +187,169 @@ def read(self, requested_date: datetime | None = None) -> pd.DataFrame | None: pandas.DataFrame with L, MLT, density and date as columns """ - if requested_date is None: - requested_date = datetime.now(timezone.utc).replace(microsecond=0, minute=0, second=0) - - requested_date = requested_date.replace(minute=0, second=0, microsecond=0) + requested_date = self._parse_none_date(requested_date) - file_name = f"plasmasphere_density_{requested_date.year}{str(requested_date.month).zfill(2)}{str(requested_date.day).zfill(2)}T{str(requested_date.hour).zfill(2)}00.csv" + file_name = f"plasmasphere_density_{requested_date.strftime('%Y%m%dT%H00')}.csv" file_path = os.path.join(self.data_dir, file_name) logger.info(f"Looking for file {file_path} for date {requested_date}") if not os.path.isfile(file_path): - msg = f"No suitable files ({file_path}) found in the folder {self.data_dir} for the requested date {requested_date}" - logger.warning(msg) + msg = f"No suitable files ({file_path}) found for the requested date {requested_date}. Returning None." + logger.error(msg) return None + logger.info(f"Reading plasmasphere density data from {file_path}") + data = pd.read_csv(file_path, parse_dates=["date"]) data["t"] = data["date"] data.drop(labels=["date"], axis=1, inplace=True) return data + + def _validate_data(self, data: pd.DataFrame) -> None: + if not isinstance(data, pd.DataFrame): + msg = f"data must be an instance of a pandas dataframe, instead it is of type {type(data)}" + logger.error(msg) + raise TypeError(msg) + + required_columns = ["L", "MLT", "t"] + for column in required_columns: + if column not in data.columns: + msg = f"column {column} is missing" + logger.error(msg) + raise ValueError(msg) + + if data.empty: + msg = "data dataframe is empty" + logger.error(msg) + raise ValueError(msg) + + if not pd.api.types.is_datetime64_any_dtype(data["t"]): + msg = "values of 't' column must be datetime objects" + logger.error(msg) + raise TypeError(msg) + + def _get_density_columns(self, data: pd.DataFrame) -> list[str]: + density_columns = [column for column in data.columns if "predicted_densities" in column] + if not density_columns: + msg = "no columns matching 'predicted_densities' were found" + logger.error(msg) + raise ValueError(msg) + return density_columns + + def _resolve_density_column(self, data: pd.DataFrame, density_column: str | None) -> str: + density_columns = self._get_density_columns(data) + if density_column is None: + return density_columns[0] + if density_column not in density_columns: + msg = f"density_column '{density_column}' is not valid. Available columns: {density_columns}" + logger.error(msg) + raise ValueError(msg) + return density_column + + def _legacy_reshape_2d(self, df_date: pd.DataFrame, density_column: str) -> tuple: + l_values = df_date["L"].to_numpy() + mlt_values = df_date["MLT"].to_numpy() + density_values = df_date[density_column].to_numpy(dtype=float) + + l_axis = np.unique(l_values) + mlt_axis = np.unique(mlt_values) + + expected_points = len(l_axis) * len(mlt_axis) + if len(df_date) != expected_points: + msg = "data for a single timestamp does not form a complete L-MLT grid. Expected n_L * n_MLT rows." + logger.error(msg) + raise ValueError(msg) + + l_grid = np.reshape(l_values, (len(l_axis), len(mlt_axis)), order="F") + mlt_grid = np.reshape(mlt_values, (len(l_axis), len(mlt_axis)), order="F") + density_2d = np.reshape(density_values, (len(l_axis), len(mlt_axis)), order="F") + + return l_axis, mlt_axis, l_grid, mlt_grid, density_2d + + def build_density_cube( + self, + requested_date: datetime | None = None, + density_column: str | None = None, + ) -> Optional[PlasmasphereDensityCube]: + """ + Build density tensor with shape time x L x MLT. + + Parameters + ---------- + requested_date : datetime.datetime or None + Date of plasma density prediction that we want to read up to hour precision. + + Returns + ------- + PlasmasphereDensityCube or None + If `density_column` is provided, `density_grid` has shape + (n_time, n_L, n_MLT). If `density_column` is None, `density_grid` + is a list of arrays with that same shape (one per density column). + + If no data is available for the requested date, returns None. + """ + requested_date = self._parse_none_date(requested_date) + data = self.read(requested_date=requested_date) + if data is None: + return None + self._validate_data(data) + + if density_column is None: + resolved_density_columns = self._get_density_columns(data) + else: + resolved_density_columns = [self._resolve_density_column(data, density_column)] + dates = np.sort(data["t"].unique()) + dates = pd.to_datetime(dates) + dates_to_return = np.array([dt.to_pydatetime() for dt in dates.to_list()]) + density_slices_by_column = {column: [] for column in resolved_density_columns} + + l_axis_ref = None + mlt_axis_ref = None + l_grid_ref = None + mlt_grid_ref = None + + for date in dates: + df_date = data[pd.to_datetime(data["t"]) == date] + for column in resolved_density_columns: + l_axis, mlt_axis, l_grid, mlt_grid, density_2d = self._legacy_reshape_2d(df_date, column) + + if l_axis_ref is None: + l_axis_ref = l_axis + mlt_axis_ref = mlt_axis + l_grid_ref = l_grid + mlt_grid_ref = mlt_grid + else: + assert mlt_axis_ref is not None + if not np.array_equal(l_axis_ref, l_axis) or not np.array_equal(mlt_axis_ref, mlt_axis): + msg = "Inconsistent L/MLT axes across timestamps." + logger.error(msg) + raise ValueError(msg) + + density_slices_by_column[column].append(density_2d) + + if l_axis_ref is None or mlt_axis_ref is None or l_grid_ref is None or mlt_grid_ref is None: + msg = "Unable to build density cube axes from input data." + logger.error(msg) + raise RuntimeError(msg) + + if len(resolved_density_columns) == 1: + resolved_density_column: list[str] = [resolved_density_columns[0]] + density_grid = [ + np.stack( + density_slices_by_column[resolved_density_columns[0]], + axis=0, + ) + ] + else: + resolved_density_column = resolved_density_columns + density_grid = [np.stack(density_slices_by_column[column], axis=0) for column in resolved_density_columns] + + return PlasmasphereDensityCube( + time=dates_to_return, + l=l_axis_ref, + mlt=mlt_axis_ref, + l_grid=l_grid_ref, + mlt_grid=mlt_grid_ref, + density_grid=density_grid, + density_column=resolved_density_column, + ) diff --git a/tests/io/plasmasphere/test_read_plasmasphere.py b/tests/io/plasmasphere/test_read_plasmasphere.py new file mode 100644 index 0000000..322f2a6 --- /dev/null +++ b/tests/io/plasmasphere/test_read_plasmasphere.py @@ -0,0 +1,206 @@ +# SPDX-FileCopyrightText: 2026 GFZ Helmholtz Centre for Geosciences +# SPDX-FileContributor: Sahil Jhawar +# +# SPDX-License-Identifier: Apache-2.0 + +import os +from datetime import datetime, timezone +from pathlib import Path +from unittest.mock import patch + +import numpy as np +import pandas as pd +import pytest + +from swvo.io.plasmasphere import ( + PlasmasphereDensityCube, + PlasmaspherePredictionReader, +) + + +def _build_sample_density_dataframe() -> pd.DataFrame: + rows = [] + times = [datetime(2026, 1, 1, h, 0, tzinfo=timezone.utc) for h in range(24)] + for ts in times: + for mlt in [0.0, 6.0, 12.0]: + for l_value in [1.5, 4, 6.5]: + rows.append( + { + "date": ts, + "L": l_value, + "MLT": mlt, + "predicted_densities_ensemble_member_0": float(l_value + mlt), + "predicted_densities_ensemble_member_1": float(2.0 * l_value + mlt), + } + ) + return pd.DataFrame(rows) + + +class TestPlasmaspherePredictionReader: + @pytest.fixture + def data_dir(self, tmp_path: Path) -> Path: + return tmp_path / "plasmasphere" + + @pytest.fixture + def sample_file(self, data_dir: Path) -> Path: + data_dir.mkdir(parents=True, exist_ok=True) + sample_date = datetime(2026, 1, 1, 0, 0, tzinfo=timezone.utc) + file_name = f"plasmasphere_density_{sample_date.strftime('%Y%m%dT%H00')}.csv" + file_path = data_dir / file_name + _build_sample_density_dataframe().to_csv(file_path, index=False) + return file_path + + def test_initialization_with_env_var(self, data_dir: Path): + data_dir.mkdir(parents=True, exist_ok=True) + with patch.dict(os.environ, {PlasmaspherePredictionReader.ENV_VAR_NAME: str(data_dir)}): + reader = PlasmaspherePredictionReader() + assert reader.data_dir == data_dir + + def test_initialization_without_env_var(self): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError): + PlasmaspherePredictionReader() + + def test_read_existing_file(self, sample_file: Path): + reader = PlasmaspherePredictionReader(data_dir=sample_file.parent) + result = reader.read(requested_date=datetime(2026, 1, 1, 0, 0, tzinfo=timezone.utc)) + assert result is not None + assert "t" in result.columns + assert len(result) == 216 # 24 hours * 3 MLT bins * 3 L bins + + def test_read_missing_file_returns_none(self, data_dir: Path): + data_dir.mkdir(parents=True, exist_ok=True) + reader = PlasmaspherePredictionReader(data_dir=data_dir) + result = reader.read(requested_date=datetime(2026, 1, 1, 0, 0, tzinfo=timezone.utc)) + assert result is None + + def test_build_density_cube_single_column(self, sample_file: Path): + reader = PlasmaspherePredictionReader(data_dir=sample_file.parent) + cube = reader.build_density_cube( + requested_date=datetime(2026, 1, 1, 0, 0, tzinfo=timezone.utc), + density_column="predicted_densities_ensemble_member_1", + ) + + assert cube is not None + assert isinstance(cube, PlasmasphereDensityCube) + assert cube.density_column == ["predicted_densities_ensemble_member_1"] + assert len(cube.density_grid) == 1 + assert cube.density_grid[0].shape == (24, 3, 3) + + def test_build_density_cube_multi_column(self, sample_file: Path): + reader = PlasmaspherePredictionReader(data_dir=sample_file.parent) + cube = reader.build_density_cube(requested_date=datetime(2026, 1, 1, 0, 0, tzinfo=timezone.utc)) + + assert cube is not None + assert isinstance(cube.density_column, list) + assert sorted(cube.density_column) == [ + "predicted_densities_ensemble_member_0", + "predicted_densities_ensemble_member_1", + ] + assert len(cube.density_grid) == 2 + assert all(grid.shape == (24, 3, 3) for grid in cube.density_grid) + + def test_get_density_at_time(self, sample_file: Path): + reader = PlasmaspherePredictionReader(data_dir=sample_file.parent) + cube = reader.build_density_cube( + requested_date=datetime(2026, 1, 1, 0, 0, tzinfo=timezone.utc), + density_column="predicted_densities_ensemble_member_1", + ) + + assert cube is not None + density_slice_at_time_t = cube.get_density_at_time(datetime(2026, 1, 1, 1, 0, tzinfo=timezone.utc)) + + assert isinstance(density_slice_at_time_t, list) + assert density_slice_at_time_t[0].shape == (3, 3) + + def test_get_density_at_time_missing_raises(self, sample_file: Path): + reader = PlasmaspherePredictionReader(data_dir=sample_file.parent) + cube = reader.build_density_cube( + requested_date=datetime(2026, 1, 1, 0, 0, tzinfo=timezone.utc), + density_column="predicted_densities_ensemble_member_1", + ) + + assert cube is not None + with pytest.raises(IndexError): + cube.get_density_at_time(datetime(2026, 1, 2, 23, 0, tzinfo=timezone.utc)) + + +def test_density_cube_equality(): + t = np.array([datetime(2026, 1, 1, 0, 0)]) + l = np.array([2.0, 3.0]) # noqa: E741 + mlt = np.array([0.0, 12.0]) + l_grid = np.array([[2.0, 2.0], [3.0, 3.0]]) + mlt_grid = np.array([[0.0, 12.0], [0.0, 12.0]]) + density_grid = [np.ones((1, 2, 2))] + + cube_a = PlasmasphereDensityCube( + time=t, + l=l, + mlt=mlt, + l_grid=l_grid, + mlt_grid=mlt_grid, + density_grid=density_grid, + density_column=["predicted_densities_ensemble_member_1"], + ) + cube_b = PlasmasphereDensityCube( + time=t.copy(), + l=l.copy(), + mlt=mlt.copy(), + l_grid=l_grid.copy(), + mlt_grid=mlt_grid.copy(), + density_grid=[density_grid[0].copy()], + density_column=["predicted_densities_ensemble_member_1"], + ) + + assert cube_a == cube_b + + +def test_density_cube_not_equal(): + t = np.array([datetime(2026, 1, 1, 0, 0)]) + l = np.array([2.0, 3.0]) # noqa: E741 + mlt = np.array([0.0, 12.0]) + l_grid = np.array([[2.0, 2.0], [3.0, 3.0]]) + mlt_grid = np.array([[0.0, 12.0], [0.0, 12.0]]) + density_grid = [np.ones((1, 2, 2))] + + cube_a = PlasmasphereDensityCube( + time=t, + l=l, + mlt=mlt, + l_grid=l_grid, + mlt_grid=mlt_grid, + density_grid=density_grid, + density_column=["predicted_densities_ensemble_member_1"], + ) + cube_b = PlasmasphereDensityCube( + time=t.copy(), + l=l.copy(), + mlt=mlt.copy(), + l_grid=l_grid.copy(), + mlt_grid=mlt_grid.copy(), + density_grid=[density_grid[0].copy()], + density_column=["predicted_densities_ensemble_member_0"], + ) + + assert cube_a.diff(cube_b)[0] == "density_column mismatch" + assert cube_a != cube_b + + +def test_density_cube_not_equal_different_density_grid_length(): + t = np.array([datetime(2026, 1, 1, 0, 0)]) + l = np.array([2.0, 3.0]) # noqa: E741 + mlt = np.array([0.0, 12.0]) + l_grid = np.array([[2.0, 2.0], [3.0, 3.0]]) + mlt_grid = np.array([[0.0, 12.0], [0.0, 12.0]]) + density_grid = [np.ones((1, 2, 2))] + + with pytest.raises(ValueError): + _ = PlasmasphereDensityCube( + time=t, + l=l, + mlt=mlt, + l_grid=l_grid, + mlt_grid=mlt_grid, + density_grid=density_grid, + density_column=["predicted_densities_ensemble_member_0", "predicted_densities_ensemble_member_1"], + )