Skip to content

Commit

Permalink
move to neural module
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen committed Jan 7, 2025
1 parent 5e5bbb3 commit 15f321c
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 85 deletions.
1 change: 0 additions & 1 deletion docs/user.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ Generic Problems
generic.SinkhornProblem
generic.GWProblem
generic.FGWProblem
generic.GENOTLinProblem

Plotting
~~~~~~~~
Expand Down
3 changes: 3 additions & 0 deletions src/moscot/neural/problems/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from moscot.neural.problems.generic import GENOTLinProblem

__all__ = ["GENOTLinProblem"]
3 changes: 3 additions & 0 deletions src/moscot/neural/problems/generic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from moscot.neural.problems.generic._generic import GENOTLinProblem

__all__ = ["GENOTLinProblem"]
93 changes: 93 additions & 0 deletions src/moscot/neural/problems/generic/_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import types
from types import MappingProxyType
from typing import Any, Dict, Literal, Mapping, Optional, Sequence, Tuple, Type, Union

from anndata import AnnData

from moscot import _constants
from moscot._types import (
CostKwargs_t,
OttCostFn_t,
OttCostFnMap_t,
Policy_t,
ProblemStage_t,
QuadInitializer_t,
ScaleCost_t,
SinkhornInitializer_t,
)
from moscot.base.problems.compound_problem import B, Callback_t, CompoundProblem, K
from moscot.base.problems.problem import CondOTProblem, OTProblem
from moscot.problems._utils import (
handle_conditional_attr,
handle_cost,
handle_cost_tmp,
handle_joint_attr,
handle_joint_attr_tmp,
)
from moscot.problems.generic._mixins import GenericAnalysisMixin

__all__ = ["GENOTLinProblem"]


class GENOTLinProblem(CondOTProblem):
"""Class for solving Conditional Parameterized Monge Map problems / Conditional Neural OT problems."""

def prepare(
self,
key: str,
joint_attr: Union[str, Mapping[str, Any]],
conditional_attr: Union[str, Mapping[str, Any]],
policy: Literal["sequential", "star", "explicit"] = "sequential",
a: Optional[str] = None,
b: Optional[str] = None,
cost: OttCostFn_t = "sq_euclidean",
cost_kwargs: CostKwargs_t = types.MappingProxyType({}),
**kwargs: Any,
) -> "GENOTLinProblem":
"""Prepare the :class:`moscot.problems.generic.GENOTLinProblem`."""
self.batch_key = key
xy, kwargs = handle_joint_attr_tmp(joint_attr, kwargs)
conditions = handle_conditional_attr(conditional_attr)
xy, xx = handle_cost_tmp(xy=xy, x={}, y={}, cost=cost, cost_kwargs=cost_kwargs)
return super().prepare(
policy_key=key,
policy=policy,
xy=xy,
xx=xx,
conditions=conditions,
a=a,
b=b,
**kwargs,
)

def solve(
self,
batch_size: int = 1024,
seed: int = 0,
iterations: int = 25000, # TODO(@MUCDK): rename to max_iterations
valid_freq: int = 50,
valid_sinkhorn_kwargs: Dict[str, Any] = MappingProxyType({}),
train_size: float = 1.0,
**kwargs: Any,
) -> "GENOTLinProblem":
"""Solve."""
return super().solve(
batch_size=batch_size,
# tau_a=tau_a, # TODO: unbalancedness handler
# tau_b=tau_b,
seed=seed,
n_iters=iterations,
valid_freq=valid_freq,
valid_sinkhorn_kwargs=valid_sinkhorn_kwargs,
train_size=train_size,
solver_name="GENOTLinSolver",
**kwargs,
)

@property
def _base_problem_type(self) -> Type[CondOTProblem]:
return CondOTProblem

@property
def _valid_policies(self) -> Tuple[Policy_t, ...]:
return _constants.SEQUENTIAL, _constants.EXPLICIT # type: ignore[return-value]
2 changes: 0 additions & 2 deletions src/moscot/problems/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from moscot.problems.cross_modality import TranslationProblem
from moscot.problems.generic import GENOTLinProblem
from moscot.problems.space import AlignmentProblem, MappingProblem
from moscot.problems.spatiotemporal import SpatioTemporalProblem
from moscot.problems.time import LineageProblem, TemporalProblem
Expand All @@ -11,5 +10,4 @@
"SpatioTemporalProblem",
"LineageProblem",
"TemporalProblem",
"GENOTLinProblem",
]
8 changes: 1 addition & 7 deletions src/moscot/problems/generic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
from moscot.problems.generic._generic import (
FGWProblem,
GENOTLinProblem,
GWProblem,
SinkhornProblem,
)
from moscot.problems.generic._generic import FGWProblem, GWProblem, SinkhornProblem
from moscot.problems.generic._mixins import GenericAnalysisMixin

__all__ = [
"FGWProblem" "SinkhornProblem",
"GENOTLinProblem",
"GWProblem",
"GenericAnalysisMixin",
]
77 changes: 3 additions & 74 deletions src/moscot/problems/generic/_generic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import types
from types import MappingProxyType
from typing import Any, Dict, Literal, Mapping, Optional, Sequence, Tuple, Type, Union

from anndata import AnnData
Expand All @@ -16,17 +15,11 @@
SinkhornInitializer_t,
)
from moscot.base.problems.compound_problem import B, Callback_t, CompoundProblem, K
from moscot.base.problems.problem import CondOTProblem, OTProblem
from moscot.problems._utils import (
handle_conditional_attr,
handle_cost,
handle_cost_tmp,
handle_joint_attr,
handle_joint_attr_tmp,
)
from moscot.base.problems.problem import OTProblem
from moscot.problems._utils import handle_cost, handle_joint_attr
from moscot.problems.generic._mixins import GenericAnalysisMixin

__all__ = ["SinkhornProblem", "GWProblem", "GENOTLinProblem", "FGWProblem"]
__all__ = ["SinkhornProblem", "GWProblem", "FGWProblem"]


def set_quad_defaults(z: Optional[Union[str, Mapping[str, Any]]]) -> Dict[str, str]:
Expand Down Expand Up @@ -774,67 +767,3 @@ def _base_problem_type(self) -> Type[B]:
@property
def _valid_policies(self) -> Tuple[Policy_t, ...]:
return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value]


class GENOTLinProblem(CondOTProblem):
"""Class for solving Conditional Parameterized Monge Map problems / Conditional Neural OT problems."""

def prepare(
self,
key: str,
joint_attr: Union[str, Mapping[str, Any]],
conditional_attr: Union[str, Mapping[str, Any]],
policy: Literal["sequential", "star", "explicit"] = "sequential",
a: Optional[str] = None,
b: Optional[str] = None,
cost: OttCostFn_t = "sq_euclidean",
cost_kwargs: CostKwargs_t = types.MappingProxyType({}),
**kwargs: Any,
) -> "GENOTLinProblem":
"""Prepare the :class:`moscot.problems.generic.GENOTLinProblem`."""
self.batch_key = key
xy, kwargs = handle_joint_attr_tmp(joint_attr, kwargs)
conditions = handle_conditional_attr(conditional_attr)
xy, xx = handle_cost_tmp(xy=xy, x={}, y={}, cost=cost, cost_kwargs=cost_kwargs)
return super().prepare(
policy_key=key,
policy=policy,
xy=xy,
xx=xx,
conditions=conditions,
a=a,
b=b,
**kwargs,
)

def solve(
self,
batch_size: int = 1024,
seed: int = 0,
iterations: int = 25000, # TODO(@MUCDK): rename to max_iterations
valid_freq: int = 50,
valid_sinkhorn_kwargs: Dict[str, Any] = MappingProxyType({}),
train_size: float = 1.0,
**kwargs: Any,
) -> "GENOTLinProblem":
"""Solve."""
return super().solve(
batch_size=batch_size,
# tau_a=tau_a, # TODO: unbalancedness handler
# tau_b=tau_b,
seed=seed,
n_iters=iterations,
valid_freq=valid_freq,
valid_sinkhorn_kwargs=valid_sinkhorn_kwargs,
train_size=train_size,
solver_name="GENOTLinSolver",
**kwargs,
)

@property
def _base_problem_type(self) -> Type[CondOTProblem]:
return CondOTProblem

@property
def _valid_policies(self) -> Tuple[Policy_t, ...]:
return _constants.SEQUENTIAL, _constants.EXPLICIT # type: ignore[return-value]
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from moscot.base.output import BaseDiscreteSolverOutput
from moscot.base.problems import CondOTProblem
from moscot.problems.generic import GENOTLinProblem # type: ignore[attr-defined]
from moscot.neural.problems.generic import GENOTLinProblem # type: ignore[attr-defined]
from moscot.utils.tagged_array import DistributionCollection, DistributionContainer
from tests._utils import ATOL, RTOL
from tests.problems.conftest import neurallin_cond_args_1
Expand Down

0 comments on commit 15f321c

Please sign in to comment.