diff --git a/requirements.txt b/requirements.txt index 01cd3f9a..aa87c46d 100755 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,5 @@ requests wget distance tqdm -pytest-mock \ No newline at end of file +pytest-mock +netcdf4 \ No newline at end of file diff --git a/swvo/io/RBMDataSet/RBMDataSet.py b/swvo/io/RBMDataSet/RBMDataSet.py index 35be7b3c..c72e5e7a 100644 --- a/swvo/io/RBMDataSet/RBMDataSet.py +++ b/swvo/io/RBMDataSet/RBMDataSet.py @@ -8,7 +8,7 @@ import typing from datetime import timedelta, timezone from pathlib import Path -from typing import Any +from typing import Any, Literal import distance import numpy as np @@ -19,7 +19,9 @@ FileCadenceEnum, FolderTypeEnum, InstrumentEnum, + InstrumentLike, MfmEnum, + MfmLike, SatelliteEnum, SatelliteLike, Variable, @@ -92,9 +94,9 @@ def __init__( end_time: dt.datetime, folder_path: Path, satellite: SatelliteLike, - instrument: InstrumentEnum, - mfm: MfmEnum, - preferred_extension: str = "pickle", + instrument: InstrumentLike, + mfm: MfmLike, + preferred_extension: Literal["mat", "pickle"] = "pickle", *, verbose: bool = True, ) -> None: @@ -110,8 +112,15 @@ def __init__( if isinstance(satellite, str): satellite = SatelliteEnum[satellite.upper()] self._satellite = satellite + + if isinstance(instrument, str): + instrument = InstrumentEnum[instrument.upper()] self._instrument = instrument + + if isinstance(mfm, str): + mfm = MfmEnum[mfm.upper()] self._mfm = mfm + self._folder_path = Path(folder_path) self._preferred_ext = preferred_extension @@ -130,7 +139,7 @@ def __str__(self): return self.__repr__() def __dir__(self): - return super().__dir__() + [var.var_name for var in VariableEnum] + return list(super().__dir__()) + [var.var_name for var in VariableEnum] def __getattr__(self, name: str): # check if a sat variable is requested @@ -252,10 +261,10 @@ def _load_variable(self, var: Variable | VariableEnum) -> None: next_month = start_month + relativedelta(months=1, days=-1) date_str = start_month.strftime("%Y%m%d") + "to" + next_month.strftime("%Y%m%d") - file_name_no_format = self._file_name_stem + date_str + "_" + var.data_server_file_prefix + file_name_no_format = self._file_name_stem + date_str + "_" + var.mat_file_prefix - if var.data_server_has_B: - file_name_no_format += "_" + self._mfm.value + if var.mat_has_B: + file_name_no_format += "_n4_4_" + self._mfm.value file_name_no_format += "_ver4" else: @@ -264,7 +273,7 @@ def _load_variable(self, var: Variable | VariableEnum) -> None: full_file_path = get_file_path_any_format(self._file_path_stem, file_name_no_format, self._preferred_ext) if full_file_path is None: - print(f"File not found {full_file_path}") + print(f"File not found: {self._file_path_stem}, {file_name_no_format}") continue if self._verbose: @@ -318,5 +327,33 @@ def _load_variable(self, var: Variable | VariableEnum) -> None: setattr(self, var_name, loaded_var_arrs[var_name]) + def __eq__(self, other: RBMDataSet) -> bool: + if self._satellite != other._satellite: + return False + if self._instrument != other._instrument: + return False + if self._mfm != other._mfm: + return False + + for var in VariableEnum: + self_var = getattr(self, var.var_name) + other_var = getattr(other, var.var_name) + + if isinstance(self_var, list) and isinstance(other_var, list): + if len(self_var) != len(other_var): + return False + for dt1, dt2 in zip(self_var, other_var): + if dt1 != dt2: + return False + elif isinstance(self_var, np.ndarray) and isinstance(other_var, np.ndarray): + if self_var.shape != other_var.shape: + return False + if not np.allclose(self_var, other_var, equal_nan=True): + return False + elif self_var != other_var: + return False + + return True + from .bin_and_interpolate_to_model_grid import bin_and_interpolate_to_model_grid from .interp_functions import interp_flux diff --git a/swvo/io/RBMDataSet/RBMNcDataSet.py b/swvo/io/RBMDataSet/RBMNcDataSet.py new file mode 100644 index 00000000..ee5e2117 --- /dev/null +++ b/swvo/io/RBMDataSet/RBMNcDataSet.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: 2025 GFZ Helmholtz Centre for Geosciences +# +# SPDX-License-Identifier: Apache-2.0 + +import datetime as dt +import typing +from pathlib import Path +from typing import Any + +import netCDF4 +import numpy as np +from dateutil.relativedelta import relativedelta +from numpy.typing import NDArray + +from swvo.io.RBMDataSet import ( + FolderTypeEnum, + InstrumentLike, + MfmLike, + RBMDataSet, + SatelliteLike, + Variable, + VariableEnum, +) +from swvo.io.RBMDataSet.custom_enums import MfmEnumLiteral, VariableLiteral +from swvo.io.RBMDataSet.utils import join_var + + +def _read_all_datasets_netcdf(file_path: str | Path) -> dict[str, Any]: + """Reads all datasets (variables) from a NetCDF file, including those in groups. + + This function recursively traverses all groups and variables in a NetCDF-4 + file and stores their data in a dictionary. The key for each dataset is its + full hierarchical path. + + Args: + file_path (str | Path): The path to the NetCDF file. + + Returns: + Dict[str, Any]: A dictionary where keys are the full variable paths + and values are the corresponding NumPy arrays. + """ + datasets: dict[str, Any] = {} + file_path = Path(file_path) + + def _read_all_recursively(group: netCDF4.Group | netCDF4.Dataset, path: str = ""): + for var_name, var_obj in group.variables.items(): + full_path = f"{path}/{var_name}" if path else var_name + datasets[full_path] = var_obj[:] + + for group_name, group_obj in group.groups.items(): + new_path = f"{path}/{group_name}" if path else group_name + _read_all_recursively(group_obj, new_path) + + if not file_path.exists(): + print(f"File not found: {file_path}") + return {} + + with netCDF4.Dataset(file_path, "r") as nc_file: + _read_all_recursively(nc_file) + + return datasets + + +class RBMNcDataSet(RBMDataSet): + """Class for handling RBM NetCDF data files.""" + + datetime: list[dt.datetime] + time: NDArray[np.float64] + energy_channels: NDArray[np.float64] + alpha_local: NDArray[np.float64] + alpha_eq_model: NDArray[np.float64] + alpha_eq_real: NDArray[np.float64] + InvMu: NDArray[np.float64] + InvMu_real: NDArray[np.float64] + InvK: NDArray[np.float64] + InvV: NDArray[np.float64] + Lstar: NDArray[np.float64] + Flux: NDArray[np.float64] + PSD: NDArray[np.float64] + MLT: NDArray[np.float64] + B_SM: NDArray[np.float64] + B_total: NDArray[np.float64] + B_sat: NDArray[np.float64] + xGEO: NDArray[np.float64] + P: NDArray[np.float64] + R0: NDArray[np.float64] + density: NDArray[np.float64] + + def __init__( + self, + start_time: dt.datetime, + end_time: dt.datetime, + folder_path: Path, + satellite: SatelliteLike, + instrument: InstrumentLike, + mfm: MfmLike, + *, + verbose: bool = True, + ) -> None: + super().__init__( + start_time, + end_time, + folder_path, + satellite, + instrument, + mfm=mfm, + verbose=verbose, + ) + + def _create_file_path_stem(self) -> Path: + # implement special cases here + # if self._satellite == SatelliteEnum.THEMIS: + # pass + if self._folder_type == FolderTypeEnum.DataServer: + return self._folder_path / self._satellite.mission / self._satellite.sat_name + + if self._folder_type == FolderTypeEnum.SingleFolder: + return self._folder_path + + msg = "Encountered invalid FolderTypeEnum!" + raise ValueError(msg) + + def _load_variable(self, var: Variable | VariableEnum) -> None: + loaded_var_arrs: dict[str, NDArray[np.number]] = {} + var_names_stored: list[str] = [] + + # computed values + if isinstance(var, VariableEnum) and var == VariableEnum.INV_V: + inv_K_repeated = np.repeat(self.InvK[:, np.newaxis, :], self.InvMu.shape[1], axis=1) + + self.InvV = self.InvMu * (inv_K_repeated + 0.5) ** 2 + return + + if isinstance(var, VariableEnum) and var == VariableEnum.P: + self.P = ((self.MLT + 12) / 12 * np.pi) % (2 * np.pi) + return + + for date in self._date_of_files: + if self._folder_type == FolderTypeEnum.DataServer: + start_month = date.replace(day=1) + next_month = start_month + relativedelta(months=1, days=-1) + date_str = start_month.strftime("%Y%m%d") + "to" + next_month.strftime("%Y%m%d") + + file_name = self._file_name_stem + date_str + "_" + self._mfm.value + ".nc" + else: + raise NotImplementedError + + datasets = _read_all_datasets_netcdf(self._file_path_stem / file_name) + + if datasets == {}: + continue + + # also store python datetimes for binning + datetimes = typing.cast( + NDArray[np.object_], + np.asarray( + [dt.datetime.fromtimestamp(t.astype(np.int64), tz=dt.timezone.utc) for t in datasets["time"]] + ), + ) # type: ignore + datasets["datetime"] = datetimes + + # limit in time + correct_time_idx = (datetimes >= self._start_time) & (datetimes <= self._end_time) + + for key, var_arr in datasets.items(): + if ((not isinstance(var_arr, np.ndarray)) or (not np.issubdtype(var_arr.dtype, np.number))) and ( + key != "datetime" + ): + # var represents some strings or metadata objects; don't read them + continue + var_arr = typing.cast("NDArray[np.number]", var_arr) + + # check if var is time dependent + if var_arr.shape[0] == correct_time_idx.shape[0]: + var_arr_trimmed = var_arr[correct_time_idx.reshape(-1), ...] + + joined_value = ( + join_var(loaded_var_arrs[key], var_arr_trimmed) if key in loaded_var_arrs else var_arr_trimmed + ) + else: + joined_value = var_arr + + loaded_var_arrs[key] = joined_value + + if key not in var_names_stored: + var_names_stored.append(key) + + # not a single file was found + if var.var_name not in var_names_stored: + setattr(self, var.var_name, np.asarray([])) + + for var_name in var_names_stored: + if var_name == "datetime": + loaded_var_arrs[var_name] = list(loaded_var_arrs[var_name]) # type: ignore + + rbm_var_name = RBMNcDataSet._get_rbm_name(var_name, self._mfm.value) + + if rbm_var_name is not None: + setattr(self, rbm_var_name, loaded_var_arrs[var_name]) + + @classmethod + def _get_rbm_name(cls, var_name: str, mag_field: MfmEnumLiteral) -> VariableLiteral | None: + match var_name: + case "time": + return "time" + case "datetime": + return "datetime" + case "flux/FEDU": + return "Flux" + case "flux/alpha_eq": + return "alpha_eq_model" + case "flux/energy": + return "energy_channels" + case "flux/alpha_local": + return "alpha_local" + case "position/xGEO": + return "xGEO" + case _ if var_name == f"position/{mag_field}/MLT": + return "MLT" + case _ if var_name == f"position/{mag_field}/R0": + return "R0" + case _ if var_name == f"position/{mag_field}/Lstar": + return "Lstar" + case _ if var_name == f"position/{mag_field}/Lm": + return "Lm" + case _ if var_name == f"mag_field/{mag_field}/B_local": + return "B_total" + case "psd/PSD": + return "PSD" + case _ if var_name == f"psd/{mag_field}/inv_mu": + return "InvMu" + case _ if var_name == f"psd/{mag_field}/inv_K": + return "InvK" + case "density/density_local": + return "density" + case _: + return None diff --git a/swvo/io/RBMDataSet/__init__.py b/swvo/io/RBMDataSet/__init__.py index b410d903..f77d0642 100644 --- a/swvo/io/RBMDataSet/__init__.py +++ b/swvo/io/RBMDataSet/__init__.py @@ -12,7 +12,9 @@ SatelliteLike as SatelliteLike, SatelliteEnum as SatelliteEnum, InstrumentEnum as InstrumentEnum, + InstrumentLike as InstrumentLike, MfmEnum as MfmEnum, + MfmLike as MfmLike, ElPasoMFMEnum as ElPasoMFMEnum, SatelliteLiteral as SatelliteLiteral, ) @@ -20,4 +22,5 @@ from swvo.io.RBMDataSet.interp_functions import TargetType as TargetType from swvo.io.RBMDataSet.scripts.create_RBSP_line_data import create_RBSP_line_data as create_RBSP_line_data from swvo.io.RBMDataSet.RBMDataSet import RBMDataSet as RBMDataSet +from swvo.io.RBMDataSet.RBMNcDataSet import RBMNcDataSet as RBMNcDataSet from swvo.io.RBMDataSet.RBMDataSetElPaso import RBMDataSetElPaso as RBMDataSetElPaso diff --git a/swvo/io/RBMDataSet/custom_enums.py b/swvo/io/RBMDataSet/custom_enums.py index 7500de1e..edb24cda 100644 --- a/swvo/io/RBMDataSet/custom_enums.py +++ b/swvo/io/RBMDataSet/custom_enums.py @@ -27,8 +27,8 @@ class FileCadenceEnum(Enum): @dataclass(frozen=True) class Variable: var_name: str - data_server_file_prefix: str - data_server_has_B: bool + mat_file_prefix: str + mat_has_B: bool without_B: bool = False @@ -140,6 +140,7 @@ class InstrumentEnum(Enum): HOPE = "hope" MAGEIS = "mageis" REPT = "rept" + ECT_COMBINED = "ect_combined" # GOES MAGEDandEPEAD = "MAGEDandEPEAD" @@ -157,9 +158,32 @@ class InstrumentEnum(Enum): TED = "TED-electron" +InstrumentLiteral = Literal[ + "hope", + "mageis", + "rept", + "ect_combined", + "MAGEDandEPEAD", + "MAGED", + "XEP", + "mepe", + "PWE-density", + "orbit", + "TED-electron", +] +InstrumentLike: TypeAlias = InstrumentLiteral | InstrumentEnum + + class MfmEnum(Enum): - T89 = "n4_4_T89" - T04s = "n4_4_T04s" + T89 = "T89" + T04s = "T04s" + T96 = "T96" + TS04 = "T04s" + OP77 = "OP77" + + +MfmEnumLiteral = Literal["T89", "T04s", "TS04", "T96", "OP77"] +MfmLike: TypeAlias = MfmEnumLiteral | MfmEnum class ElPasoMFMEnum(Enum): diff --git a/swvo/io/RBMDataSet/utils.py b/swvo/io/RBMDataSet/utils.py index 35605a6a..d1c108d5 100644 --- a/swvo/io/RBMDataSet/utils.py +++ b/swvo/io/RBMDataSet/utils.py @@ -34,8 +34,8 @@ def get_file_path_any_format(folder_path: Path, file_stem: str, preferred_ext: s if len(all_files) >= 1: extensions_found = [file.suffix[1:] for file in all_files] - if preferred_ext in extensions_found: - if len(all_files) > 1: + if len(all_files) > 1: + if preferred_ext in extensions_found: warnings.warn( ( f"Several files found for {folder_path / (file_stem + '.*')} with extensions: {extensions_found}. " @@ -45,13 +45,20 @@ def get_file_path_any_format(folder_path: Path, file_stem: str, preferred_ext: s ) return folder_path / (file_stem + "." + preferred_ext) - return all_files[0] - else: - warnings.warn( - f"File not found: {folder_path / (file_stem + '.' + preferred_ext)}", - stacklevel=2, + + msg = ( + f"Several files found for {folder_path / (file_stem + '.*')} with extensions: {extensions_found}. " + f"However, the preferred extension ({preferred_ext}) is not available!" ) - return None + raise ValueError(msg) + + if len(all_files) == 1: + return all_files[0] + + warnings.warn( + f"File not found: {folder_path / (file_stem + '.' + preferred_ext)}", + stacklevel=2, + ) return None diff --git a/swvo/io/dst/read_dst_from_multiple_models.py b/swvo/io/dst/read_dst_from_multiple_models.py index cfdbf2b2..d4652071 100644 --- a/swvo/io/dst/read_dst_from_multiple_models.py +++ b/swvo/io/dst/read_dst_from_multiple_models.py @@ -6,6 +6,7 @@ import logging import warnings +from collections.abc import Sequence from datetime import datetime, timezone import numpy as np @@ -23,7 +24,7 @@ def read_dst_from_multiple_models( start_time: datetime, end_time: datetime, - model_order: list[DSTModel] | None = None, + model_order: Sequence[DSTModel] | None = None, historical_data_cutoff_time: datetime | None = None, *, synthetic_now_time: datetime | None = None, # deprecated @@ -43,7 +44,7 @@ def read_dst_from_multiple_models( Start time of the data request. end_time : datetime End time of the data request. - model_order : list or None, optional + model_order : Sequence or None, optional Order in which data will be read from the models. Defaults to [OMNI, WDC]. historical_data_cutoff_time : datetime or None, optional Time representing "now". After this time, no data will be taken from diff --git a/swvo/io/f10_7/read_f107_from_multiple_models.py b/swvo/io/f10_7/read_f107_from_multiple_models.py index 1da266f7..100f561f 100644 --- a/swvo/io/f10_7/read_f107_from_multiple_models.py +++ b/swvo/io/f10_7/read_f107_from_multiple_models.py @@ -6,6 +6,7 @@ import logging import warnings +from collections.abc import Sequence from datetime import datetime, timezone import numpy as np @@ -23,7 +24,7 @@ def read_f107_from_multiple_models( start_time: datetime, end_time: datetime, - model_order: list[F107Model] | None = None, + model_order: Sequence[F107Model] | None = None, historical_data_cutoff_time: datetime | None = None, *, synthetic_now_time: datetime | None = None, # deprecated @@ -43,7 +44,7 @@ def read_f107_from_multiple_models( Start time of the data request. end_time : datetime End time of the data request. - model_order : list or None, optional + model_order : Sequence or None, optional Order in which data will be read from the models. Defaults to [OMNI, SWPC]. historical_data_cutoff_time : datetime or None, optional Time representing "now". After this time, no data will be taken from diff --git a/swvo/io/hp/read_hp_from_multiple_models.py b/swvo/io/hp/read_hp_from_multiple_models.py index 1e8c3cc0..ddb99740 100644 --- a/swvo/io/hp/read_hp_from_multiple_models.py +++ b/swvo/io/hp/read_hp_from_multiple_models.py @@ -8,6 +8,7 @@ import logging import warnings +from collections.abc import Sequence from datetime import datetime, timedelta, timezone from typing import Literal @@ -26,7 +27,7 @@ def read_hp_from_multiple_models( # noqa: PLR0913 start_time: datetime, end_time: datetime, - model_order: list[HpModel] | None = None, + model_order: Sequence[HpModel] | None = None, hp_index: str = "hp30", reduce_ensemble: Literal["mean", "median"] | None = None, historical_data_cutoff_time: datetime | None = None, @@ -48,10 +49,10 @@ def read_hp_from_multiple_models( # noqa: PLR0913 Start time of the data request. end_time : datetime End time of the data request. - model_order : list, optional + model_order : Sequence, optional Order in which data will be read from the models, defaults to [OMNI, Niemegk, Ensemble, SWPC]. - reduce_ensemble : {"mean"}, optional - The method to reduce ensembles to a single time series, defaults to None. + reduce_ensemble : {"mean", "median"} or None, optional + The method to reduce ensembles to a single time series ("mean" or "median"), defaults to None. historical_data_cutoff_time : datetime, optional Time, which represents "now". After this time, no data will be taken from historical models (OMNI, Niemegk), defaults to None. download : bool, optional diff --git a/swvo/io/kp/__init__.py b/swvo/io/kp/__init__.py index 47a4bea8..9549a153 100755 --- a/swvo/io/kp/__init__.py +++ b/swvo/io/kp/__init__.py @@ -9,3 +9,4 @@ # This has to be imported after the models to avoid a circular import from swvo.io.kp.read_kp_from_multiple_models import read_kp_from_multiple_models as read_kp_from_multiple_models # noqa: I001 +from swvo.io.kp.read_kp_from_multiple_models import KpModel as KpModel diff --git a/swvo/io/kp/read_kp_from_multiple_models.py b/swvo/io/kp/read_kp_from_multiple_models.py index dad99fd0..0649fff6 100644 --- a/swvo/io/kp/read_kp_from_multiple_models.py +++ b/swvo/io/kp/read_kp_from_multiple_models.py @@ -8,6 +8,7 @@ import logging import warnings +from collections.abc import Sequence from datetime import datetime, timedelta, timezone from typing import Literal @@ -26,14 +27,14 @@ def read_kp_from_multiple_models( # noqa: PLR0913 start_time: datetime, end_time: datetime, - model_order: list[KpModel] | None = None, + model_order: Sequence[KpModel] | None = None, reduce_ensemble: Literal["mean", "median"] | None = None, historical_data_cutoff_time: datetime | None = None, *, synthetic_now_time: datetime | None = None, # deprecated download: bool = False, recurrence: bool = False, - rec_model_order: list[KpOMNI | KpNiemegk] = None, + rec_model_order: Sequence[KpOMNI | KpNiemegk] | None = None, ) -> pd.DataFrame | list[pd.DataFrame]: """Read Kp data from multiple models. @@ -48,10 +49,10 @@ def read_kp_from_multiple_models( # noqa: PLR0913 The start time of the data request. end_time : datetime The end time of the data request. - model_order : list or None, optional + model_order : Sequence or None, optional The order in which data will be read from the models. Defaults to [OMNI, Niemegk, Ensemble, SWPC]. - reduce_ensemble : {"mean", None}, optional - The method to reduce ensembles to a single time series. Defaults to None. + reduce_ensemble : {"mean", "median"} or None, optional + The method to reduce ensembles to a single time series. Can be "mean", "median", or None. Defaults to None. historical_data_cutoff_time : datetime or None, optional Represents "now". After this time, no data will be taken from historical models (OMNI, Niemegk). Defaults to None. @@ -61,7 +62,7 @@ def read_kp_from_multiple_models( # noqa: PLR0913 recurrence : bool, optional If True, fill missing values using 27-day recurrence from historical models (OMNI, Niemegk). Defaults to False. - rec_model_order : list[KpOMNI | KpNiemegk], optional + rec_model_order : Sequence[KpOMNI | KpNiemegk], optional The order in which historical models will be used for 27-day recurrence filling. Defaults to [OMNI, Niemegk]. diff --git a/swvo/io/solar_wind/read_solar_wind_from_multiple_models.py b/swvo/io/solar_wind/read_solar_wind_from_multiple_models.py index e82be6ec..233721bb 100644 --- a/swvo/io/solar_wind/read_solar_wind_from_multiple_models.py +++ b/swvo/io/solar_wind/read_solar_wind_from_multiple_models.py @@ -6,6 +6,7 @@ import logging import warnings +from collections.abc import Sequence from datetime import datetime, timedelta, timezone from typing import Literal @@ -24,7 +25,7 @@ def read_solar_wind_from_multiple_models( # noqa: PLR0913 start_time: datetime, end_time: datetime, - model_order: list[SWModel] | None = None, + model_order: Sequence[SWModel] | None = None, reduce_ensemble: str | None = None, historical_data_cutoff_time: datetime | None = None, *, diff --git a/tests/io/RBMDataSet/test_RBMDataset.py b/tests/io/RBMDataSet/test_RBMDataset.py index f7332e70..118cbefb 100644 --- a/tests/io/RBMDataSet/test_RBMDataset.py +++ b/tests/io/RBMDataSet/test_RBMDataset.py @@ -80,17 +80,17 @@ def test_init_datetime_timezone(mock_module_string): assert dataset._end_time.tzinfo == timezone.utc -def test_get_satellite_name(mock_dataset): +def test_get_satellite_name(mock_dataset: RBMDataSet): """Test get_satellite_name method.""" assert mock_dataset.get_satellite_name() == "rbspa" -def test_get_satellite_and_instrument_name(mock_dataset): +def test_get_satellite_and_instrument_name(mock_dataset: RBMDataSet): """Test get_satellite_and_instrument_name method.""" assert mock_dataset.get_satellite_and_instrument_name() == "rbspa_mageis" -def test_get_print_name(mock_dataset): +def test_get_print_name(mock_dataset: RBMDataSet): """Test get_print_name method.""" assert mock_dataset.get_print_name() == "rbspa mageis" @@ -113,7 +113,7 @@ def test_satellite_string_input(mock_module_string): assert dataset._satellite == SatelliteEnum.RBSPA -def test_getattr_with_valid_variable(mock_dataset): +def test_getattr_with_valid_variable(mock_dataset: RBMDataSet): """Test __getattr__ with a valid variable.""" with mock.patch.object(mock_dataset, "_load_variable") as _: mock_dataset.Flux = np.array([1.0, 2.0, 3.0]) @@ -122,13 +122,13 @@ def test_getattr_with_valid_variable(mock_dataset): assert (result == np.array([1.0, 2.0, 3.0])).all() -def test_getattr_with_invalid_variable(mock_dataset): +def test_getattr_with_invalid_variable(mock_dataset: RBMDataSet): """Test __getattr__ with an invalid variable.""" with pytest.raises(AttributeError): _ = mock_dataset.NonExistentAttribute -def test_getattr_with_similar_variable(mock_dataset): +def test_getattr_with_similar_variable(mock_dataset: RBMDataSet): """Test __getattr__ suggests similar variable name.""" with pytest.raises(AttributeError) as e: _ = mock_dataset.Flx @@ -136,7 +136,7 @@ def test_getattr_with_similar_variable(mock_dataset): assert "Maybe you meant Flux?" in str(e.value) -def test_computed_invv_variable(mock_dataset): +def test_computed_invv_variable(mock_dataset: RBMDataSet): """Test computed InvV variable.""" mock_dataset.InvK = np.array([[1.0, 2.0]]) @@ -151,7 +151,7 @@ def test_computed_invv_variable(mock_dataset): np.testing.assert_array_equal(mock_dataset.InvV, expected) -def test_computed_p_variable(mock_dataset): +def test_computed_p_variable(mock_dataset: RBMDataSet): """Test computed P variable.""" mock_dataset.MLT = np.array([0.0, 6.0, 12.0, 18.0]) @@ -196,7 +196,7 @@ def test_all_instruments_work(instrument, mock_module_string): assert dataset._instrument == instrument -def test_create_date_list_monthly(mock_dataset): +def test_create_date_list_monthly(mock_dataset: RBMDataSet): """Test monthly cadence date generation.""" mock_dataset.set_file_cadence(FileCadenceEnum.Monthly) date_list = mock_dataset._create_date_list() @@ -204,7 +204,7 @@ def test_create_date_list_monthly(mock_dataset): assert all(date.tzinfo == timezone.utc for date in date_list) -def test_create_date_list_daily(mock_dataset): +def test_create_date_list_daily(mock_dataset: RBMDataSet): """Test daily cadence date generation.""" mock_dataset.set_file_cadence(FileCadenceEnum.Daily) date_list = mock_dataset._create_date_list() @@ -212,32 +212,32 @@ def test_create_date_list_daily(mock_dataset): assert all(date.tzinfo == timezone.utc for date in date_list) -def test_file_name_stem_generation(mock_dataset): +def test_file_name_stem_generation(mock_dataset: RBMDataSet): """Test that file name stem is generated correctly.""" assert mock_dataset._create_file_name_stem() == "rbspa_mageis_" -def test_file_path_stem_dataserver(mock_dataset): +def test_file_path_stem_dataserver(mock_dataset: RBMDataSet): """Test correct file path stem for DataServer folder type.""" expected_path = Path("/mock/path/RBSP/rbspa/Processed_Mat_Files") assert mock_dataset._create_file_path_stem() == expected_path -def test_invalid_cadence_raises(mock_dataset): +def test_invalid_cadence_raises(mock_dataset: RBMDataSet): """Invalid cadence should raise ValueError.""" mock_dataset._file_cadence = None with pytest.raises(ValueError): mock_dataset._create_date_list() -def test_invalid_folder_type_raises(mock_dataset): +def test_invalid_folder_type_raises(mock_dataset: RBMDataSet): """Invalid folder type should raise ValueError.""" mock_dataset._folder_type = None with pytest.raises(ValueError): mock_dataset._create_file_path_stem() -def test_get_var_method(mock_dataset): +def test_get_var_method(mock_dataset: RBMDataSet): """Test get_var returns correct variable.""" mock_dataset.Flux = np.array([4.0, 5.0]) result = mock_dataset.get_var(VariableEnum.FLUX) @@ -266,7 +266,7 @@ def test_load_variable_real_file(): assert isinstance(dataset.alpha_local, np.ndarray), "'alpha_local' should be a NumPy array." -def test_all_variables_in_dir(mock_dataset): +def test_all_variables_in_dir(mock_dataset: RBMDataSet): vars = [ "datetime", "time", diff --git a/tests/io/RBMDataSet/test_RBMNcDataset.py b/tests/io/RBMDataSet/test_RBMNcDataset.py new file mode 100644 index 00000000..8850f488 --- /dev/null +++ b/tests/io/RBMDataSet/test_RBMNcDataset.py @@ -0,0 +1,291 @@ +# SPDX-FileCopyrightText: 2025 GFZ Helmholtz Centre for Geosciences +# +# SPDX-License-Identifier: Apache-2.0 + +import datetime as dt +from datetime import timezone +from pathlib import Path +from unittest import mock + +import numpy as np +import pytest + +from swvo.io.RBMDataSet import ( + FileCadenceEnum, + InstrumentEnum, + MfmEnum, + RBMNcDataSet, + SatelliteEnum, + VariableEnum, +) + + +@pytest.fixture +def mock_module_string(): + return "swvo.io.RBMDataSet.RBMDataSet.RBMDataSet" + + +@pytest.fixture +def mock_dataset(mocker) -> RBMNcDataSet: + start_time = dt.datetime(2023, 1, 1, tzinfo=timezone.utc) + end_time = dt.datetime(2023, 1, 31, tzinfo=timezone.utc) + + mocker.patch( + "swvo.io.RBMDataSet.RBMNcDataSet._read_all_datasets_netcdf", + return_value={ + "time": np.array([dt.datetime(2023, 1, 15).timestamp()]), + "datetime": np.array([dt.datetime(2023, 1, 15, tzinfo=timezone.utc)]), + "flux/energy": np.array([100, 200, 300]), + "flux/alpha_local": np.array([0.1, 0.2, 0.3]), + "flux/FEDU": np.array([[1.0, 2.0, 3.0]]), + }, + ) + + dataset = RBMNcDataSet( + start_time=start_time, + end_time=end_time, + folder_path=Path("/mock/path"), + satellite=SatelliteEnum.RBSPA, + instrument=InstrumentEnum.MAGEIS, + mfm=MfmEnum.T89, + verbose=False, + ) + + return dataset + + +def test_init_datetime_timezone(mock_module_string): + """Test timezone handling for input datetimes.""" + + start_time = dt.datetime(2023, 1, 1) + end_time = dt.datetime(2023, 1, 31) + + with ( + mock.patch(f"{mock_module_string}._create_date_list"), + mock.patch(f"{mock_module_string}._create_file_path_stem"), + mock.patch(f"{mock_module_string}._create_file_name_stem"), + ): + dataset = RBMNcDataSet( + start_time=start_time, + end_time=end_time, + folder_path=Path("/mock/path"), + satellite=SatelliteEnum.RBSPA, + instrument=InstrumentEnum.MAGEIS, + mfm=MfmEnum.T89, + ) + + assert dataset._start_time.tzinfo == timezone.utc + assert dataset._end_time.tzinfo == timezone.utc + + +def test_get_satellite_name(mock_dataset: RBMNcDataSet): + """Test get_satellite_name method.""" + assert mock_dataset.get_satellite_name() == "rbspa" + + +def test_get_satellite_and_instrument_name(mock_dataset: RBMNcDataSet): + """Test get_satellite_and_instrument_name method.""" + assert mock_dataset.get_satellite_and_instrument_name() == "rbspa_mageis" + + +def test_get_print_name(mock_dataset: RBMNcDataSet): + """Test get_print_name method.""" + assert mock_dataset.get_print_name() == "rbspa mageis" + + +def test_satellite_string_input(mock_module_string): + """Test that satellite can be provided as string.""" + with mock.patch(f"{mock_module_string}._create_date_list"): + with mock.patch(f"{mock_module_string}._create_file_path_stem"): + with mock.patch(f"{mock_module_string}._create_file_name_stem"): + dataset = RBMNcDataSet( + start_time=dt.datetime(2023, 1, 1, tzinfo=timezone.utc), + end_time=dt.datetime(2023, 1, 31, tzinfo=timezone.utc), + folder_path=Path("/mock/path"), + satellite="RBSPA", + instrument=InstrumentEnum.MAGEIS, + mfm=MfmEnum.T89, + ) + + assert dataset._satellite == SatelliteEnum.RBSPA + + +def test_getattr_with_valid_variable(mock_dataset: RBMNcDataSet): + """Test __getattr__ with a valid variable.""" + with mock.patch.object(mock_dataset, "_load_variable") as _: + mock_dataset.Flux = np.array([1.0, 2.0, 3.0]) + result = mock_dataset.Flux + assert isinstance(result, np.ndarray) + assert (result == np.array([1.0, 2.0, 3.0])).all() + + +def test_getattr_with_invalid_variable(mock_dataset: RBMNcDataSet): + """Test __getattr__ with an invalid variable.""" + with pytest.raises(AttributeError): + _ = mock_dataset.NonExistentAttribute + + +def test_getattr_with_similar_variable(mock_dataset: RBMNcDataSet): + """Test __getattr__ suggests similar variable name.""" + with pytest.raises(AttributeError) as e: + _ = mock_dataset.Flx + + assert "Maybe you meant Flux?" in str(e.value) + + +def test_computed_invv_variable(mock_dataset: RBMNcDataSet): + """Test computed InvV variable.""" + + mock_dataset.InvK = np.array([[1.0, 2.0]]) + mock_dataset.InvMu = np.array([[0.1, 0.2], [0.3, 0.4]]) + + mock_dataset._load_variable(VariableEnum.INV_V) + + expected = ( + mock_dataset.InvMu + * (np.repeat(mock_dataset.InvK[:, np.newaxis, :], mock_dataset.InvMu.shape[1], axis=1) + 0.5) ** 2 + ) + np.testing.assert_array_equal(mock_dataset.InvV, expected) + + +def test_computed_p_variable(mock_dataset: RBMNcDataSet): + """Test computed P variable.""" + + mock_dataset.MLT = np.array([0.0, 6.0, 12.0, 18.0]) + + mock_dataset._load_variable(VariableEnum.P) + + expected = ((mock_dataset.MLT + 12) / 12 * np.pi) % (2 * np.pi) + np.testing.assert_array_equal(mock_dataset.P, expected) + + +@pytest.mark.parametrize("satellite", list(SatelliteEnum)) +def test_all_satellites_work(satellite, mock_module_string): + """Ensure all SatelliteEnum values initialize without error.""" + with mock.patch(f"{mock_module_string}._create_date_list"): + with mock.patch(f"{mock_module_string}._create_file_path_stem"): + with mock.patch(f"{mock_module_string}._create_file_name_stem"): + dataset = RBMNcDataSet( + start_time=dt.datetime(2023, 1, 1, tzinfo=timezone.utc), + end_time=dt.datetime(2023, 1, 31, tzinfo=timezone.utc), + folder_path=Path("/mock/path"), + satellite=satellite, + instrument=InstrumentEnum.HOPE, + mfm=MfmEnum.T89, + ) + assert dataset._satellite == satellite + + +@pytest.mark.parametrize("instrument", list(InstrumentEnum)) +def test_all_instruments_work(instrument, mock_module_string): + """Ensure all InstrumentEnum values initialize without error.""" + with mock.patch(f"{mock_module_string}._create_date_list"): + with mock.patch(f"{mock_module_string}._create_file_path_stem"): + with mock.patch(f"{mock_module_string}._create_file_name_stem"): + dataset = RBMNcDataSet( + start_time=dt.datetime(2023, 1, 1, tzinfo=timezone.utc), + end_time=dt.datetime(2023, 1, 31, tzinfo=timezone.utc), + folder_path=Path("/mock/path"), + satellite=SatelliteEnum.RBSPA, + instrument=instrument, + mfm=MfmEnum.T89, + ) + assert dataset._instrument == instrument + + +def test_create_date_list_monthly(mock_dataset: RBMNcDataSet): + """Test monthly cadence date generation.""" + mock_dataset.set_file_cadence(FileCadenceEnum.Monthly) + date_list = mock_dataset._create_date_list() + assert date_list[0].month == 1 + assert all(date.tzinfo == timezone.utc for date in date_list) + + +def test_create_date_list_daily(mock_dataset: RBMNcDataSet): + """Test daily cadence date generation.""" + mock_dataset.set_file_cadence(FileCadenceEnum.Daily) + date_list = mock_dataset._create_date_list() + assert len(date_list) > 20 + assert all(date.tzinfo == timezone.utc for date in date_list) + + +def test_file_name_stem_generation(mock_dataset: RBMNcDataSet): + """Test that file name stem is generated correctly.""" + assert mock_dataset._create_file_name_stem() == "rbspa_mageis_" + + +def test_file_path_stem_dataserver(mock_dataset: RBMNcDataSet): + """Test correct file path stem for DataServer folder type.""" + expected_path = Path("/mock/path/RBSP/rbspa/") + assert mock_dataset._create_file_path_stem() == expected_path + + +def test_invalid_cadence_raises(mock_dataset: RBMNcDataSet): + """Invalid cadence should raise ValueError.""" + mock_dataset._file_cadence = None + with pytest.raises(ValueError): + mock_dataset._create_date_list() + + +def test_invalid_folder_type_raises(mock_dataset: RBMNcDataSet): + """Invalid folder type should raise ValueError.""" + mock_dataset._folder_type = None + with pytest.raises(ValueError): + mock_dataset._create_file_path_stem() + + +def test_get_var_method(mock_dataset: RBMNcDataSet): + """Test get_var returns correct variable.""" + mock_dataset.Flux = np.array([4.0, 5.0]) + result = mock_dataset.get_var(VariableEnum.FLUX) + assert isinstance(result, np.ndarray) + assert (result == np.array([4.0, 5.0])).all() + + +def test_load_variable_real_file(): + start_time = dt.datetime(2025, 4, 1, tzinfo=dt.timezone.utc) + end_time = dt.datetime(2025, 4, 30, tzinfo=dt.timezone.utc) + + dataset = RBMNcDataSet( + start_time=start_time, + end_time=end_time, + folder_path=Path("path/to/real/files"), # this does not matter for the test + satellite=SatelliteEnum.GOESSecondary, + instrument=InstrumentEnum.MAGED, + mfm=MfmEnum.T89, + verbose=True, + ) + + dataset._load_variable(VariableEnum.ALPHA_LOCAL) + + assert hasattr(dataset, "alpha_local"), "Dataset should have 'alpha_local' attribute after loading." + assert isinstance(dataset.alpha_local, np.ndarray), "'alpha_local' should be a NumPy array." + + +def test_all_variables_in_dir(mock_dataset: RBMNcDataSet): + vars = [ + "datetime", + "time", + "energy_channels", + "alpha_local", + "alpha_eq_model", + "alpha_eq_real", + "InvMu", + "InvMu_real", + "InvK", + "InvV", + "Lstar", + "Flux", + "PSD", + "MLT", + "B_SM", + "B_total", + "B_sat", + "xGEO", + "P", + "R0", + "density", + ] + + for var in vars: + assert var in mock_dataset.__dir__() diff --git a/tests/io/RBMDataSet/test_utils.py b/tests/io/RBMDataSet/test_utils.py index a5d05628..8ba45aba 100644 --- a/tests/io/RBMDataSet/test_utils.py +++ b/tests/io/RBMDataSet/test_utils.py @@ -37,7 +37,7 @@ def test_get_file_path_any_format_multiple_match(tmp_path: Path): def test_get_file_path_any_format_no_match(tmp_path: Path): - (tmp_path / "nonexistent.tmp").touch() + (tmp_path / "nonexistent").touch() result = utils.get_file_path_any_format(tmp_path, "nonexistent", "pickle") assert result is None