Skip to content

Typing #9

@swelborn

Description

@swelborn

@nikitinvv it would be very helpful if your package was typed. For example, when we specify args, we can add a type like so:

@dataclass
class StreamReconArgs:
    n: int
    nproj: int
    nz: int
    nflat: int
    ndark: int
    in_dtype: np.dtype | str
    ngpus: int
    file_type: str
    nsino_per_chunk: int
    nproj_per_chunk: int
    dtype: np.dtype | str
    # more args here...


...
class StreamRecon:
    """Streaming reconstruction"""

    def __init__(self, args: StreamReconArgs):

You could do this with dataclasses, or add pydantic to do validation on the configuration. This would make things significantly easier to reason about what is going on when using streamtomocupy, and for other people to develop ontop. I started working on this, but it ended up being a bigger project than I anticipated, as the args is passed around a lot of places in the repo.

Fyi, ChatGPT gave me this when I asked it to make some pydantic models:

from enum import Enum
from typing import Literal

from pydantic import BaseModel


class FileType(str, Enum):
    STANDARD = "standard"
    DOUBLE_FOV = "double_fov"


class ReconstructionAlgorithm(str, Enum):
    FOURIER = "fourierrec"
    LP = "lprec"
    LINE = "linerec"


class FBPFilter(str, Enum):
    NONE = "none"
    RAMP = "ramp"
    SHEPP = "shepp"
    HANN = "hann"
    HAMMING = "hamming"
    PARZEN = "parzen"
    COSINE = "cosine"
    COSINE2 = "cosine2"


class PhaseMethod(str, Enum):
    NONE = "none"
    PAGANIN = "paganin"
    GPAGANIN = "Gpaganin"


class StripeRemovalMethod(str, Enum):
    NONE = "none"
    FW = "fw"
    TI = "ti"
    VO_ALL = "vo-all"


class WaveletFilter(str, Enum):
    HAAR = "haar"
    DB5 = "db5"
    SYM5 = "sym5"
    SYM16 = "sym16"


class ReconstructionConfig(BaseModel):
    nproj: int = 0
    nz: int = 0
    n: int = 0
    nflat: int = 0
    ndark: int = 0
    in_dtype: str = "uint16"
    file_type: FileType = FileType.STANDARD
    reconstruction_algorithm: ReconstructionAlgorithm = ReconstructionAlgorithm.FOURIER
    rotation_axis: float = -1.0
    dtype: Literal["float32", "float16"] = "float32"
    fbp_filter: FBPFilter = FBPFilter.PARZEN
    dezinger: int = 0
    dezinger_threshold: int = 5000
    minus_log: bool = True
    nsino_per_chunk: int = 8
    nproj_per_chunk: int = 8
    ngpus: int = 1


class PhaseRetrievalConfig(BaseModel):
    retrieve_phase_method: PhaseMethod = PhaseMethod.NONE
    energy: float = 20
    propagation_distance: float = 100
    retrieve_phase_alpha: float = 0.001
    retrieve_phase_delta_beta: float = 1500.0
    retrieve_phase_W: float = 2e-4
    retrieve_phase_pad: int = 1
    pixel_size: float = 1


class StripeRemovalConfig(BaseModel):
    remove_stripe_method: StripeRemovalMethod = StripeRemovalMethod.NONE


class FourierWaveletConfig(BaseModel):
    fw_sigma: float = 1
    fw_filter: WaveletFilter = WaveletFilter.SYM16
    fw_level: int = 7
    fw_pad: bool = True


class VOAllConfig(BaseModel):
    vo_all_snr: float = 3
    vo_all_la_size: int = 61
    vo_all_sm_size: int = 21
    vo_all_dim: int = 1


class TitarenkoConfig(BaseModel):
    ti_beta: float = 0.022
    ti_mask: float = 1


class FullConfig(BaseModel):
    reconstruction: ReconstructionConfig = ReconstructionConfig()
    phase_retrieval: PhaseRetrievalConfig = PhaseRetrievalConfig()
    stripe_removal: StripeRemovalConfig = StripeRemovalConfig()
    fourier_wavelet: FourierWaveletConfig = FourierWaveletConfig()
    vo_all: VOAllConfig = VOAllConfig()
    titarenko: TitarenkoConfig = TitarenkoConfig()

    @classmethod
    def from_args(cls, args) -> "FullConfig":
        return cls(
            reconstruction=ReconstructionConfig(
                file_type=args.file_type,
                reconstruction_algorithm=args.reconstruction_algorithm,
                rotation_axis=args.rotation_axis,
                dtype=args.dtype,
                fbp_filter=args.fbp_filter,
                dezinger=args.dezinger,
                dezinger_threshold=args.dezinger_threshold,
                minus_log=args.minus_log == "True",
                nsino_per_chunk=args.nsino_per_chunk,
                nproj_per_chunk=args.nproj_per_chunk,
                ngpus=args.ngpus,
            ),
            phase_retrieval=PhaseRetrievalConfig(
                retrieve_phase_method=args.retrieve_phase_method,
                energy=args.energy,
                propagation_distance=args.propagation_distance,
                retrieve_phase_alpha=args.retrieve_phase_alpha,
                retrieve_phase_delta_beta=args.retrieve_phase_delta_beta,
                retrieve_phase_W=args.retrieve_phase_W,
                retrieve_phase_pad=args.retrieve_phase_pad,
                pixel_size=args.pixel_size,
            ),
            stripe_removal=StripeRemovalConfig(
                remove_stripe_method=args.remove_stripe_method,
            ),
            fourier_wavelet=FourierWaveletConfig(
                fw_sigma=args.fw_sigma,
                fw_filter=args.fw_filter,
                fw_level=args.fw_level,
                fw_pad=args.fw_pad,
            ),
            vo_all=VOAllConfig(
                vo_all_snr=args.vo_all_snr,
                vo_all_la_size=args.vo_all_la_size,
                vo_all_sm_size=args.vo_all_sm_size,
                vo_all_dim=args.vo_all_dim,
            ),
            titarenko=TitarenkoConfig(
                ti_beta=args.ti_beta,
                ti_mask=args.ti_mask,
            ),
        )


class FlatConfig(BaseModel):
    # Reconstruction
    nproj: int = 0
    nz: int = 0
    n: int = 0
    nflat: int = 0
    ndark: int = 0
    in_dtype: str = "uint16"
    file_type: FileType = FileType.STANDARD
    reconstruction_algorithm: ReconstructionAlgorithm = ReconstructionAlgorithm.FOURIER
    rotation_axis: float = -1.0
    dtype: Literal["float32", "float16"] = "float32"
    fbp_filter: FBPFilter = FBPFilter.PARZEN
    dezinger: int = 0
    dezinger_threshold: int = 5000
    minus_log: bool = True
    nsino_per_chunk: int = 8
    nproj_per_chunk: int = 8
    ngpus: int = 1

    # Phase Retrieval
    retrieve_phase_method: PhaseMethod = PhaseMethod.NONE
    energy: float = 20
    propagation_distance: float = 100
    retrieve_phase_alpha: float = 0.001
    retrieve_phase_delta_beta: float = 1500.0
    retrieve_phase_W: float = 2e-4
    retrieve_phase_pad: int = 1
    pixel_size: float = 1

    # Stripe Removal
    remove_stripe_method: StripeRemovalMethod = StripeRemovalMethod.NONE

    # Fourier Wavelet
    fw_sigma: float = 1
    fw_filter: WaveletFilter = WaveletFilter.SYM16
    fw_level: int = 7
    fw_pad: bool = True

    # VO All
    vo_all_snr: float = 3
    vo_all_la_size: int = 61
    vo_all_sm_size: int = 21
    vo_all_dim: int = 1

    # Titarenko
    ti_beta: float = 0.022
    ti_mask: float = 1

    @classmethod
    def from_full_config(cls, config: FullConfig) -> "FlatConfig":
        return cls(
            # Reconstruction
            file_type=config.reconstruction.file_type,
            reconstruction_algorithm=config.reconstruction.reconstruction_algorithm,
            rotation_axis=config.reconstruction.rotation_axis,
            dtype=config.reconstruction.dtype,
            fbp_filter=config.reconstruction.fbp_filter,
            dezinger=config.reconstruction.dezinger,
            dezinger_threshold=config.reconstruction.dezinger_threshold,
            minus_log=config.reconstruction.minus_log,
            nsino_per_chunk=config.reconstruction.nsino_per_chunk,
            nproj_per_chunk=config.reconstruction.nproj_per_chunk,
            ngpus=config.reconstruction.ngpus,
            # Phase Retrieval
            retrieve_phase_method=config.phase_retrieval.retrieve_phase_method,
            energy=config.phase_retrieval.energy,
            propagation_distance=config.phase_retrieval.propagation_distance,
            retrieve_phase_alpha=config.phase_retrieval.retrieve_phase_alpha,
            retrieve_phase_delta_beta=config.phase_retrieval.retrieve_phase_delta_beta,
            retrieve_phase_W=config.phase_retrieval.retrieve_phase_W,
            retrieve_phase_pad=config.phase_retrieval.retrieve_phase_pad,
            pixel_size=config.phase_retrieval.pixel_size,
            # Stripe Removal
            remove_stripe_method=config.stripe_removal.remove_stripe_method,
            # Fourier Wavelet
            fw_sigma=config.fourier_wavelet.fw_sigma,
            fw_filter=config.fourier_wavelet.fw_filter,
            fw_level=config.fourier_wavelet.fw_level,
            fw_pad=config.fourier_wavelet.fw_pad,
            # VO All
            vo_all_snr=config.vo_all.vo_all_snr,
            vo_all_la_size=config.vo_all.vo_all_la_size,
            vo_all_sm_size=config.vo_all.vo_all_sm_size,
            vo_all_dim=config.vo_all.vo_all_dim,
            # Titarenko
            ti_beta=config.titarenko.ti_beta,
            ti_mask=config.titarenko.ti_mask,
        )

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions