diff --git a/swvo/io/RBMDataSet/RBMDataSet.py b/swvo/io/RBMDataSet/RBMDataSet.py index 91de5e7..c6c7a01 100644 --- a/swvo/io/RBMDataSet/RBMDataSet.py +++ b/swvo/io/RBMDataSet/RBMDataSet.py @@ -228,6 +228,15 @@ def __getattr__(self, name: str) -> NDArray[np.float64]: raise AttributeError(msg) + def load(self, name_or_var: str | VariableEnum) -> None: + """ Load data into memory """ + + if isinstance(name_or_var, VariableEnum): + getattr(self, name_or_var.var_name) + else: + getattr(self, name_or_var) + + def find_similar_variable(self, name: str) -> tuple[None | VariableEnum, dict[str, Any]]: levenstein_info: dict[str, Any] = {"min_distance": 10, "var_name": ""} sat_variable = None diff --git a/swvo/io/RBMDataSet/RBMNcDataSet.py b/swvo/io/RBMDataSet/RBMNcDataSet.py index a172ab4..d4862cf 100644 --- a/swvo/io/RBMDataSet/RBMNcDataSet.py +++ b/swvo/io/RBMDataSet/RBMNcDataSet.py @@ -110,6 +110,28 @@ def __init__( verbose=verbose, ) + mfm_str = mfm if isinstance(mfm, str) else mfm.mfm_name + + self.variable_lut = { + "time": "time", + "datetime": "datetime", + "flux/FEDU": "Flux", + "flux/alpha_eq": "alpha_eq_model", + "flux/energy": "energy_channels", + "flux/alpha_local": "alpha_local", + "position/xGEO": "xGEO", + "psd/PSD": "PSD", + "density/density_local": "density", + + f"position/{mfm_str}/MLT": "MLT", + f"position/{mfm_str}/R0": "R0", + f"position/{mfm_str}/Lstar": "Lstar", + f"position/{mfm_str}/Lm": "Lm", + f"mag_field/{mfm_str}/B_local": "B_total", + f"psd/{mfm_str}/inv_mu": "InvMu", + f"psd/{mfm_str}/inv_K": "InvK", + } + def _create_file_path_stem(self) -> Path: # implement special cases here # if self._satellite == SatelliteEnum.THEMIS: @@ -196,20 +218,23 @@ def _load_variable(self, var: Variable | VariableEnum) -> None: if var_name == "datetime": loaded_var_arrs[var_name] = list(loaded_var_arrs[var_name]) # ty:ignore[invalid-assignment] - rbm_var_name = RBMNcDataSet._get_rbm_name(var_name, self._mfm.mfm_name) # ty:ignore[invalid-argument-type] + rbm_var_names = RBMNcDataSet._get_rbm_name(var_name, self._mfm.mfm_name) # ty:ignore[invalid-argument-type] - if rbm_var_name is not None: - setattr(self, rbm_var_name, loaded_var_arrs[var_name]) + if rbm_var_names is not None: + for name in rbm_var_names: + setattr(self, name, loaded_var_arrs[var_name]) @classmethod - def _get_rbm_name(cls, var_name: str, mag_field: MfmEnumLiteral) -> VariableLiteral | None: + def _get_rbm_name(cls, var_name: str, mag_field: MfmEnumLiteral) -> VariableLiteral | None | list[VariableLiteral]: match var_name: case "time": return "time" case "datetime": return "datetime" case "flux/FEDU": - return "Flux" + return ["Flux", "FEDU"] + case "flux/FEIU": + return ["Flux", "FEIU"] case "flux/alpha_eq": return "alpha_eq_model" case "flux/energy": diff --git a/swvo/io/RBMDataSet/custom_enums.py b/swvo/io/RBMDataSet/custom_enums.py index a701106..c42bfdc 100644 --- a/swvo/io/RBMDataSet/custom_enums.py +++ b/swvo/io/RBMDataSet/custom_enums.py @@ -57,6 +57,10 @@ class VariableEnum(Variable, Enum): P = "P", "mlt", with_B R_0 = "R0", "R0", with_B DENSITY = "density", "density", without_B + # NC only variables + FEDU = "FEDU", "", without_B + FEIU = "FEIU", "", without_B + Lm = "Lm", "", without_B VariableLiteral = Literal[ @@ -72,6 +76,9 @@ class VariableEnum(Variable, Enum): "InvV", "Lstar", "Flux", + "FEDU", + "FEIU", + "Lm", "PSD", "MLT", "B_SM", diff --git a/swvo/io/RBMDataSet/utils.py b/swvo/io/RBMDataSet/utils.py index 645fab0..4fb8585 100644 --- a/swvo/io/RBMDataSet/utils.py +++ b/swvo/io/RBMDataSet/utils.py @@ -31,8 +31,13 @@ def join_var(var1: NDArray[np.generic], var2: NDArray[np.generic]) -> NDArray[np def get_file_path_any_format(folder_path: Path, file_stem: str, preferred_ext: str) -> Path | None: """Get the file path for a given file stem and preferred extension.""" pattern = re.compile(fnmatch.translate(file_stem + ".*"), re.IGNORECASE) + + if not folder_path.exists(): + return None + all_files = [p for p in folder_path.iterdir() if pattern.match(p.name)] + if len(all_files) == 0: warnings.warn(f"File not found: {folder_path / (file_stem + '.*')}", stacklevel=2) return None diff --git a/tests/io/RBMDataSet/test_RBMNcDataset.py b/tests/io/RBMDataSet/test_RBMNcDataset.py index 8850f48..c01190f 100644 --- a/tests/io/RBMDataSet/test_RBMNcDataset.py +++ b/tests/io/RBMDataSet/test_RBMNcDataset.py @@ -260,7 +260,7 @@ def test_load_variable_real_file(): assert hasattr(dataset, "alpha_local"), "Dataset should have 'alpha_local' attribute after loading." assert isinstance(dataset.alpha_local, np.ndarray), "'alpha_local' should be a NumPy array." - + assert hasattr(dataset, "FEDU") def test_all_variables_in_dir(mock_dataset: RBMNcDataSet): vars = [