Skip to content

first draft classifier-free guidance #132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
9 changes: 9 additions & 0 deletions src/cfp/model/_cellflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ def prepare_model(
flow: dict[Literal["constant_noise", "bridge"], float] | None = None,
match_fn: Callable[[ArrayLike, ArrayLike], ArrayLike] = match_linear,
optimizer: optax.GradientTransformation = optax.adam(1e-4),
cfg_p_resample: float = 0.0,
cfg_ode_weight: float = 0.0,
layer_norm_before_concatenation: bool = False,
linear_projection_before_concatenation: bool = False,
genot_source_layers: Layers_t | None = None,
Expand Down Expand Up @@ -346,6 +348,11 @@ def prepare_model(
data and return the optimal transport matrix, see e.g. :func:`cfp.utils.match_linear`.
optimizer
Optimizer used for training.
cfg_p_resample
Probability of the null condition for classifier free guidance.
cfg_ode_weight
Weighting factor of the null condition for classifier free guidance.
0 corresponds to no classifier-free guidance, the larger 0, the more guidance.
layer_norm_before_concatenation
If :obj:`True`, applies layer normalization before concatenating
the embedded time, embedded data, and condition embeddings.
Expand Down Expand Up @@ -447,6 +454,8 @@ def prepare_model(
match_fn=match_fn,
flow=flow,
optimizer=optimizer,
cfg_p_resample=cfg_p_resample,
cfg_ode_weight=cfg_ode_weight,
conditions=self.train_data.condition_data,
rng=jax.random.PRNGKey(seed),
**solver_kwargs,
Expand Down
44 changes: 42 additions & 2 deletions src/cfp/solvers/_otfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ class OTFlowMatching:
time_sampler
Time sampler with a ``(rng, n_samples) -> time`` signature, see e.g.
:func:`ott.solvers.utils.uniform_sampler`.
cfg_p_resample
Probability of the null condition for classifier free guidance.
cfg_ode_weight
Weighting factor of the null condition for classifier free guidance.
0 corresponds to no classifier-free guidance, the larger 0, the more guidance.
kwargs
Keyword arguments for :meth:`cfp.networks.ConditionalVelocityField.create_train_state`.
"""
Expand All @@ -48,18 +53,33 @@ def __init__(
time_sampler: Callable[
[jax.Array, int], jnp.ndarray
] = solver_utils.uniform_sampler,
cfg_p_resample: float = 0.0,
cfg_ode_weight: float = 0.0,
**kwargs: Any,
):
self._is_trained: bool = False
self.vf = vf
self.flow = flow
self.time_sampler = time_sampler
if cfg_p_resample > 0 and cfg_ode_weight == 0:
raise ValueError(
"cfg_p_resample > 0 requires cfg_ode_weight > 0 for classifier free guidance."
)
if cfg_p_resample == 0 and cfg_ode_weight > 0:
raise ValueError(
"cfg_ode_weight > 0 requires cfg_p_resample > 0 for classifier free guidance."
)
if cfg_ode_weight < 0:
raise ValueError("cfg_ode_weight must be non-negative.")
self.cfg_p_resample = cfg_p_resample
self.cfg_ode_weight = cfg_ode_weight
self.match_fn = jax.jit(match_fn)

self.vf_state = self.vf.create_train_state(
input_dim=self.vf.output_dims[-1], **kwargs
)
self.vf_step_fn = self._get_vf_step_fn()
self.null_value_cfg = self.vf.mask_value

def _get_vf_step_fn(self) -> Callable: # type: ignore[type-arg]

Expand Down Expand Up @@ -125,7 +145,14 @@ def step_fn(
"""
src, tgt = batch["src_cell_data"], batch["tgt_cell_data"]
condition = batch.get("condition")
rng_resample, rng_step_fn = jax.random.split(rng, 2)
rng_resample, rng_cfg, rng_step_fn = jax.random.split(rng, 3)
cfg_null = jax.random.bernoulli(rng_cfg, self.cfg_p_resample)
if cfg_null:
# TODO: adapt to null condition in transformer
condition = jax.tree_util.tree_map(
lambda x: jnp.full(x.shape, self.null_value_cfg), condition
)

if self.match_fn is not None:
tmat = self.match_fn(src, tgt)
src_ixs, tgt_ixs = solver_utils.sample_joint(rng_resample, tmat)
Expand Down Expand Up @@ -192,8 +219,21 @@ def vf(
params = self.vf_state.params
return self.vf_state.apply_fn({"params": params}, t, x, cond, train=False)

def vf_cfg(
t: jnp.ndarray, x: jnp.ndarray, cond: dict[str, jnp.ndarray] | None
) -> jnp.ndarray:
cond_mask = jax.tree_util.tree_map(
lambda x: jnp.full(x.shape, self.null_value_cfg), cond
)
params = self.vf_state.params
return (1 + self.cfg_ode_weight) * self.vf_state.apply_fn(
{"params": params}, t, x, cond, train=False
) - self.cfg_ode_weight * self.vf_state.apply_fn(
{"params": params}, t, x, cond_mask, train=False
)

def solve_ode(x: jnp.ndarray, condition: jnp.ndarray | None) -> jnp.ndarray:
ode_term = diffrax.ODETerm(vf)
ode_term = diffrax.ODETerm(vf_cfg if self.cfg_p_resample else vf)
result = diffrax.diffeqsolve(
ode_term,
t0=0.0,
Expand Down
2 changes: 1 addition & 1 deletion src/cfp/training/_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def train(
loss = self.solver.step_fn(rng_step_fn, batch)
self.training_logs["loss"].append(float(loss))

if ((it - 1) % valid_freq == 0) and (it > 1):
if ((it + 1) % valid_freq == 0) and (it > 1):
# Get predictions from validation data
valid_true_data, valid_pred_data = self._validation_step(
valid_loaders, mode="on_log_iteration"
Expand Down
15 changes: 12 additions & 3 deletions tests/model/test_cellflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,18 @@

class TestCellFlow:
@pytest.mark.parametrize("solver", ["otfm", "genot"])
@pytest.mark.parametrize("use_classifier_free_guidance", [False, True])
def test_cellflow_solver(
self,
adata_perturbation,
solver,
self, adata_perturbation, solver, use_classifier_free_guidance
):
if solver == "genot" and use_classifier_free_guidance:
pytest.skip("Classifier free guidance is not implemented for GENOT")
if use_classifier_free_guidance:
cfg_p_resample = 0.3
cfg_ode_weight = 2.0
else:
cfg_p_resample = 0.0
cfg_ode_weight = 0.0
sample_rep = "X"
control_key = "control"
perturbation_covariates = {"drug": ["drug1", "drug2"]}
Expand All @@ -47,6 +54,8 @@ def test_cellflow_solver(
hidden_dims=(32, 32),
decoder_dims=(32, 32),
condition_encoder_kwargs=condition_encoder_kwargs,
cfg_p_resample=cfg_p_resample,
cfg_ode_weight=cfg_ode_weight,
)
assert cf._trainer is not None

Expand Down