@nikitinvv it would be very helpful if your package was typed. For example, when we specify args, we can add a type like so:
@dataclass
class StreamReconArgs:
n: int
nproj: int
nz: int
nflat: int
ndark: int
in_dtype: np.dtype | str
ngpus: int
file_type: str
nsino_per_chunk: int
nproj_per_chunk: int
dtype: np.dtype | str
# more args here...
...
class StreamRecon:
"""Streaming reconstruction"""
def __init__(self, args: StreamReconArgs):
You could do this with dataclasses, or add pydantic to do validation on the configuration. This would make things significantly easier to reason about what is going on when using streamtomocupy, and for other people to develop ontop. I started working on this, but it ended up being a bigger project than I anticipated, as the args is passed around a lot of places in the repo.
Fyi, ChatGPT gave me this when I asked it to make some pydantic models:
from enum import Enum
from typing import Literal
from pydantic import BaseModel
class FileType(str, Enum):
STANDARD = "standard"
DOUBLE_FOV = "double_fov"
class ReconstructionAlgorithm(str, Enum):
FOURIER = "fourierrec"
LP = "lprec"
LINE = "linerec"
class FBPFilter(str, Enum):
NONE = "none"
RAMP = "ramp"
SHEPP = "shepp"
HANN = "hann"
HAMMING = "hamming"
PARZEN = "parzen"
COSINE = "cosine"
COSINE2 = "cosine2"
class PhaseMethod(str, Enum):
NONE = "none"
PAGANIN = "paganin"
GPAGANIN = "Gpaganin"
class StripeRemovalMethod(str, Enum):
NONE = "none"
FW = "fw"
TI = "ti"
VO_ALL = "vo-all"
class WaveletFilter(str, Enum):
HAAR = "haar"
DB5 = "db5"
SYM5 = "sym5"
SYM16 = "sym16"
class ReconstructionConfig(BaseModel):
nproj: int = 0
nz: int = 0
n: int = 0
nflat: int = 0
ndark: int = 0
in_dtype: str = "uint16"
file_type: FileType = FileType.STANDARD
reconstruction_algorithm: ReconstructionAlgorithm = ReconstructionAlgorithm.FOURIER
rotation_axis: float = -1.0
dtype: Literal["float32", "float16"] = "float32"
fbp_filter: FBPFilter = FBPFilter.PARZEN
dezinger: int = 0
dezinger_threshold: int = 5000
minus_log: bool = True
nsino_per_chunk: int = 8
nproj_per_chunk: int = 8
ngpus: int = 1
class PhaseRetrievalConfig(BaseModel):
retrieve_phase_method: PhaseMethod = PhaseMethod.NONE
energy: float = 20
propagation_distance: float = 100
retrieve_phase_alpha: float = 0.001
retrieve_phase_delta_beta: float = 1500.0
retrieve_phase_W: float = 2e-4
retrieve_phase_pad: int = 1
pixel_size: float = 1
class StripeRemovalConfig(BaseModel):
remove_stripe_method: StripeRemovalMethod = StripeRemovalMethod.NONE
class FourierWaveletConfig(BaseModel):
fw_sigma: float = 1
fw_filter: WaveletFilter = WaveletFilter.SYM16
fw_level: int = 7
fw_pad: bool = True
class VOAllConfig(BaseModel):
vo_all_snr: float = 3
vo_all_la_size: int = 61
vo_all_sm_size: int = 21
vo_all_dim: int = 1
class TitarenkoConfig(BaseModel):
ti_beta: float = 0.022
ti_mask: float = 1
class FullConfig(BaseModel):
reconstruction: ReconstructionConfig = ReconstructionConfig()
phase_retrieval: PhaseRetrievalConfig = PhaseRetrievalConfig()
stripe_removal: StripeRemovalConfig = StripeRemovalConfig()
fourier_wavelet: FourierWaveletConfig = FourierWaveletConfig()
vo_all: VOAllConfig = VOAllConfig()
titarenko: TitarenkoConfig = TitarenkoConfig()
@classmethod
def from_args(cls, args) -> "FullConfig":
return cls(
reconstruction=ReconstructionConfig(
file_type=args.file_type,
reconstruction_algorithm=args.reconstruction_algorithm,
rotation_axis=args.rotation_axis,
dtype=args.dtype,
fbp_filter=args.fbp_filter,
dezinger=args.dezinger,
dezinger_threshold=args.dezinger_threshold,
minus_log=args.minus_log == "True",
nsino_per_chunk=args.nsino_per_chunk,
nproj_per_chunk=args.nproj_per_chunk,
ngpus=args.ngpus,
),
phase_retrieval=PhaseRetrievalConfig(
retrieve_phase_method=args.retrieve_phase_method,
energy=args.energy,
propagation_distance=args.propagation_distance,
retrieve_phase_alpha=args.retrieve_phase_alpha,
retrieve_phase_delta_beta=args.retrieve_phase_delta_beta,
retrieve_phase_W=args.retrieve_phase_W,
retrieve_phase_pad=args.retrieve_phase_pad,
pixel_size=args.pixel_size,
),
stripe_removal=StripeRemovalConfig(
remove_stripe_method=args.remove_stripe_method,
),
fourier_wavelet=FourierWaveletConfig(
fw_sigma=args.fw_sigma,
fw_filter=args.fw_filter,
fw_level=args.fw_level,
fw_pad=args.fw_pad,
),
vo_all=VOAllConfig(
vo_all_snr=args.vo_all_snr,
vo_all_la_size=args.vo_all_la_size,
vo_all_sm_size=args.vo_all_sm_size,
vo_all_dim=args.vo_all_dim,
),
titarenko=TitarenkoConfig(
ti_beta=args.ti_beta,
ti_mask=args.ti_mask,
),
)
class FlatConfig(BaseModel):
# Reconstruction
nproj: int = 0
nz: int = 0
n: int = 0
nflat: int = 0
ndark: int = 0
in_dtype: str = "uint16"
file_type: FileType = FileType.STANDARD
reconstruction_algorithm: ReconstructionAlgorithm = ReconstructionAlgorithm.FOURIER
rotation_axis: float = -1.0
dtype: Literal["float32", "float16"] = "float32"
fbp_filter: FBPFilter = FBPFilter.PARZEN
dezinger: int = 0
dezinger_threshold: int = 5000
minus_log: bool = True
nsino_per_chunk: int = 8
nproj_per_chunk: int = 8
ngpus: int = 1
# Phase Retrieval
retrieve_phase_method: PhaseMethod = PhaseMethod.NONE
energy: float = 20
propagation_distance: float = 100
retrieve_phase_alpha: float = 0.001
retrieve_phase_delta_beta: float = 1500.0
retrieve_phase_W: float = 2e-4
retrieve_phase_pad: int = 1
pixel_size: float = 1
# Stripe Removal
remove_stripe_method: StripeRemovalMethod = StripeRemovalMethod.NONE
# Fourier Wavelet
fw_sigma: float = 1
fw_filter: WaveletFilter = WaveletFilter.SYM16
fw_level: int = 7
fw_pad: bool = True
# VO All
vo_all_snr: float = 3
vo_all_la_size: int = 61
vo_all_sm_size: int = 21
vo_all_dim: int = 1
# Titarenko
ti_beta: float = 0.022
ti_mask: float = 1
@classmethod
def from_full_config(cls, config: FullConfig) -> "FlatConfig":
return cls(
# Reconstruction
file_type=config.reconstruction.file_type,
reconstruction_algorithm=config.reconstruction.reconstruction_algorithm,
rotation_axis=config.reconstruction.rotation_axis,
dtype=config.reconstruction.dtype,
fbp_filter=config.reconstruction.fbp_filter,
dezinger=config.reconstruction.dezinger,
dezinger_threshold=config.reconstruction.dezinger_threshold,
minus_log=config.reconstruction.minus_log,
nsino_per_chunk=config.reconstruction.nsino_per_chunk,
nproj_per_chunk=config.reconstruction.nproj_per_chunk,
ngpus=config.reconstruction.ngpus,
# Phase Retrieval
retrieve_phase_method=config.phase_retrieval.retrieve_phase_method,
energy=config.phase_retrieval.energy,
propagation_distance=config.phase_retrieval.propagation_distance,
retrieve_phase_alpha=config.phase_retrieval.retrieve_phase_alpha,
retrieve_phase_delta_beta=config.phase_retrieval.retrieve_phase_delta_beta,
retrieve_phase_W=config.phase_retrieval.retrieve_phase_W,
retrieve_phase_pad=config.phase_retrieval.retrieve_phase_pad,
pixel_size=config.phase_retrieval.pixel_size,
# Stripe Removal
remove_stripe_method=config.stripe_removal.remove_stripe_method,
# Fourier Wavelet
fw_sigma=config.fourier_wavelet.fw_sigma,
fw_filter=config.fourier_wavelet.fw_filter,
fw_level=config.fourier_wavelet.fw_level,
fw_pad=config.fourier_wavelet.fw_pad,
# VO All
vo_all_snr=config.vo_all.vo_all_snr,
vo_all_la_size=config.vo_all.vo_all_la_size,
vo_all_sm_size=config.vo_all.vo_all_sm_size,
vo_all_dim=config.vo_all.vo_all_dim,
# Titarenko
ti_beta=config.titarenko.ti_beta,
ti_mask=config.titarenko.ti_mask,
)
@nikitinvv it would be very helpful if your package was typed. For example, when we specify
args, we can add a type like so:You could do this with dataclasses, or add
pydanticto do validation on the configuration. This would make things significantly easier to reason about what is going on when usingstreamtomocupy, and for other people to develop ontop. I started working on this, but it ended up being a bigger project than I anticipated, as theargsis passed around a lot of places in the repo.Fyi, ChatGPT gave me this when I asked it to make some pydantic models: