diff --git a/src/mcstastox/ReadNeXus.py b/src/mcstastox/ReadNeXus.py index 621f922..9b928b6 100644 --- a/src/mcstastox/ReadNeXus.py +++ b/src/mcstastox/ReadNeXus.py @@ -3,7 +3,9 @@ import re from dataclasses import dataclass from types import MappingProxyType +from typing import cast +import h5py import numpy as np @@ -43,6 +45,35 @@ def _get_mcstas_version_settings( raise ValueError(f"McStas version {version} not supported by this tool.") +def _validate_file(file_handle: h5py.File) -> None: + # Check file is formatted as expected + if "entry1" not in list(file_handle.keys()): + raise ValueError("h5 file not formatted as expected, lacks 'entry1'.") + + entry_obj = cast(h5py.Group, file_handle["entry1"]) + entry_keys = tuple(entry_obj.keys()) + + mandatory_entry_keys = ("data", "simulation", "instrument") + missing_keys = tuple(key for key in mandatory_entry_keys if key not in entry_keys) + if any(missing_keys): + raise ValueError( + "'entry1' not formatted as expected, lacks keys: " + f"[{', '.join(missing_keys)}]." + ) from None + + simluation_keys = tuple(entry_obj["simulation"].keys()) + if "Param" not in simluation_keys: + raise ValueError( + "'entry1/simulation' not formatted as expected, lacks 'Param'." + ) + + instrument_keys = tuple(entry_obj["instrument"].keys()) + if "components" not in instrument_keys: + raise ValueError( + "'entry1/instrument' not formatted as expected, lacks 'components'." + ) + + class McStasNeXus: """ Reads a McStas NeXus files and provides methods to retrieve data or entries @@ -57,63 +88,43 @@ def __init__( mcstas_setting_registry: _McStasVersionSettingTp = _MCSTAS_VERSION_SETTINGS, ): self.file_handle = file_handle - f = self.file_handle - - self.mcstas_version = mcstas_version or self.read_mcstas_version() + # Check file is formatted as expected + _validate_file(self.file_handle) + self.mcstas_version = mcstas_version or self._read_mcstas_version() # Load settings appropriate for this McStas version self.settings = _get_mcstas_version_settings( self.mcstas_version, mcstas_setting_registry ) - # Check file is formatted as expected - if "entry1" not in list(f.keys()): - raise ValueError("h5 file not formatted as expected, lacks 'entry1'.") - - if "data" not in list(f["entry1"].keys()): - raise ValueError("h5 file not formatted as expected, lacks 'data'.") - - if "simulation" not in list(f["entry1"].keys()): - raise ValueError("h5 file not formatted as expected, lacks 'simulation'.") - - if "Param" not in list(f["entry1"]["simulation"].keys()): - raise ValueError("h5 file not formatted as expected, lacks 'Param'.") - - if "instrument" not in list(f["entry1"].keys()): - raise ValueError("h5 file not formatted as expected, lacks 'instrument'.") - - if "components" not in list(f["entry1"]["instrument"].keys()): - raise ValueError("h5 file not formatted as expected, lacks 'components'.") - # Grab basic information + self.component_names: list + self.component_path_names: dict + self._read_component_name_and_path() + + def _read_component_name_and_path(self) -> None: if self.settings.component_numbers is None: - self.component_names = f["entry1"]["instrument"]["components"].keys() + self.component_names = list( + self.file_handle["entry1"]["instrument"]["components"].keys() + ) self.component_path_names = {name: name for name in self.component_names} else: comp_name_start_index = self.settings.component_numbers + 1 - self.component_names = [] - self.component_path_names = {} - full_comp_names = f["entry1"]["instrument"]["components"].keys() - for name in full_comp_names: - component_name = name[comp_name_start_index:] - self.component_names.append(component_name) - self.component_path_names[component_name] = name - - def read_mcstas_version(self) -> tuple[int, int, int]: - f = self.file_handle - - # First attempt at reading version - if "entry1" not in list(f.keys()): - raise ValueError("h5 file not formatted as expected, lacks 'entry1'") - - if "simulation" not in list(f["entry1"].keys()): - raise ValueError( - "h5 file not formatted as expected, lacks 'entry1/simulation'" + components = self.file_handle["entry1"]["instrument"]["components"].keys() + component_paths = list(components) + self.component_names = [ + name[comp_name_start_index:] for name in component_paths + ] + self.component_path_names = dict( + zip(self.component_names, component_paths, strict=True) ) + def _read_mcstas_version(self) -> tuple[int, int, int]: + f = self.file_handle if "program" not in list(f["entry1"]["simulation"].attrs): raise ValueError( - "h5 file not formatted as expected, lacks 'entry1/simulation/program'" + "h5 file not formatted as expected, " + "lacks 'program' attribute in 'entry1/simulation/program'" ) version_string = f["entry1"]["simulation"].attrs["program"].decode("utf-8")