From 15f321ce2b9e0a4c11a05cd13fff4530b71861d8 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 7 Jan 2025 22:12:53 +0100 Subject: [PATCH] move to neural module --- docs/user.rst | 1 - src/moscot/neural/problems/__init__.py | 3 + .../neural/problems/generic/__init__.py | 3 + .../neural/problems/generic/_generic.py | 93 +++++++++++++++++++ src/moscot/problems/__init__.py | 2 - src/moscot/problems/generic/__init__.py | 8 +- src/moscot/problems/generic/_generic.py | 77 +-------------- .../test_conditional_neural_problem.py | 2 +- 8 files changed, 104 insertions(+), 85 deletions(-) create mode 100644 src/moscot/neural/problems/__init__.py create mode 100644 src/moscot/neural/problems/generic/__init__.py create mode 100644 src/moscot/neural/problems/generic/_generic.py rename tests/{ => neural}/problems/generic/test_conditional_neural_problem.py (97%) diff --git a/docs/user.rst b/docs/user.rst index ccc769697..2c8d19448 100644 --- a/docs/user.rst +++ b/docs/user.rst @@ -27,7 +27,6 @@ Generic Problems generic.SinkhornProblem generic.GWProblem generic.FGWProblem - generic.GENOTLinProblem Plotting ~~~~~~~~ diff --git a/src/moscot/neural/problems/__init__.py b/src/moscot/neural/problems/__init__.py new file mode 100644 index 000000000..ccd9ebcd9 --- /dev/null +++ b/src/moscot/neural/problems/__init__.py @@ -0,0 +1,3 @@ +from moscot.neural.problems.generic import GENOTLinProblem + +__all__ = ["GENOTLinProblem"] \ No newline at end of file diff --git a/src/moscot/neural/problems/generic/__init__.py b/src/moscot/neural/problems/generic/__init__.py new file mode 100644 index 000000000..4b26a873b --- /dev/null +++ b/src/moscot/neural/problems/generic/__init__.py @@ -0,0 +1,3 @@ +from moscot.neural.problems.generic._generic import GENOTLinProblem + +__all__ = ["GENOTLinProblem"] \ No newline at end of file diff --git a/src/moscot/neural/problems/generic/_generic.py b/src/moscot/neural/problems/generic/_generic.py new file mode 100644 index 000000000..5597c0331 --- /dev/null +++ b/src/moscot/neural/problems/generic/_generic.py @@ -0,0 +1,93 @@ +import types +from types import MappingProxyType +from typing import Any, Dict, Literal, Mapping, Optional, Sequence, Tuple, Type, Union + +from anndata import AnnData + +from moscot import _constants +from moscot._types import ( + CostKwargs_t, + OttCostFn_t, + OttCostFnMap_t, + Policy_t, + ProblemStage_t, + QuadInitializer_t, + ScaleCost_t, + SinkhornInitializer_t, +) +from moscot.base.problems.compound_problem import B, Callback_t, CompoundProblem, K +from moscot.base.problems.problem import CondOTProblem, OTProblem +from moscot.problems._utils import ( + handle_conditional_attr, + handle_cost, + handle_cost_tmp, + handle_joint_attr, + handle_joint_attr_tmp, +) +from moscot.problems.generic._mixins import GenericAnalysisMixin + +__all__ = ["GENOTLinProblem"] + + +class GENOTLinProblem(CondOTProblem): + """Class for solving Conditional Parameterized Monge Map problems / Conditional Neural OT problems.""" + + def prepare( + self, + key: str, + joint_attr: Union[str, Mapping[str, Any]], + conditional_attr: Union[str, Mapping[str, Any]], + policy: Literal["sequential", "star", "explicit"] = "sequential", + a: Optional[str] = None, + b: Optional[str] = None, + cost: OttCostFn_t = "sq_euclidean", + cost_kwargs: CostKwargs_t = types.MappingProxyType({}), + **kwargs: Any, + ) -> "GENOTLinProblem": + """Prepare the :class:`moscot.problems.generic.GENOTLinProblem`.""" + self.batch_key = key + xy, kwargs = handle_joint_attr_tmp(joint_attr, kwargs) + conditions = handle_conditional_attr(conditional_attr) + xy, xx = handle_cost_tmp(xy=xy, x={}, y={}, cost=cost, cost_kwargs=cost_kwargs) + return super().prepare( + policy_key=key, + policy=policy, + xy=xy, + xx=xx, + conditions=conditions, + a=a, + b=b, + **kwargs, + ) + + def solve( + self, + batch_size: int = 1024, + seed: int = 0, + iterations: int = 25000, # TODO(@MUCDK): rename to max_iterations + valid_freq: int = 50, + valid_sinkhorn_kwargs: Dict[str, Any] = MappingProxyType({}), + train_size: float = 1.0, + **kwargs: Any, + ) -> "GENOTLinProblem": + """Solve.""" + return super().solve( + batch_size=batch_size, + # tau_a=tau_a, # TODO: unbalancedness handler + # tau_b=tau_b, + seed=seed, + n_iters=iterations, + valid_freq=valid_freq, + valid_sinkhorn_kwargs=valid_sinkhorn_kwargs, + train_size=train_size, + solver_name="GENOTLinSolver", + **kwargs, + ) + + @property + def _base_problem_type(self) -> Type[CondOTProblem]: + return CondOTProblem + + @property + def _valid_policies(self) -> Tuple[Policy_t, ...]: + return _constants.SEQUENTIAL, _constants.EXPLICIT # type: ignore[return-value] diff --git a/src/moscot/problems/__init__.py b/src/moscot/problems/__init__.py index 14f4422f5..b96b993b4 100644 --- a/src/moscot/problems/__init__.py +++ b/src/moscot/problems/__init__.py @@ -1,5 +1,4 @@ from moscot.problems.cross_modality import TranslationProblem -from moscot.problems.generic import GENOTLinProblem from moscot.problems.space import AlignmentProblem, MappingProblem from moscot.problems.spatiotemporal import SpatioTemporalProblem from moscot.problems.time import LineageProblem, TemporalProblem @@ -11,5 +10,4 @@ "SpatioTemporalProblem", "LineageProblem", "TemporalProblem", - "GENOTLinProblem", ] diff --git a/src/moscot/problems/generic/__init__.py b/src/moscot/problems/generic/__init__.py index d96dc4db6..ef9b78951 100644 --- a/src/moscot/problems/generic/__init__.py +++ b/src/moscot/problems/generic/__init__.py @@ -1,14 +1,8 @@ -from moscot.problems.generic._generic import ( - FGWProblem, - GENOTLinProblem, - GWProblem, - SinkhornProblem, -) +from moscot.problems.generic._generic import FGWProblem, GWProblem, SinkhornProblem from moscot.problems.generic._mixins import GenericAnalysisMixin __all__ = [ "FGWProblem" "SinkhornProblem", - "GENOTLinProblem", "GWProblem", "GenericAnalysisMixin", ] diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index fcd3d8b2e..53733bb1f 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -1,5 +1,4 @@ import types -from types import MappingProxyType from typing import Any, Dict, Literal, Mapping, Optional, Sequence, Tuple, Type, Union from anndata import AnnData @@ -16,17 +15,11 @@ SinkhornInitializer_t, ) from moscot.base.problems.compound_problem import B, Callback_t, CompoundProblem, K -from moscot.base.problems.problem import CondOTProblem, OTProblem -from moscot.problems._utils import ( - handle_conditional_attr, - handle_cost, - handle_cost_tmp, - handle_joint_attr, - handle_joint_attr_tmp, -) +from moscot.base.problems.problem import OTProblem +from moscot.problems._utils import handle_cost, handle_joint_attr from moscot.problems.generic._mixins import GenericAnalysisMixin -__all__ = ["SinkhornProblem", "GWProblem", "GENOTLinProblem", "FGWProblem"] +__all__ = ["SinkhornProblem", "GWProblem", "FGWProblem"] def set_quad_defaults(z: Optional[Union[str, Mapping[str, Any]]]) -> Dict[str, str]: @@ -774,67 +767,3 @@ def _base_problem_type(self) -> Type[B]: @property def _valid_policies(self) -> Tuple[Policy_t, ...]: return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value] - - -class GENOTLinProblem(CondOTProblem): - """Class for solving Conditional Parameterized Monge Map problems / Conditional Neural OT problems.""" - - def prepare( - self, - key: str, - joint_attr: Union[str, Mapping[str, Any]], - conditional_attr: Union[str, Mapping[str, Any]], - policy: Literal["sequential", "star", "explicit"] = "sequential", - a: Optional[str] = None, - b: Optional[str] = None, - cost: OttCostFn_t = "sq_euclidean", - cost_kwargs: CostKwargs_t = types.MappingProxyType({}), - **kwargs: Any, - ) -> "GENOTLinProblem": - """Prepare the :class:`moscot.problems.generic.GENOTLinProblem`.""" - self.batch_key = key - xy, kwargs = handle_joint_attr_tmp(joint_attr, kwargs) - conditions = handle_conditional_attr(conditional_attr) - xy, xx = handle_cost_tmp(xy=xy, x={}, y={}, cost=cost, cost_kwargs=cost_kwargs) - return super().prepare( - policy_key=key, - policy=policy, - xy=xy, - xx=xx, - conditions=conditions, - a=a, - b=b, - **kwargs, - ) - - def solve( - self, - batch_size: int = 1024, - seed: int = 0, - iterations: int = 25000, # TODO(@MUCDK): rename to max_iterations - valid_freq: int = 50, - valid_sinkhorn_kwargs: Dict[str, Any] = MappingProxyType({}), - train_size: float = 1.0, - **kwargs: Any, - ) -> "GENOTLinProblem": - """Solve.""" - return super().solve( - batch_size=batch_size, - # tau_a=tau_a, # TODO: unbalancedness handler - # tau_b=tau_b, - seed=seed, - n_iters=iterations, - valid_freq=valid_freq, - valid_sinkhorn_kwargs=valid_sinkhorn_kwargs, - train_size=train_size, - solver_name="GENOTLinSolver", - **kwargs, - ) - - @property - def _base_problem_type(self) -> Type[CondOTProblem]: - return CondOTProblem - - @property - def _valid_policies(self) -> Tuple[Policy_t, ...]: - return _constants.SEQUENTIAL, _constants.EXPLICIT # type: ignore[return-value] diff --git a/tests/problems/generic/test_conditional_neural_problem.py b/tests/neural/problems/generic/test_conditional_neural_problem.py similarity index 97% rename from tests/problems/generic/test_conditional_neural_problem.py rename to tests/neural/problems/generic/test_conditional_neural_problem.py index 5a7297de0..f624d0756 100644 --- a/tests/problems/generic/test_conditional_neural_problem.py +++ b/tests/neural/problems/generic/test_conditional_neural_problem.py @@ -8,7 +8,7 @@ from moscot.base.output import BaseDiscreteSolverOutput from moscot.base.problems import CondOTProblem -from moscot.problems.generic import GENOTLinProblem # type: ignore[attr-defined] +from moscot.neural.problems.generic import GENOTLinProblem # type: ignore[attr-defined] from moscot.utils.tagged_array import DistributionCollection, DistributionContainer from tests._utils import ATOL, RTOL from tests.problems.conftest import neurallin_cond_args_1