Skip to content

Commit

Permalink
docs: add type hints for fourier module
Browse files Browse the repository at this point in the history
  • Loading branch information
Eoghan O'Connell committed Jan 24, 2025
1 parent fe5e62a commit 5aa0885
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 13 deletions.
5 changes: 3 additions & 2 deletions qpretrieve/fourier/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# flake8: noqa: F401
import warnings

from .base import FFTFilter
from .ff_numpy import FFTFilterNumpy

try:
Expand All @@ -11,7 +12,7 @@
PREFERRED_INTERFACE = None


def get_available_interfaces():
def get_available_interfaces() -> list:
"""Return a list of available FFT algorithms"""
interfaces = [
FFTFilterPyFFTW,
Expand All @@ -24,7 +25,7 @@ def get_available_interfaces():
return interfaces_available


def get_best_interface():
def get_best_interface() -> FFTFilter:
"""Return the fastest refocusing interface available
If `pyfftw` is installed, :class:`.FFTFilterPyFFTW`
Expand Down
8 changes: 4 additions & 4 deletions qpretrieve/fourier/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self,
data: np.ndarray,
subtract_mean: bool = True,
padding: int = 2,
copy: bool = True):
copy: bool = True) -> None:
r"""
Parameters
----------
Expand Down Expand Up @@ -135,13 +135,13 @@ def __init__(self,
self.fft_used = None

@property
def shape(self):
def shape(self) -> tuple:
"""Shape of the Fourier transform data"""
return self.fft_origin.shape

@property
@abstractmethod
def is_available(self):
def is_available(self) -> bool:
"""Whether this method is available given current hardware/software"""
return True

Expand Down Expand Up @@ -169,7 +169,7 @@ def _init_fft(self, data):

def filter(self, filter_name: str, filter_size: float,
freq_pos: (float, float),
scale_to_filter: bool | float = False):
scale_to_filter: bool | float = False) -> np.ndarray:
"""
Parameters
----------
Expand Down
4 changes: 2 additions & 2 deletions qpretrieve/fourier/ff_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class FFTFilterNumpy(FFTFilter):
# always available, because numpy is a dependency
is_available = True

def _init_fft(self, data):
def _init_fft(self, data: np.ndarray) -> np.ndarray:
"""Perform initial Fourier transform of the input data
Parameters
Expand All @@ -25,6 +25,6 @@ def _init_fft(self, data):
"""
return np.fft.fft2(data, axes=(-2, -1))

def _ifft(self, data):
def _ifft(self, data: np.ndarray) -> np.ndarray:
"""Perform inverse Fourier transform"""
return np.fft.ifft2(data, axes=(-2, -1))
11 changes: 6 additions & 5 deletions qpretrieve/fourier/ff_pyfftw.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import multiprocessing as mp
import numpy as np

import pyfftw

Expand All @@ -11,7 +12,7 @@ class FFTFilterPyFFTW(FFTFilter):
# always available, because numpy is a dependency
is_available = True

def _init_fft(self, data):
def _init_fft(self, data: np.ndarray) -> np.ndarray:
"""Perform initial Fourier transform of the input data
Parameters
Expand All @@ -33,13 +34,13 @@ def _init_fft(self, data):
fft_obj()
return out_arr

def _ifft(self, data):
def _ifft(self, data: np.ndarray) -> np.ndarray:
"""Perform inverse Fourier transform"""
in_arr = pyfftw.empty_aligned(data.shape, dtype='complex128')
ou_arr = pyfftw.empty_aligned(data.shape, dtype='complex128')
fft_obj = pyfftw.FFTW(in_arr, ou_arr, axes=(-2, -1),
out_arr = pyfftw.empty_aligned(data.shape, dtype='complex128')
fft_obj = pyfftw.FFTW(in_arr, out_arr, axes=(-2, -1),
direction="FFTW_BACKWARD",
)
in_arr[:] = data
fft_obj()
return ou_arr
return out_arr

0 comments on commit 5aa0885

Please sign in to comment.