Skip to content

Commit

Permalink
enh: interfere base allows str input for getting original data array …
Browse files Browse the repository at this point in the history
…layout
  • Loading branch information
Eoghan O'Connell committed Jan 22, 2025
1 parent fc91620 commit d3ba3fb
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions qpretrieve/interfere/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import warnings
from abc import ABC, abstractmethod

import numpy as np

from ..fourier import get_best_interface, get_available_interfaces
from ..fourier.base import FFTFilter
from ..data_input import check_data_input_format, revert_to_data_input_format
from ..data_array_layout import (
convert_data_to_3d_array_layout, convert_3d_data_to_array_layout
)


class BadFFTFilterError(ValueError):
Expand Down Expand Up @@ -80,7 +83,7 @@ def __init__(self, data, fft_interface: str | FFTFilter = "auto",
f"available interface.")

# figure out what type of data we have, change it to 3d-stack
data, self.orig_data_fmt = check_data_input_format(data)
data, self.orig_array_layout = convert_data_to_3d_array_layout(data)

#: qpretrieve Fourier transform interface class
self.fft = self.ff_iface(data=data,
Expand All @@ -98,8 +101,16 @@ def __init__(self, data, fft_interface: str | FFTFilter = "auto",
self._phase = None
self._amplitude = None

def get_orig_data_fmt(self, data_attr):
return revert_to_data_input_format(self.orig_data_fmt, data_attr)
def get_array_with_input_layout(self, data):
if isinstance(data, str):
if data == "fft":
data = "fft_filtered"
warnings.warn(
"You have asked for 'fft' which is a class. "
"Returning 'fft_filtered'. "
"Alternatively you could use 'fft_origin'.")
data = getattr(self, data)
return convert_3d_data_to_array_layout(data, self.orig_array_layout)

@property
def phase(self):
Expand Down

0 comments on commit d3ba3fb

Please sign in to comment.