From e3af00d99385fba6298e765122e858e9315b5f75 Mon Sep 17 00:00:00 2001 From: YooSunyoung Date: Wed, 21 May 2025 16:41:09 +0200 Subject: [PATCH 1/3] Extract file object validator and file object initializing logic. --- src/mcstastox/ReadNeXus.py | 81 ++++++++++++++++++++++++-------------- 1 file changed, 51 insertions(+), 30 deletions(-) diff --git a/src/mcstastox/ReadNeXus.py b/src/mcstastox/ReadNeXus.py index 621f922..7a46a11 100644 --- a/src/mcstastox/ReadNeXus.py +++ b/src/mcstastox/ReadNeXus.py @@ -3,7 +3,9 @@ import re from dataclasses import dataclass from types import MappingProxyType +from typing import cast +import h5py import numpy as np @@ -43,6 +45,35 @@ def _get_mcstas_version_settings( raise ValueError(f"McStas version {version} not supported by this tool.") +def _validate_file(file_handle: h5py.File) -> None: + # Check file is formatted as expected + if "entry1" not in list(file_handle.keys()): + raise ValueError("h5 file not formatted as expected, lacks 'entry1'.") + + entry_obj = cast(h5py.Group, file_handle["entry1"]) + entry_keys = tuple(entry_obj.keys()) + + mandatory_entry_keys = ("data", "simulation", "instrument") + missing_keys = tuple(key for key in mandatory_entry_keys if key not in entry_keys) + if any(missing_keys): + raise ValueError( + "'entry1' not formatted as expected, lacks keys: " + f"[{', '.join(missing_keys)}]." + ) from None + + simluation_keys = tuple(entry_obj["simulation"].keys()) + if "Param" not in simluation_keys: + raise ValueError( + "'entry1/simulation' not formatted as expected, lacks 'Param'." + ) + + instrument_keys = tuple(entry_obj["instrument"].keys()) + if "components" not in instrument_keys: + raise ValueError( + "'entry1/instrument' not formatted as expected, lacks 'components'." + ) + + class McStasNeXus: """ Reads a McStas NeXus files and provides methods to retrieve data or entries @@ -57,9 +88,7 @@ def __init__( mcstas_setting_registry: _McStasVersionSettingTp = _MCSTAS_VERSION_SETTINGS, ): self.file_handle = file_handle - f = self.file_handle - - self.mcstas_version = mcstas_version or self.read_mcstas_version() + self.mcstas_version = mcstas_version or self._read_mcstas_version() # Load settings appropriate for this McStas version self.settings = _get_mcstas_version_settings( @@ -67,39 +96,31 @@ def __init__( ) # Check file is formatted as expected - if "entry1" not in list(f.keys()): - raise ValueError("h5 file not formatted as expected, lacks 'entry1'.") - - if "data" not in list(f["entry1"].keys()): - raise ValueError("h5 file not formatted as expected, lacks 'data'.") - - if "simulation" not in list(f["entry1"].keys()): - raise ValueError("h5 file not formatted as expected, lacks 'simulation'.") - - if "Param" not in list(f["entry1"]["simulation"].keys()): - raise ValueError("h5 file not formatted as expected, lacks 'Param'.") - - if "instrument" not in list(f["entry1"].keys()): - raise ValueError("h5 file not formatted as expected, lacks 'instrument'.") - - if "components" not in list(f["entry1"]["instrument"].keys()): - raise ValueError("h5 file not formatted as expected, lacks 'components'.") + _validate_file(self.file_handle) # Grab basic information + self.component_names: list + self.component_path_names: dict + self._read_component_name_and_path() + + def _read_component_name_and_path(self) -> None: if self.settings.component_numbers is None: - self.component_names = f["entry1"]["instrument"]["components"].keys() + self.component_names = list( + self.file_handle["entry1"]["instrument"]["components"].keys() + ) self.component_path_names = {name: name for name in self.component_names} else: comp_name_start_index = self.settings.component_numbers + 1 - self.component_names = [] - self.component_path_names = {} - full_comp_names = f["entry1"]["instrument"]["components"].keys() - for name in full_comp_names: - component_name = name[comp_name_start_index:] - self.component_names.append(component_name) - self.component_path_names[component_name] = name - - def read_mcstas_version(self) -> tuple[int, int, int]: + components = self.file_handle["entry1"]["instrument"]["components"].keys() + component_paths = list(components) + self.component_names = [ + name[comp_name_start_index:] for name in component_paths + ] + self.component_path_names = dict( + zip(self.component_names, component_paths, strict=True) + ) + + def _read_mcstas_version(self)-> tuple[int, int, int]: f = self.file_handle # First attempt at reading version From 95e2219a3684b9de81e5f73fd6287b6cf7cfa10e Mon Sep 17 00:00:00 2001 From: YooSunyoung Date: Wed, 21 May 2025 16:56:17 +0200 Subject: [PATCH 2/3] Formatting. --- src/mcstastox/ReadNeXus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcstastox/ReadNeXus.py b/src/mcstastox/ReadNeXus.py index 7a46a11..279381d 100644 --- a/src/mcstastox/ReadNeXus.py +++ b/src/mcstastox/ReadNeXus.py @@ -120,7 +120,7 @@ def _read_component_name_and_path(self) -> None: zip(self.component_names, component_paths, strict=True) ) - def _read_mcstas_version(self)-> tuple[int, int, int]: + def _read_mcstas_version(self) -> tuple[int, int, int]: f = self.file_handle # First attempt at reading version From 333f81671543c81aecf0a5b5415a0c4c328654f3 Mon Sep 17 00:00:00 2001 From: YooSunyoung Date: Wed, 21 May 2025 16:59:57 +0200 Subject: [PATCH 3/3] Remove repeating validation step. --- src/mcstastox/ReadNeXus.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/src/mcstastox/ReadNeXus.py b/src/mcstastox/ReadNeXus.py index 279381d..9b928b6 100644 --- a/src/mcstastox/ReadNeXus.py +++ b/src/mcstastox/ReadNeXus.py @@ -88,6 +88,8 @@ def __init__( mcstas_setting_registry: _McStasVersionSettingTp = _MCSTAS_VERSION_SETTINGS, ): self.file_handle = file_handle + # Check file is formatted as expected + _validate_file(self.file_handle) self.mcstas_version = mcstas_version or self._read_mcstas_version() # Load settings appropriate for this McStas version @@ -95,9 +97,6 @@ def __init__( self.mcstas_version, mcstas_setting_registry ) - # Check file is formatted as expected - _validate_file(self.file_handle) - # Grab basic information self.component_names: list self.component_path_names: dict @@ -122,19 +121,10 @@ def _read_component_name_and_path(self) -> None: def _read_mcstas_version(self) -> tuple[int, int, int]: f = self.file_handle - - # First attempt at reading version - if "entry1" not in list(f.keys()): - raise ValueError("h5 file not formatted as expected, lacks 'entry1'") - - if "simulation" not in list(f["entry1"].keys()): - raise ValueError( - "h5 file not formatted as expected, lacks 'entry1/simulation'" - ) - if "program" not in list(f["entry1"]["simulation"].attrs): raise ValueError( - "h5 file not formatted as expected, lacks 'entry1/simulation/program'" + "h5 file not formatted as expected, " + "lacks 'program' attribute in 'entry1/simulation/program'" ) version_string = f["entry1"]["simulation"].attrs["program"].decode("utf-8")