Skip to content

Commit

Permalink
Generalize base solver output classes, fix some mypy errors, refactor…
Browse files Browse the repository at this point in the history
…ing typing
  • Loading branch information
selmanozleyen committed Jan 7, 2025
1 parent 9c38829 commit b55a0e4
Show file tree
Hide file tree
Showing 16 changed files with 66 additions and 97 deletions.
2 changes: 0 additions & 2 deletions docs/developer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ Backends
backends.ott.GWSolver
backends.ott.OTTOutput
backends.ott.GraphOTTOutput
backends.ott.GENOTLinSolver
backends.ott.output.OTTNeuralOutput
backends.utils.get_solver
backends.utils.get_available_backends

Expand Down
11 changes: 5 additions & 6 deletions src/moscot/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,18 @@
from typing import Any, Literal, Mapping, Optional, Sequence, Union

import numpy as np
from jax import Array as JaxArray
from numpy.typing import DTypeLike as DTypeLikeNumpy
from numpy.typing import NDArray
from ott.initializers.linear.initializers import SinkhornInitializer
from ott.initializers.linear.initializers_lr import LRInitializer
from ott.initializers.quadratic.initializers import BaseQuadraticInitializer

# TODO(michalk8): polish

try:
from numpy.typing import DTypeLike, NDArray

ArrayLike = NDArray[np.floating]
except (ImportError, TypeError):
ArrayLike = np.ndarray # type: ignore[misc]
DTypeLike = np.dtype # type: ignore[misc]
ArrayLike = Union[NDArray[np.floating], JaxArray]
DTypeLike = DTypeLikeNumpy

ProblemKind_t = Literal["linear", "quadratic", "unknown"]
Numeric_t = Union[int, float] # type of `time_key` arguments
Expand Down
12 changes: 10 additions & 2 deletions src/moscot/backends/ott/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from ott.geometry import costs

from moscot.backends.ott._utils import sinkhorn_divergence
from moscot.backends.ott.output import GraphOTTOutput, OTTNeuralOutput, OTTOutput
from moscot.backends.ott.output import GraphOTTOutput, NeuralOutput, OTTOutput
from moscot.backends.ott.solver import GENOTLinSolver, GWSolver, SinkhornSolver
from moscot.costs import register_cost

__all__ = ["OTTOutput", "GWSolver", "SinkhornSolver", "OTTNeuralOutput", "sinkhorn_divergence", "GENOTLinSolver"]
__all__ = [
"OTTOutput",
"GWSolver",
"SinkhornSolver",
"NeuralOutput",
"sinkhorn_divergence",
"GENOTLinSolver",
"GraphOTTOutput",
]


register_cost("euclidean", backend="ott")(costs.Euclidean)
Expand Down
5 changes: 3 additions & 2 deletions src/moscot/backends/ott/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def alpha_to_fused_penalty(alpha: float) -> float:
return (1 - alpha) / alpha


def densify(arr: ArrayLike) -> jax.Array:
def densify(arr: Union[ArrayLike, sp.sparray, sp.spmatrix]) -> jax.Array:
"""If the input is sparse, convert it to dense.
Parameters
Expand All @@ -197,7 +197,8 @@ def densify(arr: ArrayLike) -> jax.Array:
dense :mod:`jax` array.
"""
if sp.issparse(arr):
arr = arr.toarray() # type: ignore[attr-defined]
arr_sp: Union[sp.sparray, sp.spmatrix] = arr
arr = arr_sp.toarray()
elif isinstance(arr, jesp.BCOO):
arr = arr.todense()
return jnp.asarray(arr)
Expand Down
67 changes: 22 additions & 45 deletions src/moscot/backends/ott/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from moscot.backends.ott._utils import get_nearest_neighbors
from moscot.base.output import BaseDiscreteSolverOutput, BaseNeuralOutput

__all__ = ["OTTOutput", "GraphOTTOutput", "OTTNeuralOutput"]
__all__ = ["OTTOutput", "GraphOTTOutput", "NeuralOutput"]


class OTTOutput(BaseDiscreteSolverOutput):
Expand Down Expand Up @@ -182,6 +182,9 @@ def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike:
axis=1 - forward,
).T # convert to batch first

def _apply_forward(self, x: ArrayLike) -> ArrayLike:
return self._apply(x, forward=True)

@property
def shape(self) -> Tuple[int, int]: # noqa: D102
if isinstance(self._output, sinkhorn.SinkhornOutput):
Expand Down Expand Up @@ -241,11 +244,11 @@ def _ones(self, n: int) -> ArrayLike: # noqa: D102
return jnp.ones((n,))


class OTTNeuralOutput(BaseNeuralOutput):
class NeuralOutput(BaseNeuralOutput):
"""Output wrapper for GENOT."""

def __init__(self, model: GENOT, logs: dict[str, list[float]]):
"""Initialize `OTTNeuralOutput`.
"""Initialize `NeuralOutput`.
Parameters
----------
Expand All @@ -269,8 +272,7 @@ def _project_transport_matrix(
self,
src_dist: ArrayLike,
tgt_dist: ArrayLike,
forward: bool,
func: Callable[[jnp.ndarray], jnp.ndarray],
func: Callable[[ArrayLike], ArrayLike],
save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments
batch_size: int = 1024,
k: int = 30,
Expand All @@ -279,9 +281,9 @@ def _project_transport_matrix(
recall_target: float = 0.95,
aggregate_to_topk: bool = True,
) -> sp.csr_matrix:
row_indices: Union[jnp.ndarray, List[jnp.ndarray]] = []
column_indices: Union[jnp.ndarray, List[jnp.ndarray]] = []
distances_list: Union[jnp.ndarray, List[jnp.ndarray]] = []
row_indices: List[ArrayLike] = []
column_indices: List[ArrayLike] = []
distances_list: List[ArrayLike] = []
if length_scale is None:
key = jax.random.PRNGKey(seed)
src_batch = src_dist[jax.random.choice(key, src_dist.shape[0], shape=((batch_size,)))]
Expand All @@ -306,20 +308,14 @@ def _project_transport_matrix(
row_indices = jnp.concatenate(row_indices)
column_indices = jnp.concatenate(column_indices)
tm = sp.csr_matrix((distances, (row_indices, column_indices)), shape=[len(src_dist), len(tgt_dist)])
if forward:
if save_transport_matrix:
self._transport_matrix = tm
else:
tm = tm.T
if save_transport_matrix:
self._inverse_transport_matrix = tm
if save_transport_matrix:
self._transport_matrix = tm
return tm

def project_to_transport_matrix( # type:ignore[override]
self,
src_cells: ArrayLike,
tgt_cells: ArrayLike,
forward: bool = True,
condition: ArrayLike = None,
save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments
batch_size: int = 1024,
Expand Down Expand Up @@ -351,7 +347,7 @@ def project_to_transport_matrix( # type:ignore[override]
save_transport_matrix
Whether to save the transport matrix.
batch_size
Number of data points in the source distribution the neighborhoodgraph is computed
Number of data points in the source distribution the neighborhood graph is computed
for in parallel.
k
Number of neighbors to construct the k-nearest neighbor graph of a mapped cell.
Expand All @@ -375,13 +371,12 @@ def project_to_transport_matrix( # type:ignore[override]
The projected transport matrix.
"""
src_cells, tgt_cells = jnp.asarray(src_cells), jnp.asarray(tgt_cells)
push = self.push if condition is None else lambda x: self.push(x, condition)
pull = self.pull if condition is None else lambda x: self.pull(x, condition)
func, src_dist, tgt_dist = (push, src_cells, tgt_cells) if forward else (pull, tgt_cells, src_cells)
conditioned_fn: Callable[[ArrayLike], ArrayLike] = lambda x: self.push(x, condition)
push = self.push if condition is None else conditioned_fn
func, src_dist, tgt_dist = (push, src_cells, tgt_cells)
return self._project_transport_matrix(
src_dist=src_dist,
tgt_dist=tgt_dist,
forward=forward,
func=func,
save_transport_matrix=save_transport_matrix, # TODO(@MUCDK) adapt order of arguments
batch_size=batch_size,
Expand All @@ -406,31 +401,13 @@ def push(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike:
-------
Pushed distribution.
"""
if isinstance(x, (bool, int, float, complex)):
raise ValueError("Expected array, found scalar value.")
if x.ndim not in (1, 2):
raise ValueError(f"Expected 1D or 2D array, found `{x.ndim}`.")
return self._apply(x, cond=cond, forward=True)

def pull(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike:
"""Pull distribution `x` conditioned on condition `cond`.
This does not make sense for some neural models and is therefore left unimplemented.
Parameters
----------
x
Distribution to push.
cond
Condition of conditional neural OT.
Raises
------
NotImplementedError
"""
raise NotImplementedError("`pull` does not make sense for neural OT.")
return self._apply_forward(x, cond=cond)

def _apply(self, x: ArrayLike, forward: bool, cond: Optional[ArrayLike] = None) -> ArrayLike:
if not forward:
raise NotImplementedError("Backward i.e., pull on neural OT is not supported.")
def _apply_forward(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike:
return self._model.transport(x, condition=cond)

@property
Expand All @@ -445,7 +422,7 @@ def shape(self) -> Tuple[int, int]:
def to(
self,
device: Optional[Device_t] = None,
) -> "OTTNeuralOutput":
) -> "NeuralOutput":
"""Transfer the output to another device or change its data type.
Parameters
Expand All @@ -471,7 +448,7 @@ def to(
# raise IndexError(f"Unable to fetch the device with `id={idx}`.") from err

# out = jax.device_put(self._model, device)
# return OTTNeuralOutput(out)
# return NeuralOutput(out)
return self # TODO(ilan-gold) move model to device

@property
Expand Down
6 changes: 3 additions & 3 deletions src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
densify,
ensure_2d,
)
from moscot.backends.ott.output import GraphOTTOutput, OTTNeuralOutput, OTTOutput
from moscot.backends.ott.output import GraphOTTOutput, NeuralOutput, OTTOutput
from moscot.base.problems._utils import TimeScalesHeatKernel
from moscot.base.solver import OTSolver
from moscot.costs import get_cost
Expand Down Expand Up @@ -699,10 +699,10 @@ def solver(self) -> genot.GENOT:
def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]:
return {"batch_size", "train_size", "trainloader", "validloader", "seed"}, {} # type: ignore[return-value]

def _solve(self, data_samplers: Tuple[MultiLoader, MultiLoader]) -> OTTNeuralOutput: # type: ignore[override]
def _solve(self, data_samplers: Tuple[MultiLoader, MultiLoader]) -> NeuralOutput: # type: ignore[override]
seed = self._neural_kwargs.get("seed", 0) # TODO(ilan-gold): unify rng hadnling like OTT tests
rng = jax.random.PRNGKey(seed)
logs = self.solver(
data_samplers[0], n_iters=self._neural_kwargs.get("n_iters", 100), rng=rng
) # TODO(ilan-gold): validation and figure out defualts
return OTTNeuralOutput(self.solver, logs)
return NeuralOutput(self.solver, logs)
3 changes: 1 addition & 2 deletions src/moscot/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def register_solver(
return _REGISTRY.register(backend) # type: ignore[return-value]


# TODO(@MUCDK) fix mypy error
@register_solver("ott") # type: ignore[arg-type]
@register_solver("ott")
def _(
problem_kind: Literal["linear", "quadratic"],
solver_name: Optional[Literal["GENOTLinSolver"]] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/base/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayLike:
f"Cost matrix contains `{np.sum(np.isnan(cost))}` NaN values, "
f"setting them to the maximum value `{maxx}`."
)
cost = np.nan_to_num(cost, nan=maxx) # type: ignore[call-overload]
cost = np.nan_to_num(cost, nan=maxx) # type: ignore[arg-type, type-var]
if np.any(cost < 0):
raise ValueError(f"Cost matrix contains `{np.sum(cost < 0)}` negative values.")
return cost
Expand Down
32 changes: 9 additions & 23 deletions src/moscot/base/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@
from scipy.sparse.linalg import LinearOperator

from moscot._logging import logger
from moscot._types import ArrayLike, Device_t, DTypeLike # type: ignore[attr-defined]
from moscot._types import ArrayLike, Device_t, DTypeLike

__all__ = ["BaseDiscreteSolverOutput", "MatrixSolverOutput", "BaseNeuralOutput"]


class BaseSolverOutput(abc.ABC):
"""Base class for all solver outputs."""

@abc.abstractmethod
def pull(self, x: ArrayLike, **kwargs) -> ArrayLike:
"""Pull the solution based on a condition."""

@abc.abstractmethod
def push(self, x: ArrayLike, **kwargs) -> ArrayLike:
"""Push the solution based on a condition."""

@abc.abstractmethod
def _apply_forward(self, x: ArrayLike) -> ArrayLike:
"""Apply the transport matrix in the forward direction."""

@property
@abc.abstractmethod
def shape(self) -> tuple[int, int]:
Expand Down Expand Up @@ -348,6 +348,9 @@ def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike:
return self.transport_matrix.T @ x
return self.transport_matrix @ x

def _apply_forward(self, x: ArrayLike) -> ArrayLike:
return self._apply(x, forward=True)

@property
def transport_matrix(self) -> ArrayLike: # noqa: D102
return self._transport_matrix
Expand Down Expand Up @@ -393,7 +396,7 @@ def _ones(self, n: int) -> ArrayLike:
return jnp.ones((n,), dtype=self.transport_matrix.dtype)


class BaseNeuralOutput(BaseDiscreteSolverOutput, abc.ABC):
class BaseNeuralOutput(BaseSolverOutput, abc.ABC):
"""Base class for output of."""

@abstractmethod
Expand All @@ -402,27 +405,10 @@ def project_to_transport_matrix(
source: Optional[ArrayLike] = None,
target: Optional[ArrayLike] = None,
condition: Optional[ArrayLike] = None,
forward: bool = True,
save_transport_matrix: bool = False,
batch_size: int = 1024,
k: int = 30,
length_scale: Optional[float] = None,
seed: int = 42,
) -> sp.csr_matrix:
"""Project transport matrix."""
pass

@property
def transport_matrix(self): # noqa: D102
raise NotImplementedError("Neural output does not require a transport matrix.")

@property
def cost(self): # noqa: D102
raise NotImplementedError("Neural output does not implement a cost property.")

@property
def potentials(self): # noqa: D102
raise NotImplementedError("Neural output does not need to implement a potentials property.")

def _ones(self, n: int): # noqa: D102
raise NotImplementedError("Neural output does not need to implement a `_ones` property.")
2 changes: 1 addition & 1 deletion src/moscot/base/problems/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def perm_test_extractor(res: Sequence[Tuple[ArrayLike, ArrayLike]]) -> Tuple[Arr
corr_bs = np.concatenate(corr_bs, axis=0)
corr_ci_low, corr_ci_high = np.quantile(corr_bs, q=ql, axis=0), np.quantile(corr_bs, q=qh, axis=0)

return pvals, corr_ci_low, corr_ci_high # type:ignore[return-value]
return pvals, corr_ci_low, corr_ci_high

if not (0 <= confidence_level <= 1):
raise ValueError(f"Expected `confidence_level` to be in interval `[0, 1]`, found `{confidence_level}`.")
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/problems/generic/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def set_quad_defaults(z: Optional[Union[str, Mapping[str, Any]]]) -> Dict[str, s
raise TypeError("`x_attr` and `y_attr` must be of type `str` or `dict` if no callback is provided.")


class SinkhornProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc]
class SinkhornProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]):
"""Class for solving a :term:`linear problem`.
Parameters
Expand Down Expand Up @@ -257,7 +257,7 @@ def _valid_policies(self) -> Tuple[Policy_t, ...]:
return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value]


class GWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc]
class GWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]):
"""Class for solving the :term:`GW <Gromov-Wasserstein>` or :term:`FGW <fused Gromov-Wasserstein>` problems.
Parameters
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/problems/space/_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def _create_problem(
adata_tgt=self.adata_sc,
src_obs_mask=src_mask,
tgt_obs_mask=None,
src_var_mask=self.filtered_vars, # type: ignore[arg-type]
tgt_var_mask=self.filtered_vars, # type: ignore[arg-type]
src_var_mask=self.filtered_vars,
tgt_var_mask=self.filtered_vars,
src_key=src,
tgt_key=tgt,
**kwargs,
Expand Down
Loading

0 comments on commit b55a0e4

Please sign in to comment.