Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 54 additions & 18 deletions src/nxmx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
import dateutil.parser
import h5py
import numpy as np
import numpy.typing as npt
import pint
from scipy.spatial.transform import Rotation

# NeXus field type for type annotations
# https://manual.nexusformat.org/nxdl-types.html#nxdl-field-types-and-units
NXBoolT: TypeAlias = bool | np.ndarray
NXFloatT: TypeAlias = float | np.ndarray
NXIntT: TypeAlias = int | np.ndarray
NXBoolT: TypeAlias = np.bool_ | npt.NDArray[np.bool_]
NXFloatT: TypeAlias = np.floating | npt.NDArray[np.floating]
NXIntT: TypeAlias = np.integer | npt.NDArray[np.integer]
NXNumberT: TypeAlias = NXFloatT | NXIntT


Expand All @@ -46,20 +47,52 @@ def __len__(self):
return len(self._handle)


def h5scalar(ds: h5py.Dataset):
"""Read a scalar value from an HDF5 dataset
def h5scalar(ds: h5py.Dataset) -> np.generic:
"""
Read a scalar value from an HDF5 dataset

Sometimes scalars are stored as a length-1 1D dataset instead of a proper
scalar. This function allows for that, since NumPy has got stricter about
converting to scalars.
Raises: ValueError, if the dataset does not contain a scalar.
"""
if ds.size != 1:
raise ValueError("only length-1 arrays can be converted to Python scalars")
arr = np.squeeze(ds[()])
return arr.item()
value = h5_maybe_scalar(ds)
if not isinstance(value, np.generic):
raise ValueError("Cannot be converted to numpy scalar")
return value


def h5_maybe_scalar(ds: h5py.Dataset) -> np.generic | npt.NDArray[np.generic]:
"""
Coerce an HDF5 dataset to a NumPy scalar, if appropriate.

Sometimes scalars are stored as a length-1 1D dataset instead of a
scalar, this converts those instances to a scalar.

Only changes when:
- It is already a 0-D array
- It is a 1-D array of length 1

Otherwise, the array is returned as-is, with dtype preserved.
"""
data = ds[()]

if isinstance(data, np.ndarray):
if data.ndim == 0:
# 0-D array → NumPy scalar
return data[()]
elif data.ndim == 1 and data.size == 1:
# 1-D length-1 array → NumPy scalar
return data[0]
else:
# multi-element array → leave as-is
return data
else:
if isinstance(data, np.generic):
# already a NumPy scalar → leave as-is
return data
# Explicitly promote this into a numpy scalar
return np.asarray(data).flat[0]


def h5str(h5_value: str | np.bytes_ | bytes | None) -> str | None:
def h5str(h5_value: str | np.bytes_ | np.str_ | bytes | None) -> str | None:
"""
Convert a value returned from an h5py attribute to str.

Expand All @@ -69,6 +102,9 @@ def h5str(h5_value: str | np.bytes_ | bytes | None) -> str | None:
"""
if isinstance(h5_value, np.bytes_ | bytes):
return h5_value.decode("utf-8")
if isinstance(h5_value, np.str_):
return str(h5_value)
assert isinstance(h5_value, str) or h5_value is None
return h5_value


Expand Down Expand Up @@ -246,7 +282,7 @@ def signal(self) -> str | None:
return self._handle.attrs.get("signal")

@cached_property
def data_scale_factor(self) -> str | None:
def data_scale_factor(self) -> NXNumberT | None:
"""
An optional scaling factor to apply to the values in ``data``.

Expand Down Expand Up @@ -278,17 +314,17 @@ def data_scale_factor(self) -> str | None:
When omitted, the scaling factor is assumed to be 1.
"""
if "data_scale_factor" in self._handle:
return h5scalar(self._handle["data_scale_factor"])
return h5_maybe_scalar(self._handle["data_scale_factor"])

@cached_property
def data_offset(self) -> str | None:
def data_offset(self) -> NXNumberT | None:
"""
An optional offset to apply to the values in data.

When omitted, the offset is assumed to be 0.
"""
if "data_offset" in self._handle:
return h5scalar(self._handle["data_offset"])
return h5_maybe_scalar(self._handle["data_offset"])


class NXtransformations(H5Mapping):
Expand Down Expand Up @@ -575,7 +611,7 @@ def depends_on(self) -> NXtransformationsAxis | None:
def temperature(self) -> pint.Quantity | None:
"""The temperature of the sample."""
if temperature := self._handle.get("temperature"):
return h5scalar(temperature) * units(temperature)
return h5_maybe_scalar(temperature) * units(temperature)
return None

@cached_property
Expand Down
Loading