Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ requests
wget
distance
tqdm
pytest-mock
pytest-mock
netcdf4
54 changes: 45 additions & 9 deletions swvo/io/RBMDataSet/RBMDataSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import typing
from datetime import timedelta, timezone
from pathlib import Path
from typing import Any
from typing import Any, Literal

import distance
import numpy as np
Expand All @@ -19,7 +19,9 @@
FileCadenceEnum,
FolderTypeEnum,
InstrumentEnum,
InstrumentLike,
MfmEnum,
MfmLike,
SatelliteEnum,
SatelliteLike,
Variable,
Expand Down Expand Up @@ -92,9 +94,9 @@ def __init__(
end_time: dt.datetime,
folder_path: Path,
satellite: SatelliteLike,
instrument: InstrumentEnum,
mfm: MfmEnum,
preferred_extension: str = "pickle",
instrument: InstrumentLike,
mfm: MfmLike,
preferred_extension: Literal["mat", "pickle"] = "pickle",
*,
verbose: bool = True,
) -> None:
Expand All @@ -110,8 +112,14 @@ def __init__(
if isinstance(satellite, str):
satellite = SatelliteEnum[satellite.upper()]
self._satellite = satellite
if isinstance(instrument, str):
instrument = InstrumentEnum[instrument.upper()]
self._instrument = instrument

if mfm.upper() in MfmEnum.__members__:
mfm = MfmEnum[mfm.upper()]
Copy link

Copilot AI Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Case handling for MfmEnum is incorrect and can leave self._mfm as a str (e.g., 'T04s' becomes 'T04S' and misses members), which later breaks when accessing .value. Handle strings explicitly and avoid uppercasing for mixed-case names. Example: if isinstance(mfm, str): key = mfm if mfm in MfmEnum.members else mfm.upper(); if key in MfmEnum.members: mfm = MfmEnum[key]; elif not isinstance(mfm, MfmEnum): raise ValueError('Invalid mfm'); then assign self._mfm = mfm.

Suggested change
if mfm.upper() in MfmEnum.__members__:
mfm = MfmEnum[mfm.upper()]
if isinstance(mfm, str):
key = mfm if mfm in MfmEnum.__members__ else mfm.upper()
if key in MfmEnum.__members__:
mfm = MfmEnum[key]
else:
raise ValueError(f"Invalid mfm: {mfm}")
elif not isinstance(mfm, MfmEnum):
raise ValueError(f"Invalid mfm: {mfm}")

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DoctorRabbit55 can you check this?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this review sort of make sense. I tried it locally and if the mfm is a invalid string, it is still set to self._mfm. And even if the mfm is a valid str, value self._mfm is set to a str mfm istead of the corresponding MfmEnum, which I think is wrong as it will break the self._mfm.value

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this also fails for mfm: str = "T04s"

self._mfm = mfm

self._folder_path = Path(folder_path)

self._preferred_ext = preferred_extension
Expand All @@ -130,7 +138,7 @@ def __str__(self):
return self.__repr__()

def __dir__(self):
return super().__dir__() + [var.var_name for var in VariableEnum]
return list(super().__dir__()) + [var.var_name for var in VariableEnum]

def __getattr__(self, name: str):
# check if a sat variable is requested
Expand Down Expand Up @@ -252,10 +260,10 @@ def _load_variable(self, var: Variable | VariableEnum) -> None:
next_month = start_month + relativedelta(months=1, days=-1)
date_str = start_month.strftime("%Y%m%d") + "to" + next_month.strftime("%Y%m%d")

file_name_no_format = self._file_name_stem + date_str + "_" + var.data_server_file_prefix
file_name_no_format = self._file_name_stem + date_str + "_" + var.mat_file_prefix

if var.data_server_has_B:
file_name_no_format += "_" + self._mfm.value
if var.mat_has_B:
file_name_no_format += "_n4_4_" + self._mfm.value

file_name_no_format += "_ver4"
else:
Expand All @@ -264,7 +272,7 @@ def _load_variable(self, var: Variable | VariableEnum) -> None:
full_file_path = get_file_path_any_format(self._file_path_stem, file_name_no_format, self._preferred_ext)

if full_file_path is None:
print(f"File not found {full_file_path}")
print(f"File not found: {self._file_path_stem}, {file_name_no_format}")
continue

if self._verbose:
Expand Down Expand Up @@ -318,5 +326,33 @@ def _load_variable(self, var: Variable | VariableEnum) -> None:

setattr(self, var_name, loaded_var_arrs[var_name])

def __eq__(self, other: RBMDataSet) -> bool:
if self._satellite != other._satellite:
return False
if self._instrument != other._instrument:
return False
if self._mfm != other._mfm:
return False

for var in VariableEnum:
self_var = getattr(self, var.var_name)
other_var = getattr(other, var.var_name)

if isinstance(self_var, list) and isinstance(other_var, list):
if len(self_var) != len(other_var):
return False
for dt1, dt2 in zip(self_var, other_var):
if dt1 != dt2:
return False
elif isinstance(self_var, np.ndarray) and isinstance(other_var, np.ndarray):
if self_var.shape != other_var.shape:
return False
if not np.allclose(self_var, other_var, equal_nan=True):
return False
elif self_var != other_var:
return False

return True

from .bin_and_interpolate_to_model_grid import bin_and_interpolate_to_model_grid
from .interp_functions import interp_flux
236 changes: 236 additions & 0 deletions swvo/io/RBMDataSet/RBMNcDataSet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# SPDX-FileCopyrightText: 2025 GFZ Helmholtz Centre for Geosciences
#
# SPDX-License-Identifier: Apache-2.0

import datetime as dt
import typing
from pathlib import Path
from typing import Any

import netCDF4
import numpy as np
from dateutil.relativedelta import relativedelta
from numpy.typing import NDArray

from swvo.io.RBMDataSet import (
FolderTypeEnum,
InstrumentLike,
MfmLike,
RBMDataSet,
SatelliteLike,
Variable,
VariableEnum,
Comment thread
sahiljhawar marked this conversation as resolved.
)
from swvo.io.RBMDataSet.custom_enums import MfmEnumLiteral, VariableLiteral
from swvo.io.RBMDataSet.utils import join_var


def _read_all_datasets_netcdf(file_path: str | Path) -> dict[str, Any]:
"""Reads all datasets (variables) from a NetCDF file, including those in groups.

This function recursively traverses all groups and variables in a NetCDF-4
file and stores their data in a dictionary. The key for each dataset is its
full hierarchical path.

Args:
file_path (str | Path): The path to the NetCDF file.

Returns:
Dict[str, Any]: A dictionary where keys are the full variable paths
and values are the corresponding NumPy arrays.
"""
datasets: dict[str, Any] = {}
file_path = Path(file_path)

def _read_all_recursively(group: netCDF4.Group | netCDF4.Dataset, path: str = ""):
for var_name, var_obj in group.variables.items():
full_path = f"{path}/{var_name}" if path else var_name
datasets[full_path] = var_obj[:]

for group_name, group_obj in group.groups.items():
new_path = f"{path}/{group_name}" if path else group_name
_read_all_recursively(group_obj, new_path)

if not file_path.exists():
print(f"File not found: {file_path}")
return {}
Comment thread
sahiljhawar marked this conversation as resolved.

with netCDF4.Dataset(file_path, "r") as nc_file:
_read_all_recursively(nc_file)

return datasets


class RBMNcDataSet(RBMDataSet):
"""Class for handling RBM NetCDF data files."""

datetime: list[dt.datetime]
time: NDArray[np.float64]
energy_channels: NDArray[np.float64]
alpha_local: NDArray[np.float64]
alpha_eq_model: NDArray[np.float64]
alpha_eq_real: NDArray[np.float64]
InvMu: NDArray[np.float64]
InvMu_real: NDArray[np.float64]
InvK: NDArray[np.float64]
InvV: NDArray[np.float64]
Lstar: NDArray[np.float64]
Flux: NDArray[np.float64]
PSD: NDArray[np.float64]
MLT: NDArray[np.float64]
B_SM: NDArray[np.float64]
B_total: NDArray[np.float64]
B_sat: NDArray[np.float64]
xGEO: NDArray[np.float64]
P: NDArray[np.float64]
R0: NDArray[np.float64]
density: NDArray[np.float64]

def __init__(
self,
start_time: dt.datetime,
end_time: dt.datetime,
folder_path: Path,
satellite: SatelliteLike,
instrument: InstrumentLike,
mfm: MfmLike,
*,
verbose: bool = True,
) -> None:
super().__init__(
start_time,
end_time,
folder_path,
satellite,
instrument,
mfm=mfm,
verbose=verbose,
)

def _create_file_path_stem(self) -> Path:
# implement special cases here
# if self._satellite == SatelliteEnum.THEMIS:
# pass
if self._folder_type == FolderTypeEnum.DataServer:
return self._folder_path / self._satellite.mission / self._satellite.sat_name

if self._folder_type == FolderTypeEnum.SingleFolder:
return self._folder_path

msg = "Encountered invalid FolderTypeEnum!"
raise ValueError(msg)

def _load_variable(self, var: Variable | VariableEnum) -> None:
loaded_var_arrs: dict[str, NDArray[np.number]] = {}
var_names_storred: list[str] = []
Comment thread
sahiljhawar marked this conversation as resolved.
Outdated

# computed values
if isinstance(var, VariableEnum) and var == VariableEnum.INV_V:
inv_K_repeated = np.repeat(self.InvK[:, np.newaxis, :], self.InvMu.shape[1], axis=1)

self.InvV = self.InvMu * (inv_K_repeated + 0.5) ** 2
return

if isinstance(var, VariableEnum) and var == VariableEnum.P:
self.P = ((self.MLT + 12) / 12 * np.pi) % (2 * np.pi)
return

for date in self._date_of_files:
if self._folder_type == FolderTypeEnum.DataServer:
start_month = date.replace(day=1)
next_month = start_month + relativedelta(months=1, days=-1)
date_str = start_month.strftime("%Y%m%d") + "to" + next_month.strftime("%Y%m%d")

file_name = self._file_name_stem + date_str + "_" + self._mfm.value + ".nc"
else:
raise NotImplementedError

datasets = _read_all_datasets_netcdf(self._file_path_stem / file_name)

if datasets == {}:
continue

# also store python datetimes for binning
datetimes = typing.cast(
NDArray[np.object_],
np.asarray(
[dt.datetime.fromtimestamp(t.astype(np.int64), tz=dt.timezone.utc) for t in datasets["time"]]
),
) # type: ignore
datasets["datetime"] = datetimes

# limit in time
correct_time_idx = (datetimes >= self._start_time) & (datetimes <= self._end_time)

for key, var_arr in datasets.items():
if ((not isinstance(var_arr, np.ndarray)) or (not np.issubdtype(var_arr.dtype, np.number))) and (
key != "datetime"
):
# var represents some strings or metadata objects; don't read them
continue
var_arr = typing.cast("NDArray[np.number]", var_arr)

# check if var is time dependent
if var_arr.shape[0] == correct_time_idx.shape[0]:
var_arr_trimmed = var_arr[correct_time_idx.reshape(-1), ...]

joined_value = (
join_var(loaded_var_arrs[key], var_arr_trimmed) if key in loaded_var_arrs else var_arr_trimmed
)
else:
joined_value = var_arr

loaded_var_arrs[key] = joined_value
Comment thread
sahiljhawar marked this conversation as resolved.

if key not in var_names_storred:
var_names_storred.append(key)
Comment thread
sahiljhawar marked this conversation as resolved.
Outdated

# not a single file was found
if var.var_name not in var_names_storred:
setattr(self, var.var_name, np.asarray([]))

for var_name in var_names_storred:
if var_name == "datetime":
loaded_var_arrs[var_name] = list(loaded_var_arrs[var_name]) # type: ignore

rbm_var_name = RBMNcDataSet._get_rbm_name(var_name, self._mfm.value)

setattr(self, rbm_var_name, loaded_var_arrs[var_name])

@classmethod
def _get_rbm_name(cls, var_name: str, mag_field: MfmEnumLiteral) -> VariableLiteral:
Comment thread
sahiljhawar marked this conversation as resolved.
Outdated
match var_name:
case "time":
return "time"
case "datetime":
return "datetime"
case "flux/FEDU":
return "Flux"
case "flux/alpha_eq":
return "alpha_eq_model"
case "flux/energy":
return "energy_channels"
case "flux/alpha_local":
return "alpha_local"
case "position/xGEO":
return "xGEO"
case _ if var_name == f"position/{mag_field}/MLT":
return "MLT"
case _ if var_name == f"position/{mag_field}/R0":
return "R0"
case _ if var_name == f"position/{mag_field}/Lstar":
return "Lstar"
case _ if var_name == f"position/{mag_field}/Lm":
return "Lm"
case _ if var_name == f"mag_field/{mag_field}/B_local":
return "B_total"
case "psd/PSD":
return "PSD"
case _ if var_name == f"psd/{mag_field}/inv_mu":
return "InvMu"
case _ if var_name == f"psd/{mag_field}/inv_K":
return "InvK"
case "density/density_local":
return "density"
case _:
return "_invalid_name"
Comment thread
sahiljhawar marked this conversation as resolved.
Outdated
3 changes: 3 additions & 0 deletions swvo/io/RBMDataSet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
SatelliteLike as SatelliteLike,
SatelliteEnum as SatelliteEnum,
InstrumentEnum as InstrumentEnum,
InstrumentLike as InstrumentLike,
MfmEnum as MfmEnum,
MfmLike as MfmLike,
ElPasoMFMEnum as ElPasoMFMEnum,
SatelliteLiteral as SatelliteLiteral,
)
from swvo.io.RBMDataSet.RBMDataSetManager import RBMDataSetManager as RBMDataSetManager
from swvo.io.RBMDataSet.interp_functions import TargetType as TargetType
from swvo.io.RBMDataSet.scripts.create_RBSP_line_data import create_RBSP_line_data as create_RBSP_line_data
from swvo.io.RBMDataSet.RBMDataSet import RBMDataSet as RBMDataSet
from swvo.io.RBMDataSet.RBMNcDataSet import RBMNcDataSet as RBMNcDataSet
from swvo.io.RBMDataSet.RBMDataSetElPaso import RBMDataSetElPaso as RBMDataSetElPaso
Loading