diff --git a/src/nxmx/__init__.py b/src/nxmx/__init__.py index 5fc185a..515f3f0 100644 --- a/src/nxmx/__init__.py +++ b/src/nxmx/__init__.py @@ -12,14 +12,15 @@ import dateutil.parser import h5py import numpy as np +import numpy.typing as npt import pint from scipy.spatial.transform import Rotation # NeXus field type for type annotations # https://manual.nexusformat.org/nxdl-types.html#nxdl-field-types-and-units -NXBoolT: TypeAlias = bool | np.ndarray -NXFloatT: TypeAlias = float | np.ndarray -NXIntT: TypeAlias = int | np.ndarray +NXBoolT: TypeAlias = np.bool_ | npt.NDArray[np.bool_] +NXFloatT: TypeAlias = np.floating | npt.NDArray[np.floating] +NXIntT: TypeAlias = np.integer | npt.NDArray[np.integer] NXNumberT: TypeAlias = NXFloatT | NXIntT @@ -46,20 +47,52 @@ def __len__(self): return len(self._handle) -def h5scalar(ds: h5py.Dataset): - """Read a scalar value from an HDF5 dataset +def h5scalar(ds: h5py.Dataset) -> np.generic: + """ + Read a scalar value from an HDF5 dataset - Sometimes scalars are stored as a length-1 1D dataset instead of a proper - scalar. This function allows for that, since NumPy has got stricter about - converting to scalars. + Raises: ValueError, if the dataset does not contain a scalar. """ - if ds.size != 1: - raise ValueError("only length-1 arrays can be converted to Python scalars") - arr = np.squeeze(ds[()]) - return arr.item() + value = h5_maybe_scalar(ds) + if not isinstance(value, np.generic): + raise ValueError("Cannot be converted to numpy scalar") + return value + + +def h5_maybe_scalar(ds: h5py.Dataset) -> np.generic | npt.NDArray[np.generic]: + """ + Coerce an HDF5 dataset to a NumPy scalar, if appropriate. + + Sometimes scalars are stored as a length-1 1D dataset instead of a + scalar, this converts those instances to a scalar. + + Only changes when: + - It is already a 0-D array + - It is a 1-D array of length 1 + + Otherwise, the array is returned as-is, with dtype preserved. + """ + data = ds[()] + + if isinstance(data, np.ndarray): + if data.ndim == 0: + # 0-D array → NumPy scalar + return data[()] + elif data.ndim == 1 and data.size == 1: + # 1-D length-1 array → NumPy scalar + return data[0] + else: + # multi-element array → leave as-is + return data + else: + if isinstance(data, np.generic): + # already a NumPy scalar → leave as-is + return data + # Explicitly promote this into a numpy scalar + return np.asarray(data).flat[0] -def h5str(h5_value: str | np.bytes_ | bytes | None) -> str | None: +def h5str(h5_value: str | np.bytes_ | np.str_ | bytes | None) -> str | None: """ Convert a value returned from an h5py attribute to str. @@ -69,6 +102,9 @@ def h5str(h5_value: str | np.bytes_ | bytes | None) -> str | None: """ if isinstance(h5_value, np.bytes_ | bytes): return h5_value.decode("utf-8") + if isinstance(h5_value, np.str_): + return str(h5_value) + assert isinstance(h5_value, str) or h5_value is None return h5_value @@ -246,7 +282,7 @@ def signal(self) -> str | None: return self._handle.attrs.get("signal") @cached_property - def data_scale_factor(self) -> str | None: + def data_scale_factor(self) -> NXNumberT | None: """ An optional scaling factor to apply to the values in ``data``. @@ -278,17 +314,17 @@ def data_scale_factor(self) -> str | None: When omitted, the scaling factor is assumed to be 1. """ if "data_scale_factor" in self._handle: - return h5scalar(self._handle["data_scale_factor"]) + return h5_maybe_scalar(self._handle["data_scale_factor"]) @cached_property - def data_offset(self) -> str | None: + def data_offset(self) -> NXNumberT | None: """ An optional offset to apply to the values in data. When omitted, the offset is assumed to be 0. """ if "data_offset" in self._handle: - return h5scalar(self._handle["data_offset"]) + return h5_maybe_scalar(self._handle["data_offset"]) class NXtransformations(H5Mapping): @@ -575,7 +611,7 @@ def depends_on(self) -> NXtransformationsAxis | None: def temperature(self) -> pint.Quantity | None: """The temperature of the sample.""" if temperature := self._handle.get("temperature"): - return h5scalar(temperature) * units(temperature) + return h5_maybe_scalar(temperature) * units(temperature) return None @cached_property