Skip to content

Commit

Permalink
docs: add type hints for array layout module
Browse files Browse the repository at this point in the history
  • Loading branch information
Eoghan O'Connell committed Jan 24, 2025
1 parent 560054c commit 880591a
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions qpretrieve/data_array_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,18 @@ def get_allowed_array_layouts() -> list:
]


def convert_data_to_3d_array_layout(data):
"""Convert the data to the 3d array_layout."""
def convert_data_to_3d_array_layout(
data: np.ndarray) -> tuple[np.ndarray, str]:
"""Convert the data to the 3d array_layout
Returns
-------
data
3d version of the data
array_layout
original array layout for future reference
"""
if len(data.shape) == 3:
if data.shape[-1] in [1, 2, 3]:
# take the first slice (we have alpha or RGB information)
Expand All @@ -41,9 +51,15 @@ def convert_data_to_3d_array_layout(data):
return data.copy(), array_layout


def convert_3d_data_to_array_layout(data, array_layout):
def convert_3d_data_to_array_layout(
data: np.ndarray, array_layout: str) -> np.ndarray:
"""Convert the 3d data to the desired `array_layout`.
Returns
-------
data_out
input `data` with the given `array layout`
Notes
-----
Currently, this function is limited to converting from 3d to other
Expand All @@ -68,7 +84,7 @@ def convert_3d_data_to_array_layout(data, array_layout):
return data


def _convert_rgb_to_3d(data_input):
def _convert_rgb_to_3d(data_input: np.ndarray) -> tuple[np.ndarray, str]:
data = data_input[:, :, 0]
data = data[np.newaxis, :, :]
array_layout = "rgb"
Expand All @@ -77,29 +93,29 @@ def _convert_rgb_to_3d(data_input):
return data, array_layout


def _convert_rgba_to_3d(data_input):
def _convert_rgba_to_3d(data_input: np.ndarray) -> tuple[np.ndarray, str]:
data, _ = _convert_rgb_to_3d(data_input)
array_layout = "rgba"
return data, array_layout


def _convert_2d_to_3d(data_input):
def _convert_2d_to_3d(data_input: np.ndarray) -> tuple[np.ndarray, str]:
data = data_input[np.newaxis, :, :]
array_layout = "2d"
return data, array_layout


def _convert_3d_to_rgb(data_input):
def _convert_3d_to_rgb(data_input: np.ndarray) -> np.ndarray:
data = data_input[0]
data = np.dstack((data, data, data))
return data


def _convert_3d_to_rgba(data_input):
def _convert_3d_to_rgba(data_input: np.ndarray) -> np.ndarray:
data = data_input[0]
data = np.dstack((data, data, data, np.ones_like(data)))
return data


def _convert_3d_to_2d(data_input):
def _convert_3d_to_2d(data_input: np.ndarray) -> np.ndarray:
return data_input[0]

0 comments on commit 880591a

Please sign in to comment.