From b64aa82672acf23e575c709505d0135efd2f8bae Mon Sep 17 00:00:00 2001 From: Eoghan O'Connell Date: Fri, 24 Jan 2025 15:09:56 +0100 Subject: [PATCH] docs: add type hints for utils module; ref: make 2d funcs private --- qpretrieve/filter.py | 10 ++++++---- qpretrieve/utils.py | 24 +++++++++++++++++++----- tests/test_utils.py | 6 +++--- 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/qpretrieve/filter.py b/qpretrieve/filter.py index d14c0eb..2fc7fe6 100644 --- a/qpretrieve/filter.py +++ b/qpretrieve/filter.py @@ -3,7 +3,6 @@ import numpy as np from scipy import signal - available_filters = [ "disk", "smooth disk", @@ -15,7 +14,10 @@ @lru_cache(maxsize=32) -def get_filter_array(filter_name, filter_size, freq_pos, fft_shape): +def get_filter_array( + filter_name: str, filter_size: float, + freq_pos: tuple[float, float], + fft_shape: tuple[int, int]) -> np.ndarray: """Create a Fourier filter for holography Parameters @@ -55,7 +57,7 @@ def get_filter_array(filter_name, filter_size, freq_pos, fft_shape): raise ValueError("The Fourier transformed data must have a squared " + f"shape, but the input shape is '{fft_shape}'! " + "Please pad your data properly before FFT.") - if not (0 < filter_size < max(fft_shape)/2): + if not (0 < filter_size < max(fft_shape) / 2): raise ValueError("The filter size cannot exceed more than half of " + "the Fourier space or be negative. Got a filter " + f"size of '{filter_size}' and a shape of " @@ -63,7 +65,7 @@ def get_filter_array(filter_name, filter_size, freq_pos, fft_shape): if not (0 <= min(np.abs(freq_pos)) <= max(np.abs(freq_pos)) - < max(fft_shape)/2): + < max(fft_shape) / 2): raise ValueError("The frequency position must be within the Fourier " + f"domain. Got '{freq_pos}' and shape " + f"'{fft_shape}'!") diff --git a/qpretrieve/utils.py b/qpretrieve/utils.py index e8142e5..44f0bef 100644 --- a/qpretrieve/utils.py +++ b/qpretrieve/utils.py @@ -1,20 +1,22 @@ import numpy as np -def mean_2d(data): +def _mean_2d(data): + """Exists for testing against mean_3d""" data -= data.mean() return data -def mean_3d(data): - # calculate mean of the images along the z-axis. +def mean_3d(data: np.ndarray) -> np.ndarray: + """Calculate mean of the data along the z-axis.""" # The mean array here is (1000,), so we need to add newaxes for subtraction # (1000, 5, 5) -= (1000, 1, 1) data -= data.mean(axis=(-2, -1))[:, np.newaxis, np.newaxis] return data -def padding_2d(data, order, dtype): +def _padding_2d(data, order, dtype): + """Exists for testing against padding_3d""" # this is faster than np.pad datapad = np.zeros((order, order), dtype=dtype) # we could of course use np.atleast_3d here @@ -22,7 +24,19 @@ def padding_2d(data, order, dtype): return datapad -def padding_3d(data, order, dtype): +def padding_3d(data: np.ndarray, order: int, dtype: np.dtype) -> np.ndarray: + """Calculate padding of the data along the z-axis. + + Parameters + ---------- + data + 3d array. The padding will be applied to the axes (y,x) only. + order + The data will be padded to this size. + dtype + data type of the padded array. + + """ z, y, x = data.shape # this is faster than np.pad datapad = np.zeros((z, order, order), dtype=dtype) diff --git a/tests/test_utils.py b/tests/test_utils.py index c7be6d5..d5ccdab 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,6 @@ import numpy as np -from qpretrieve.utils import padding_2d, padding_3d, mean_2d, mean_3d +from qpretrieve.utils import _padding_2d, padding_3d, _mean_2d, mean_3d def test_mean_subtraction(): @@ -8,7 +8,7 @@ def test_mean_subtraction(): ind = 5 data_2d = data_3d.copy()[ind] - data_2d = mean_2d(data_2d) + data_2d = _mean_2d(data_2d) data_3d = mean_3d(data_3d) assert np.array_equal(data_3d[ind], data_2d) @@ -38,7 +38,7 @@ def test_batch_padding(): order = 512 dtype = float - data_2d_padded = padding_2d(data_2d, order, dtype) + data_2d_padded = _padding_2d(data_2d, order, dtype) data_3d_padded = padding_3d(data_3d, order, dtype) assert np.array_equal(data_3d_padded[ind], data_2d_padded)