Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 242 additions & 1 deletion osipy/dce/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from __future__ import annotations

import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

Expand All @@ -39,6 +40,8 @@
from osipy.dce.models.binding import BoundDCEModel
from osipy.dce.models.registry import get_model

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from collections.abc import Callable

Expand Down Expand Up @@ -75,6 +78,45 @@ class DCEFitResult:
fitting_stats: dict[str, Any] = field(default_factory=dict)


@dataclass
class ModelSelectionResult:
"""Result container for model selection comparison.

Holds the per-model fitting results and information-criterion
score maps so the caller can inspect which model won at each
voxel and why.

Attributes
----------
best_model_map : NDArray[np.intp]
Integer map where each voxel value is the index into
``model_names`` of the winning model. Winner is determined
by lowest AIC/BIC or highest R-squared, depending on
``criterion``.
model_names : list[str]
Ordered list of successfully compared model names.
criterion : str
Criterion used for comparison
(``'aic'``, ``'bic'``, or ``'r_squared'``).
aic_maps : dict[str, NDArray[np.floating]]
Per-model AIC score maps.
bic_maps : dict[str, NDArray[np.floating]]
Per-model BIC score maps.
r_squared_maps : dict[str, NDArray[np.floating]]
Per-model R-squared maps.
fit_results : dict[str, DCEFitResult]
Full fitting results for each model.
"""

best_model_map: NDArray[np.intp]
model_names: list[str]
criterion: str
aic_maps: dict[str, NDArray[np.floating[Any]]] = field(default_factory=dict)
bic_maps: dict[str, NDArray[np.floating[Any]]] = field(default_factory=dict)
r_squared_maps: dict[str, NDArray[np.floating[Any]]] = field(default_factory=dict)
fit_results: dict[str, DCEFitResult] = field(default_factory=dict)


def fit_model(
model_name: str,
concentration: NDArray[np.floating[Any]],
Expand Down Expand Up @@ -156,6 +198,201 @@ def fit_model(
)


def compare_models(
model_names: list[str],
concentration: NDArray[np.floating[Any]],
aif: ArterialInputFunction | NDArray[np.floating[Any]],
time: NDArray[np.floating[Any]],
mask: NDArray[np.bool_] | None = None,
criterion: str = "r_squared",
fitter: BaseFitter | str | None = None,
bounds_override: dict[str, tuple[float, float]] | None = None,
) -> ModelSelectionResult:
"""Fit multiple models and select the best one per voxel.

Fits each model independently using ``fit_model()``, computes
AIC, BIC, and R-squared at every voxel, and returns a map
indicating which model won at each spatial location.

Parameters
----------
model_names : list[str]
Models to compare, e.g. ``['tofts', 'extended_tofts', 'patlak']``.
Must contain at least two unique names.
concentration : NDArray[np.floating]
Concentration data, shape ``(x, y, z, t)`` or ``(x, y, t)``
or ``(n_voxels, t)``.
aif : ArterialInputFunction or NDArray[np.floating]
Arterial input function. Can be an ArterialInputFunction
object or a 1D array of concentration values.
time : NDArray[np.floating]
Time points in seconds.
mask : NDArray[np.bool_] | None
Optional mask of voxels to fit. If None, fits all voxels.
criterion : str
Selection criterion: ``'aic'``, ``'bic'``, or ``'r_squared'``.
Default ``'r_squared'``. Note that ``'r_squared'`` does not
penalize model complexity and may favor overfitting.
fitter : BaseFitter | str | None
Fitter instance or registry name (e.g., ``'lm'``).
Uses LevenbergMarquardtFitter by default.
bounds_override : dict[str, tuple[float, float]] | None
Optional per-parameter bound overrides.

Returns
-------
ModelSelectionResult
Comparison results including best-model map, AIC/BIC/R-squared
score maps, and the full ``DCEFitResult`` for each model.

Raises
------
FittingError
If fewer than two models are provided, criterion is invalid,
duplicate model names are given, or fewer than two models
fit successfully.
DataValidationError
If a model name is not recognized in the registry.

Notes
-----
AIC and BIC are computed from the residual sum of squares
assuming Gaussian errors:

AIC = n * ln(SS_res / n) + 2 * k
BIC = n * ln(SS_res / n) + k * ln(n)

where *n* is the number of time points and *k* is the number
of model parameters. Lower values indicate a better fit.
BIC penalizes complexity more heavily than AIC, especially
for larger *n*.

Examples
--------
>>> import numpy as np
>>> from osipy.dce.fitting import compare_models
>>> t = np.linspace(0, 300, 60)
>>> aif = np.exp(-t / 30)
>>> conc = np.random.rand(10, 10, 5, 60) * 0.01
>>> result = compare_models(
... ["tofts", "extended_tofts"],
... conc, aif, t,
... criterion="bic",
... )
>>> result.best_model_map.shape
(10, 10, 5)
"""
# --- Input validation ---------------------------------------------------
if len(model_names) < 2:
msg = "compare_models requires at least two model names."
raise FittingError(msg)

if len(model_names) != len(set(model_names)):
msg = "Duplicate model names are not allowed."
raise FittingError(msg)

criterion = criterion.lower()
if criterion not in ("aic", "bic", "r_squared"):
msg = f"criterion must be 'aic', 'bic', or 'r_squared', got '{criterion}'"
raise FittingError(msg)

n_time = len(time)
spatial_shape = concentration.shape[:-1]

# --- Pre-compute SS_tot once (shared across all models) -----------------
ct = np.asarray(concentration)
ct_mean = np.mean(ct, axis=-1, keepdims=True)
ss_tot = np.sum((ct - ct_mean) ** 2, axis=-1)

# Guard against constant-signal voxels (SS_tot = 0)
ss_tot_safe = np.where(ss_tot > 1e-30, ss_tot, 1e-30)

# --- Fit each model and compute AIC/BIC ---------------------------------
fit_results: dict[str, DCEFitResult] = {}
aic_maps: dict[str, np.ndarray] = {}
bic_maps: dict[str, np.ndarray] = {}
r_squared_maps: dict[str, np.ndarray] = {}
succeeded: list[str] = []

for name in model_names:
logger.info("Fitting model '%s' for model comparison", name)
try:
result = fit_model(
name,
concentration,
aif,
time,
mask=mask,
fitter=fitter,
bounds_override=bounds_override,
)
except Exception:
logger.warning(
"Model '%s' failed to fit; skipping it in comparison.",
name,
exc_info=True,
)
continue

fit_results[name] = result
succeeded.append(name)

# Number of free parameters for this model
n_params = len(result.parameter_maps)

# Get R-squared map (may be None if computation failed)
r2 = result.r_squared_map
if r2 is None:
r2 = np.zeros(spatial_shape)

# Clamp R-squared to [0, 1] to prevent negative SS_res
r2 = np.clip(r2, 0.0, 1.0)
r_squared_maps[name] = r2

# SS_res = (1 - R-squared) * SS_tot
ss_res = (1.0 - r2) * ss_tot_safe

# Avoid log(0) by clamping
ss_res_safe = np.where(ss_res > 1e-30, ss_res, 1e-30)

# AIC = n * ln(SS_res / n) + 2 * k
aic = n_time * np.log(ss_res_safe / n_time) + 2.0 * n_params
# BIC = n * ln(SS_res / n) + k * ln(n)
bic = n_time * np.log(ss_res_safe / n_time) + n_params * np.log(n_time)

aic_maps[name] = aic
bic_maps[name] = bic

# --- Post-fit validation ------------------------------------------------
if len(succeeded) < 2:
msg = (
f"Only {len(succeeded)} model(s) fitted successfully "
f"({succeeded}); need at least 2 for comparison."
)
raise FittingError(msg)

# --- Build best-model map -----------------------------------------------
if criterion == "r_squared":
# Higher R-squared = better fit
stacked = np.stack([r_squared_maps[n] for n in succeeded], axis=0)
best_model_map = np.argmax(stacked, axis=0).astype(np.intp)
else:
# Lower AIC/BIC = better fit
score_maps = aic_maps if criterion == "aic" else bic_maps
stacked = np.stack([score_maps[n] for n in succeeded], axis=0)
best_model_map = np.argmin(stacked, axis=0).astype(np.intp)

return ModelSelectionResult(
best_model_map=best_model_map,
model_names=list(succeeded),
criterion=criterion,
aic_maps=aic_maps,
bic_maps=bic_maps,
r_squared_maps=r_squared_maps,
fit_results=fit_results,
)


def _fit_model_impl(
model: Any,
concentration: NDArray[np.floating[Any]],
Expand Down Expand Up @@ -420,7 +657,11 @@ def _compute_r_squared_vectorized(
r_squared[quality_mask] = r2_values

except Exception:
pass
logger.warning(
"R-squared computation failed; returning zero map. "
"Parameter maps are still valid.",
exc_info=True,
)

return r_squared

Expand Down
Loading