diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 00000000..77776b16 --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,15 @@ +coverage: + status: + project: + default: + # Set to informational only - will not block PRs + informational: true + patch: + default: + # Set to informational only - will not block PRs + informational: true + +comment: + # Still show coverage comments on PRs + layout: "diff, flags, files" + behavior: default diff --git a/.flake8 b/.flake8 index 29944067..968d55a9 100644 --- a/.flake8 +++ b/.flake8 @@ -2,4 +2,4 @@ ignore = E203, E266, E501, W503, F403, F401, F821 max-line-length = 89 max-complexity = 18 -select = B,C,E,F,W,T4,B9 \ No newline at end of file +select = B,C,E,F,W,T4,B9 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fb52a54b..5216a411 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -44,5 +44,5 @@ jobs: uses: codecov/codecov-action@v5 with: files: coverage.xml - fail_ci_if_error: true + fail_ci_if_error: false token: ${{ secrets.CODECOV_TOKEN }} \ No newline at end of file diff --git a/src/gfn/env.py b/src/gfn/env.py index fc88ddf3..8b6566b3 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -322,7 +322,7 @@ def _step(self, states: States, actions: Actions) -> States: # We only step on states that are not sink states. # Note that exit actions directly set the states to the sink state, so they # are not included in the valid_states_idx. - new_valid_states_idx = valid_states_idx & ~actions.is_exit + new_valid_states_idx = valid_states_idx & ~actions.is_exit # boolean mask. # IMPORTANT: .clone() is used to ensure that the new states are a # distinct object from the old states. This is important for the sampler to @@ -330,7 +330,9 @@ def _step(self, states: States, actions: Actions) -> States: # method in your custom environment, you must ensure that the `new_states` # returned is a distinct object from the submitted states. not_done_states = states[new_valid_states_idx].clone() - not_done_actions = actions[new_valid_states_idx] + not_done_actions = actions[ + new_valid_states_idx + ] # NOTE: boolean indexing creates a copy! not_done_states = self.step(not_done_states, not_done_actions) assert isinstance( diff --git a/src/gfn/estimators.py b/src/gfn/estimators.py index 9cb6839d..8ff5f824 100644 --- a/src/gfn/estimators.py +++ b/src/gfn/estimators.py @@ -1,3 +1,4 @@ +import math from abc import ABC, abstractmethod from collections import defaultdict from typing import Any, Callable, Dict, List, Optional, Protocol, cast, runtime_checkable @@ -26,6 +27,12 @@ "prod": torch.prod, } +# Relative tolerance for detecting terminal time in diffusion estimators. +# Must match TERMINAL_TIME_EPS in gfn.gym.diffusion_sampling to ensure consistent +# exit action detection between the estimator and environment. TODO: we should handle this +# centrally somewhere. +_DIFFUSION_TERMINAL_TIME_EPS = 1e-2 + class RolloutContext: """Structured per‑rollout state owned by estimators. @@ -1290,6 +1297,7 @@ def __init__( pf_module: nn.Module, sigma: float, num_discretization_steps: int, + n_variance_outputs: int = 0, ): """Initialize the PinnedBrownianMotionForward. @@ -1305,6 +1313,12 @@ def __init__( self.sigma = sigma self.num_discretization_steps = num_discretization_steps self.dt = 1.0 / self.num_discretization_steps + self.n_variance_outputs = n_variance_outputs + + @property + def expected_output_dim(self) -> int: + # Drift (s_dim) plus optional variance outputs. + return self.s_dim + self.n_variance_outputs def forward(self, input: States) -> torch.Tensor: """Forward pass of the module. @@ -1329,7 +1343,6 @@ def to_probability_distribution( states: States, module_output: torch.Tensor, **policy_kwargs: Any, - # TODO: add epsilon-noisy exploration ) -> IsotropicGaussian: """Transform the output of the module into a IsotropicGaussian distribution, which is the distribution of the next states under the pinned Brownian motion @@ -1339,24 +1352,75 @@ def to_probability_distribution( states: The states to use, states.tensor.shape = (*batch_shape, s_dim + 1). module_output: The output of the module (actions), as a tensor of shape (*batch_shape, s_dim). - **policy_kwargs: Keyword arguments to modify the distribution. + **policy_kwargs: Keyword arguments to modify the distribution. Supported + keys: + - exploration_std: Optional callable or float controlling extra + exploration noise on top of the base diffusion std. The callable + should accept an integer step index and return a non-negative + standard deviation in state space. When provided, the extra noise + is combined in variance-space (logaddexp) with the base diffusion + variance; non-positive exploration is ignored. Returns: A IsotropicGaussian distribution (distribution of the next states) """ assert len(states.batch_shape) == 1, "States must have a batch_shape of length 1" - s_curr = states.tensor[:, :-1] + # s_curr = states.tensor[:, :-1] t_curr = states.tensor[:, [-1]] + # Check if the NEXT step would reach terminal time, not if we're already there. + # This matches the exit condition in DiffusionSampling.step() and ensures the + # sampled action is marked as an exit action (-inf) so trajectory masks align + # correctly in get_trajectory_pbs. + eps = self.dt * _DIFFUSION_TERMINAL_TIME_EPS + is_final_step = (t_curr + self.dt) >= (1.0 - eps) + # TODO: The old code followed this convention (below). I believe the change + # is slightly more correct, but I'd like to check this during review. + # (1.0 - t_curr) < self.dt * 1e-2 # Triggers when t_curr ≈ 1.0 + module_output = torch.where( - (1.0 - t_curr) < self.dt * 1e-2, # sf case; when t_curr is 1.0 - torch.full_like(s_curr, -float("inf")), # This is the exit action + is_final_step, + torch.full_like(module_output, -float("inf")), # This is the exit action module_output, ) - fwd_mean = self.dt * module_output - fwd_std = torch.tensor(self.sigma * self.dt**0.5, device=fwd_mean.device) - fwd_std = fwd_std.repeat(fwd_mean.shape[0], 1) + drift = module_output[..., : self.s_dim] + if self.n_variance_outputs > 0: + var_part = module_output[..., self.s_dim :] + # Reduce extra variance dims to a single scalar (isotropic for now). + log_std = var_part.mean(dim=-1, keepdim=True) + fwd_std = torch.exp(log_std) * math.sqrt(self.dt) + else: + fwd_std = torch.tensor(self.sigma * self.dt**0.5, device=drift.device) + fwd_std = fwd_std.repeat(drift.shape[0], 1) + + # Match reference behavior: scale diffusion noise (not drift) by t_scale if present. + t_scale_factor = getattr(self.module, "t_scale", 1.0) + if t_scale_factor != 1.0: + fwd_std = fwd_std * math.sqrt(t_scale_factor) + + fwd_mean = self.dt * drift + + # Optional exploration noise: combine variances (quadrature/logaddexp). + exploration_std = policy_kwargs.pop("exploration_std", None) + exploration_std_t = torch.as_tensor( + exploration_std if exploration_std is not None else 0.0, + device=fwd_std.device, + dtype=fwd_std.dtype, + ).clamp(min=0.0) + + # Combine base diffusion variance σ_base^2 with exploration variance σ_expl^2: + # σ_combined = sqrt(σ_base^2 + σ_expl^2). torch.compile friendly. + base_log_var = 2 * fwd_std.log() # log(σ_base^2) + extra_log_var = 2 * exploration_std_t.clamp(min=1e-12).log() # log(σ_expl^2) + extra_log_var_tensor = extra_log_var.expand_as(base_log_var) + combined_log_var = torch.logaddexp(base_log_var, extra_log_var_tensor) + fwd_std = torch.where( + exploration_std_t > 0, + torch.exp(0.5 * combined_log_var), + fwd_std, + ) + return IsotropicGaussian(fwd_mean, fwd_std) @@ -1367,30 +1431,34 @@ def __init__( pb_module: nn.Module, sigma: float, num_discretization_steps: int, + n_variance_outputs: int = 0, + pb_scale_range: float = 0.1, ): - """Initialize the PinnedBrownianMotionForward. + """Initialize the PinnedBrownianMotionBackward. Args: s_dim: The dimension of the states. pb_module: The neural network module to use for the backward policy. sigma: The diffusion coefficient parameter for the pinned Brownian motion. num_discretization_steps: The number of discretization steps. + n_variance_outputs: Number of variance outputs (0=fixed, 1=learned corr). + pb_scale_range: Scaling applied to learned corrections (tanh-bounded). """ super().__init__(s_dim=s_dim, module=pb_module, is_backward=True) # Pinned Brownian Motion related self.sigma = sigma self.dt = 1.0 / num_discretization_steps + self.n_variance_outputs = n_variance_outputs + self.pb_scale_range = pb_scale_range - def forward(self, input: States) -> torch.Tensor: - """Forward pass of the module. - - Args: - input: The input to the module as states. + @property + def expected_output_dim(self) -> int: + # Drift correction (s_dim) plus optional variance correction outputs. + return self.s_dim + self.n_variance_outputs - Returns: - The output of the module, as a tensor of shape (*batch_shape, output_dim). - """ + def forward(self, input: States) -> torch.Tensor: + """Forward pass of the module.""" out = self.module(self.preprocessor(input)) if self.expected_output_dim is not None: @@ -1411,6 +1479,7 @@ def to_probability_distribution( which is the distribution of the previous states under the pinned Brownian motion process, possibly controlled by the output of the backward module. If the module is a fixed backward module, the `module_output` is a zero vector (no control). + Includes optional learned corrections. Args: states: The states to use, states.tensor.shape = (*batch_shape, s_dim + 1). @@ -1426,14 +1495,33 @@ def to_probability_distribution( t_curr = states.tensor[:, [-1]] # shape: (*batch_shape,) is_s0 = (t_curr - self.dt) < self.dt * 1e-2 # s0 case; when t_curr - dt is 0.0 - bwd_mean = torch.where( + # Analytic Brownian bridge base + # Brownian bridge mean toward 0 at t=0: + # E[s_{t-dt} | s_t] = s_t * (1 - dt / t) and collapses to 0 at the start. + # Here, we calculate the *action* which moves the state in expectation toward 0 + # at t=0, so we scale s_curr by our distance to t=0. + base_mean = torch.where( is_s0, - s_curr, - s_curr * self.dt / t_curr, + torch.zeros_like(s_curr), + s_curr + * self.dt + / t_curr, # s_curr (batch, s_dim), t_curr (batch, 1), dt is scalar. ) - bwd_std = torch.where( + base_std = torch.where( is_s0, torch.zeros_like(t_curr), self.sigma * (self.dt * (t_curr - self.dt) / t_curr).sqrt(), ) + + # Optional learned corrections (tanh-bounded); when n_variance_outputs==0, only mean corr. + mean_corr = module_output[..., : self.s_dim] * self.pb_scale_range + if self.n_variance_outputs > 0 and module_output.shape[-1] >= self.s_dim + 1: + log_std_corr = module_output[..., [-1]] * self.pb_scale_range + corr_std = torch.exp(log_std_corr) + else: + corr_std = torch.zeros_like(base_std) + + bwd_mean = base_mean + mean_corr + bwd_std = (base_std**2 + corr_std**2).sqrt() + return IsotropicGaussian(bwd_mean, bwd_std) diff --git a/src/gfn/gflownet/__init__.py b/src/gfn/gflownet/__init__.py index 77fcf893..108c08c1 100644 --- a/src/gfn/gflownet/__init__.py +++ b/src/gfn/gflownet/__init__.py @@ -2,7 +2,11 @@ from .detailed_balance import DBGFlowNet, ModifiedDBGFlowNet from .flow_matching import FMGFlowNet from .sub_trajectory_balance import SubTBGFlowNet -from .trajectory_balance import LogPartitionVarianceGFlowNet, TBGFlowNet +from .trajectory_balance import ( + LogPartitionVarianceGFlowNet, + RelativeTrajectoryBalanceGFlowNet, + TBGFlowNet, +) __all__ = [ "GFlowNet", @@ -13,5 +17,6 @@ "FMGFlowNet", "SubTBGFlowNet", "LogPartitionVarianceGFlowNet", + "RelativeTrajectoryBalanceGFlowNet", "TBGFlowNet", ] diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 5542019b..24a976b8 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -11,7 +11,11 @@ from gfn.estimators import Estimator from gfn.samplers import Sampler from gfn.states import States -from gfn.utils.prob_calculations import get_trajectory_pfs_and_pbs +from gfn.utils.prob_calculations import ( + get_trajectory_pbs, + get_trajectory_pfs, + get_trajectory_pfs_and_pbs, +) TrainingSampleType = TypeVar("TrainingSampleType", bound=Container) @@ -343,6 +347,32 @@ def get_pfs_and_pbs( recalculate_all_logprobs, ) + def trajectory_log_probs_forward( + self, + trajectories: Trajectories, + fill_value: float = 0.0, + recalculate_all_logprobs: bool = True, + ) -> torch.Tensor: + """Evaluates forward logprobs only for each trajectory in the batch.""" + return get_trajectory_pfs( + self.pf, + trajectories, + fill_value=fill_value, + recalculate_all_logprobs=recalculate_all_logprobs, + ) + + def trajectory_log_probs_backward( + self, + trajectories: Trajectories, + fill_value: float = 0.0, + ) -> torch.Tensor: + """Evaluates backward logprobs only for each trajectory in the batch.""" + return get_trajectory_pbs( + self.pb, + trajectories, + fill_value=fill_value, + ) + def get_scores( self, trajectories: Trajectories, diff --git a/src/gfn/gflownet/mle.py b/src/gfn/gflownet/mle.py new file mode 100644 index 00000000..87a5d547 --- /dev/null +++ b/src/gfn/gflownet/mle.py @@ -0,0 +1,274 @@ +""" +MLE loss for diffusion GFlowNets (forward PF with optional PB). + +Key equations (per time step, shapes in comments): + - Backward bridge (s_t -> s_{t-dt}): + mean_bb = s_t * (1 - dt / t) # (B, s_dim) + std_bb = sigma * sqrt(dt*(t-dt)/t) # (B, 1) broadcast + With learned PB corrections: + mean = mean_bb + mean_corr + std = sqrt(std_bb^2 + corr_std^2) + - Forward PF log-prob for increment Δ = s_t - s_{t-dt}: + If PF predicts log_std: + σ = exp(log_std) * sqrt(dt) * sqrt(t_scale); optionally combine exploration + log p = -0.5 * Σ_i [ ((Δ - dt μ)_i / σ_i)^2 + 2 log σ_i + log 2π ] + Else (fixed variance): + σ = sigma * sqrt(dt) * sqrt(t_scale); optionally combine exploration + log p = -0.5 * Σ_i [ ((Δ - dt μ)_i / σ)^2 + log(2π σ^2) ] + - Loss = -mean over batch of Σ_t log p_t + +Tensor conventions: + - terminal_states: (B, s_dim) or (B, s_dim + 1) with last dim an extra + terminal indicator column; we drop the last dim if present. + - Times: scalar dt = 1/num_steps; t_curr = 1 - i*dt; t_fwd = 1 - (i+1)*dt. + +Usage (user owns optimizer/loop): +```python +gfn = MLEDiffusion(pf=pf, pb=None, num_steps=100, sigma=2.0, t_scale=1.0) +opt = torch.optim.Adam(gfn.parameters(), lr=1e-3) +for it in n_iterations: + # Sample a batch of terminal states. + batch = env.sample(batch_size) # batch shape (B, s_dim) + opt.zero_grad() + # Calculate the MLE loss under the backward / forward diffusion process. + loss = gfn.loss(batch, exploration_std=0.0) + loss.backward() + opt.step() +``` +""" + +from __future__ import annotations + +import math +from typing import Any, Optional + +import torch + +try: # torch._dynamo may be absent or flagged private by linters + from torch._dynamo import disable as dynamo_disable +except Exception: # pragma: no cover + + def dynamo_disable(fn): # type: ignore[return-type] + return fn + + +from gfn.env import Env +from gfn.estimators import ( + PinnedBrownianMotionBackward, + PinnedBrownianMotionForward, +) +from gfn.gflownet.base import GFlowNet +from gfn.samplers import Sampler +from gfn.states import States +from gfn.utils.modules import DiffusionFixedBackwardModule + +# Relative tolerance for detecting initial/terminal states in diffusion trajectories. +# Must be synchronized with TERMINAL_TIME_EPS in gfn.gym.diffusion_sampling and +# _DIFFUSION_TERMINAL_TIME_EPS in gfn.estimators. +_DIFFUSION_TERMINAL_TIME_EPS = 1e-2 + + +class MLEDiffusion(GFlowNet): + """ + Maximum-likelihood diffusion GFlowNet (PF with optional PB). + + The caller owns the training loop; this class provides: + - sampling via the forward PF (for API compatibility) + - `.loss(env, terminal_states, ...)` computing the MLE objective + """ + + def __init__( + self, + pf: PinnedBrownianMotionForward, + pb: Optional[PinnedBrownianMotionBackward] = None, + *, + num_steps: int, + sigma: float, + t_scale: float = 1.0, + pb_scale_range: float = 0.1, + learn_variance: bool = False, + reduction: str = "mean", + debug: bool = False, + ) -> None: + super().__init__() + self.pf = pf + if pb is None: + # Constant PB estimator (no learned parameters) + pb = PinnedBrownianMotionBackward( + s_dim=pf.s_dim, + pb_module=DiffusionFixedBackwardModule(pf.s_dim), + sigma=sigma, + num_discretization_steps=num_steps, + n_variance_outputs=0, + pb_scale_range=pb_scale_range, + ).to(next(pf.parameters()).device) + self.pb = pb + self.s_dim = pf.s_dim + self.num_steps = num_steps + self.dt = 1.0 / num_steps + self.sigma = sigma + self.t_scale = t_scale + self.pb_scale_range = pb_scale_range + self.learn_variance = learn_variance + self.reduction = reduction + self.debug = debug + + # Sampler for base-class API (sample_trajectories). + self.sampler = Sampler(estimator=self.pf) + + def sample_trajectories( + self, + env: Env, + n: int, + conditions: torch.Tensor | None = None, + save_logprobs: bool = False, + save_estimator_outputs: bool = False, + **policy_kwargs: Any, + ): + return self.sampler.sample_trajectories( + env, + n, + conditions=conditions, + save_logprobs=save_logprobs, + save_estimator_outputs=save_estimator_outputs, + **policy_kwargs, + ) + + def to_training_samples(self, trajectories): + return trajectories + + def loss( + self, + env: Env, + terminal_states: Any, + recalculate_all_logprobs: bool = True, + *, + exploration_std: float | torch.Tensor = 0.0, + ) -> torch.Tensor: + """ + Compute the MLE objective given terminal states sampled from the target. + + Args: + terminal_states: torch.Tensor or States; shape (B, s_dim) or (B, s_dim+1). + exploration_std: extra state-space noise (combined in quadrature with PF std). + Returns: + Scalar loss (mean reduction). + """ + del env # unused + del recalculate_all_logprobs # unused + device, dtype, s_curr = self._extract_samples(terminal_states) + + bsz, dim = s_curr.shape + assert dim == self.s_dim, f"Expected s_dim={self.s_dim}, got {dim}" + dt = self.dt + + # Tolerance for detecting initial state (t ≈ 0). Uses the module-level constant + # which must stay synchronized with TERMINAL_TIME_EPS in diffusion_sampling.py + # and _DIFFUSION_TERMINAL_TIME_EPS in estimators.py. + eps_s0 = dt * _DIFFUSION_TERMINAL_TIME_EPS + + sqrt_dt_t_scale = math.sqrt(dt * self.t_scale) + base_std_fixed = self.sigma * sqrt_dt_t_scale + log_2pi = math.log(2 * math.pi) + + logpf_sum = torch.zeros(bsz, device=device, dtype=dtype) + exploration_std_t = torch.as_tensor( + exploration_std, device=device, dtype=dtype + ).clamp(min=0.0) + exploration_var = exploration_std_t**2 + + # Precompute time grids to avoid per-step allocations. + all_t_fwd = torch.linspace( + 1.0 - dt, 0.0, self.num_steps, device=device, dtype=dtype + ) + all_t_curr = torch.linspace(1.0, dt, self.num_steps, device=device, dtype=dtype) + + for i in range(self.num_steps): + # Times: forward transition index t_fwd corresponds to s_prev -> s_curr. + t_fwd = all_t_fwd[i].expand(bsz, 1) + t_curr = all_t_curr[i].expand(bsz, 1) + + # Backward sampler: Brownian bridge base + optional PB corrections. + pb_inp = torch.cat([s_curr, t_curr], dim=1) + pb_out = self.pb.module(pb_inp) + + # Base Brownian bridge mean/std toward 0 at t=0. + is_s0 = (t_curr - dt) < eps_s0 + not_s0 = (~is_s0).float() + + base_mean = s_curr * (1.0 - dt / t_curr) * not_s0 + base_std = self.sigma * (dt * (t_curr - dt) / t_curr).sqrt() * not_s0 + + # Learned corrections (PB): mean_corr, optional log-std corr. + mean_corr = pb_out[..., :dim] * self.pb.pb_scale_range + if self.pb.n_variance_outputs > 0: + log_std_corr = pb_out[..., [-1]] * self.pb.pb_scale_range + corr_std = torch.exp(log_std_corr) + else: + corr_std = torch.zeros_like(base_std) + + bwd_std = (base_std**2 + corr_std**2).sqrt() + noise = torch.randn_like(s_curr, device=device, dtype=dtype) + s_prev = base_mean + mean_corr + bwd_std * noise + + # Forward log-prob under PF for the observed increment (s_prev -> s_curr). + model_inp = torch.cat([s_prev, t_fwd], dim=1) + module_out = self.pf.module(model_inp) + increment = s_curr - s_prev + + # Case where module outputs learned variance. + if self.pf.n_variance_outputs > 0: + drift = module_out[..., :dim] + log_std = module_out[..., [-1]] + std = torch.exp(log_std) * sqrt_dt_t_scale + std = torch.sqrt(std**2 + exploration_var) + diff = increment - dt * drift + logpf_step = -0.5 * ((diff / std) ** 2 + 2 * std.log() + log_2pi).sum( + dim=1 + ) + # Fixed variance case. + else: + drift = module_out + std = torch.sqrt(base_std_fixed**2 + exploration_var) + diff = increment - dt * drift + logpf_step = -0.5 * ((diff / std) ** 2).sum(dim=1) - 0.5 * dim * ( + log_2pi + 2 * torch.log(std) + ) + + logpf_sum += logpf_step + s_curr = s_prev + + if self.debug and torch.isnan(logpf_sum).any(): + raise ValueError("NaNs in logpf_sum during MLE loss.") + + # TODO: Use included loss reduction helpers. + loss = -(logpf_sum.mean() if self.reduction == "mean" else logpf_sum.sum()) + if self.debug: + self._assert_no_nan(logpf_sum) + return loss + + @dynamo_disable + def _assert_no_nan(self, logpf_sum: torch.Tensor) -> None: + if torch.isnan(logpf_sum).any(): + raise ValueError("NaNs in logpf_sum during MLE loss.") + + @dynamo_disable + def _extract_samples( + self, terminal_states: Any + ) -> tuple[torch.device, torch.dtype, torch.Tensor]: + """ + Normalize input to a (B, s_dim) tensor. + Accepts torch.Tensor or States; drops a final column if size matches s_dim+1. + """ + if isinstance(terminal_states, States): + tensor = terminal_states.tensor + elif torch.is_tensor(terminal_states): + tensor = terminal_states + else: + raise TypeError(f"Unsupported terminal_states type: {type(terminal_states)}") + + if tensor.shape[-1] == self.s_dim + 1: + tensor = tensor[..., :-1] + device = tensor.device + dtype = tensor.dtype + return device, dtype, tensor diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index 8e2d6e6b..9c92134a 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -3,6 +3,7 @@ and the [Log Partition Variance loss](https://arxiv.org/abs/2302.05446). """ +import math from typing import cast import torch @@ -16,6 +17,7 @@ is_callable_exception_handler, warn_about_recalculating_logprobs, ) +from gfn.utils.prob_calculations import get_trajectory_pfs class TBGFlowNet(TrajectoryBasedGFlowNet): @@ -132,6 +134,120 @@ def loss( return loss +class RelativeTrajectoryBalanceGFlowNet(TrajectoryBasedGFlowNet): + r"""GFlowNet for the Relative Trajectory Balance (RTB) loss. + + This objective matches a posterior sampler to a prior diffusion (or other + sequential) model by minimizing + + .. math:: + + \left(\log Z_\phi + \log p_\phi(\tau) - \log p_\theta(\tau) + - \beta \log r(x_T)\right)^2, + + where :math:`p_\theta` is a fixed prior process, :math:`p_\phi` is the + learnable posterior, :math:`r` is a positive reward/constraint on the + terminal state :math:`x_T`, and :math:`\log Z_\phi` is a learned scalar + normalizer. + """ + + def __init__( + self, + pf: Estimator, + prior_pf: Estimator, + *, + logZ: nn.Parameter | ScalarEstimator | None = None, + init_logZ: float = 0.0, + beta: float = 1.0, + log_reward_clip_min: float = -float("inf"), + debug: bool = False, + ): + """Initializes an RTB GFlowNet. + + Args: + pf: Posterior forward policy estimator :math:`p_\\phi`. + prior_pf: Fixed prior forward policy estimator :math:`p_\\theta`. + logZ: Learnable log-partition parameter or ScalarEstimator for + conditional settings. Defaults to a scalar parameter. + init_logZ: Initial value for logZ if ``logZ`` is None. + beta: Optional scaling applied to the terminal log-reward. + log_reward_clip_min: If finite, clips terminal log-rewards. + debug: if True, enables extra checks at the cost of execution speed. + """ + super().__init__( + pf=pf, + pb=None, + constant_pb=True, + log_reward_clip_min=log_reward_clip_min, + ) + self.prior_pf = prior_pf + self.register_buffer("beta", torch.tensor(beta)) + self.logZ = logZ or nn.Parameter(torch.tensor(init_logZ)) + self.debug = debug # TODO: to be passed to base classes. + + def logz_named_parameters(self) -> dict[str, torch.Tensor]: + """Returns named parameters containing 'logZ'.""" + return {k: v for k, v in dict(self.named_parameters()).items() if "logZ" in k} + + def logz_parameters(self) -> list[torch.Tensor]: + """Returns parameters containing 'logZ'.""" + return [v for k, v in dict(self.named_parameters()).items() if "logZ" in k] + + def loss( + self, + env: Env, + trajectories: Trajectories, + recalculate_all_logprobs: bool = True, + reduction: str = "mean", + ) -> torch.Tensor: + """Computes the RTB loss on a batch of trajectories.""" + del env # unused + warn_about_recalculating_logprobs(trajectories, recalculate_all_logprobs) + + # Posterior log-probs. + log_pf_post = self.trajectory_log_probs_forward( + trajectories, + recalculate_all_logprobs=recalculate_all_logprobs, + ) + log_pf_post = log_pf_post.sum(dim=0) # Sum along trajectory length. + + # Prior log-probs along the same trajectories. + # The prior is fixed; evaluate it without tracking gradients to keep its + # parameters out of the RTB optimization graph. + with torch.no_grad(): + log_pf_prior = get_trajectory_pfs( + self.prior_pf, + trajectories, + fill_value=0.0, + recalculate_all_logprobs=True, + ) + log_pf_prior = log_pf_prior.sum(dim=0) # Sum along trajectory length. + + # Get the rewards. + log_rewards = trajectories.log_rewards + if self.debug: + assert log_rewards is not None + if math.isfinite(self.log_reward_clip_min): + log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) # type: ignore + + # Get logZ. + if trajectories.conditions is not None: + with is_callable_exception_handler("logZ", self.logZ): + assert isinstance(self.logZ, ScalarEstimator) + logZ = self.logZ(trajectories.conditions) + else: + logZ = self.logZ + logZ = cast(torch.Tensor, logZ).squeeze() + + scores = 0.5 * (log_pf_post + logZ - log_pf_prior - self.beta * log_rewards).pow(2) # type: ignore + + loss = loss_reduce(scores, reduction) # Reduce across batch dimension. + if torch.isnan(loss).any(): + raise ValueError("loss is nan") + + return loss + + class LogPartitionVarianceGFlowNet(TrajectoryBasedGFlowNet): """GFlowNet for the Log Partition Variance loss. diff --git a/src/gfn/gym/diffusion_sampling.py b/src/gfn/gym/diffusion_sampling.py index 68fc5385..93a83882 100644 --- a/src/gfn/gym/diffusion_sampling.py +++ b/src/gfn/gym/diffusion_sampling.py @@ -19,6 +19,17 @@ # Lightweight typing alias for the target registry entries. TargetEntry = tuple[type["BaseTarget"], dict[str, Any]] +# Relative tolerance (scaled by dt) for detecting initial/terminal states in diffusion +# trajectories. This ensures consistent boundary detection across the environment, +# estimators, and probability calculations. The tolerance is applied as: +# - Initial state: t < dt * TERMINAL_TIME_EPS +# - Terminal state: t >= 1.0 - dt * TERMINAL_TIME_EPS +# - Exit action trigger: t + dt >= 1.0 - dt * TERMINAL_TIME_EPS (next step reaches terminal) +TERMINAL_TIME_EPS = 1e-2 + +# Default output directory for saving visualizations +OUTPUT_DIR = "output" + ############################### ### Target energy functions ### @@ -399,9 +410,166 @@ def visualize( if show: plt.show() else: - os.makedirs("viz", exist_ok=True) - plt.savefig(f"viz/{prefix}simple_gmm.png") + os.makedirs(OUTPUT_DIR, exist_ok=True) + plt.savefig(f"{OUTPUT_DIR}/{prefix}simple_gmm.png") + + plt.close() + + +class Grid25GaussianMixture(BaseTarget): + """Fixed 5x5 Gaussian mixture prior used for RTB demos.""" + + def __init__( + self, + device: torch.device, + dim: int = 2, + scale: float = math.sqrt(0.3), + plot_border: float = 15.0, + seed: int = 0, + ) -> None: + assert dim == 2, "Grid25GaussianMixture is defined for 2D." + self.locs = torch.tensor( + [(a, b) for a in [-10, -5, 0, 5, 10] for b in [-10, -5, 0, 5, 10]], + device=device, + dtype=torch.get_default_dtype(), + ) + mix = D.Categorical( + probs=torch.full( + (self.locs.shape[0],), 1.0 / self.locs.shape[0], device=device + ) + ) + comp = D.Independent(D.Normal(self.locs, scale * torch.ones_like(self.locs)), 1) + self.gmm = D.MixtureSameFamily(mix, comp) + + super().__init__( + device=device, dim=dim, n_gt_xs=2048, seed=seed, plot_border=plot_border + ) + + def log_reward(self, x: torch.Tensor) -> torch.Tensor: + return self.gmm.log_prob(x).flatten() + + def sample(self, batch_size: int, seed: int | None = None) -> torch.Tensor: + ctx = nullcontext() + if seed is not None: + ctx = temporarily_set_seed(seed) + with ctx: + return self.gmm.sample((batch_size,)) + + def gt_logz(self) -> float: + return 0.0 + + def visualize( + self, + samples: torch.Tensor | None = None, + show: bool = False, + prefix: str = "", + grid_width_n_points: int = 100, + max_n_samples: int = 1000, + ) -> None: + assert self.plot_border is not None, "Visualization requires a plot border." + fig, ax = plt.subplots(1, 1, figsize=(6, 6)) + viz_2d_slice( + ax, + self, + (0, 1), + samples, + plot_border=self.plot_border, + use_log_reward=True, + grid_width_n_points=grid_width_n_points, + max_n_samples=max_n_samples, + ) + plt.tight_layout() + if show: + plt.show() + else: + os.makedirs(OUTPUT_DIR, exist_ok=True) + fig.savefig(f"{OUTPUT_DIR}/{prefix}gmm25.png") + plt.close() + +class Posterior9of25GaussianMixture(BaseTarget): + """Posterior reward for the 25→9 GMM RTB demo.""" + + def __init__( + self, + device: torch.device, + dim: int = 2, + scale: float = math.sqrt(0.3), + plot_border: float = 15.0, + seed: int = 0, + ) -> None: + assert dim == 2, "Posterior9of25GaussianMixture is defined for 2D." + self.prior = Grid25GaussianMixture( + device=device, dim=dim, scale=scale, plot_border=plot_border, seed=seed + ) + + mean_ls = [ + [-10.0, -5.0], + [-5.0, -10.0], + [-5.0, 0.0], + [10.0, -5.0], + [0.0, 0.0], + [0.0, 5.0], + [5.0, -5.0], + [5.0, 0.0], + [5.0, 10.0], + ] + locs = torch.tensor(mean_ls, device=device, dtype=torch.get_default_dtype()) + weights = torch.tensor( + [4, 10, 4, 5, 10, 5, 4, 15, 4], + device=device, + dtype=torch.get_default_dtype(), + ) + weights = weights / weights.sum() + + mix = D.Categorical(probs=weights) + comp = D.Independent(D.Normal(locs, scale * torch.ones_like(locs)), 1) + self.posterior = D.MixtureSameFamily(mix, comp) + + super().__init__( + device=device, dim=dim, n_gt_xs=2048, seed=seed, plot_border=plot_border + ) + + def log_reward(self, x: torch.Tensor) -> torch.Tensor: + # r(x) = p_post(x) / p_prior(x) + return self.posterior.log_prob(x).flatten() - self.prior.log_reward(x) + + def sample(self, batch_size: int, seed: int | None = None) -> torch.Tensor: + ctx = nullcontext() + if seed is not None: + ctx = temporarily_set_seed(seed) + with ctx: + return self.posterior.sample((batch_size,)) + + def gt_logz(self) -> float: + return 0.0 + + def visualize( + self, + samples: torch.Tensor | None = None, + show: bool = False, + prefix: str = "", + grid_width_n_points: int = 100, + max_n_samples: int = 1000, + ) -> None: + assert self.plot_border is not None, "Visualization requires a plot border." + fig, ax = plt.subplots(1, 1, figsize=(6, 6)) + viz_2d_slice( + ax, + self, + (0, 1), + samples, + plot_border=self.plot_border, + use_log_reward=True, + grid_width_n_points=grid_width_n_points, + max_n_samples=max_n_samples, + ) + plt.tight_layout() + if show: + plt.show() + else: + os.makedirs(OUTPUT_DIR, exist_ok=True) + fig.savefig(f"{OUTPUT_DIR}/{prefix}posterior9of25.png") plt.close() @@ -481,7 +649,7 @@ def visualize( samples: torch.Tensor | None = None, show: bool = False, prefix: str = "", - linspace_n_steps: int = 100, + grid_width_n_points: int = 100, max_n_samples: int = 500, ) -> None: """Visualize only supported for 2D (x0, x1).""" @@ -497,14 +665,16 @@ def visualize( samples, plot_border=self.plot_border, use_log_reward=True, + grid_width_n_points=grid_width_n_points, + max_n_samples=max_n_samples, ) plt.tight_layout() if show: plt.show() else: - os.makedirs("viz", exist_ok=True) - fig.savefig(f"viz/{prefix}funnel.png") + os.makedirs(OUTPUT_DIR, exist_ok=True) + fig.savefig(f"{OUTPUT_DIR}/{prefix}funnel.png") plt.close() @@ -640,7 +810,7 @@ def visualize( samples: torch.Tensor | None = None, show: bool = False, prefix: str = "", - linspace_n_steps: int = 100, + grid_width_n_points: int = 100, max_n_samples: int = 500, ) -> None: assert self.plot_border is not None, "Visualization requires a plot border." @@ -655,20 +825,22 @@ def visualize( samples, plot_border=self.plot_border, use_log_reward=True, + grid_width_n_points=grid_width_n_points, + max_n_samples=max_n_samples, ) plt.tight_layout() if show: plt.show() else: - os.makedirs("viz", exist_ok=True) - fig.savefig(f"viz/{prefix}manywell.png") + os.makedirs(OUTPUT_DIR, exist_ok=True) + fig.savefig(f"{OUTPUT_DIR}/{prefix}manywell.png") plt.close() ###################################### -### Diffusion Sampling Environment ### +# Diffusion Sampling Environment # ###################################### @@ -685,6 +857,11 @@ class DiffusionSampling(Env): "gmm2": (SimpleGaussianMixture, {"num_components": 2}), # 2D "gmm4": (SimpleGaussianMixture, {"num_components": 4}), # 2D "gmm8": (SimpleGaussianMixture, {"num_components": 8}), # 2D + "gmm25_prior": (Grid25GaussianMixture, {}), # 2D, fixed 25-mode grid + "gmm25_posterior9": ( + Posterior9of25GaussianMixture, + {}, + ), # 2D, 9-mode posterior reward "easy_funnel": (Funnel, {"std": 1.0}), # 10D "hard_funnel": (Funnel, {"std": 3.0}), # 10D "many_well": (ManyWell, {}), # 32D @@ -763,10 +940,27 @@ class DiffusionSamplingStates(States): def is_initial_state(self) -> torch.Tensor: """Returns a tensor that is True for states that are s0 - When time is close enought to 0.0 (considering floating point errors), + When time is close enough to 0.0 (considering floating point errors), the state is s0. """ - return (self.tensor[..., -1] - 0.0) < env.dt * 1e-2 + eps = env.dt * TERMINAL_TIME_EPS + return self.tensor[..., -1] < eps + + @property + def is_sink_state(self) -> torch.Tensor: + """Return True when time is effectively 1.0 or the sink padding. + + We treat two cases as sink: + - Physical terminal time: t >= 1.0 - eps. + - Padding/exit sink states produced by `make_sink_states`, which use + non-finite sentinel values (e.g., -inf). Using non-finite check keeps + masks aligned for padded rows. + """ + time = self.tensor[..., -1] + eps = env.dt * TERMINAL_TIME_EPS + is_terminal_time = time >= (1.0 - eps) + is_padding_sink = ~torch.isfinite(time) + return is_terminal_time | is_padding_sink return DiffusionSamplingStates @@ -796,6 +990,19 @@ def step(self, states: States, actions: Actions) -> States: Returns: The next states. """ + if self.debug: + + eps = self.dt * TERMINAL_TIME_EPS + # Force exit when the next step would reach/exceed terminal time. + terminal_mask = (states.tensor[..., -1] + self.dt) >= (1.0 - eps) + if terminal_mask.any(): + raise AssertionError( + f"Estimator failed to output exit actions for {terminal_mask.sum().item()} " + f"states at terminal time. This will cause mask misalignment in " + f"get_trajectory_pbs(). Fix the estimator's exit condition to match " + f"TERMINAL_TIME_EPS={TERMINAL_TIME_EPS}." + ) + next_states_tensor = states.tensor.clone() next_states_tensor[..., :-1] = next_states_tensor[..., :-1] + actions.tensor next_states_tensor[..., -1] = next_states_tensor[..., -1] + self.dt @@ -832,15 +1039,18 @@ def is_action_valid( True if the actions are valid, False otherwise. """ time = states.tensor[..., -1].flatten()[0].item() - # TODO: support randomized discretization + eps = self.dt * TERMINAL_TIME_EPS + # TODO: support randomized discretization. assert ( states.tensor[..., -1] == time ).all(), "Time must be the same for all states in the batch" - if not backward and time == 1.0: # Terminate if time == 1.0 for forward steps + if not backward and time >= ( + 1.0 - eps + ): # Terminate if near 1.0 for forward steps sf = cast(torch.Tensor, self.sf) return bool((actions.tensor == sf[:-1]).all().item()) - elif backward and time == 0.0: # Return to s0 if time == 0.0 for backward steps + elif backward and time <= eps: # Return to s0 when near 0.0 for backward steps s0 = cast(torch.Tensor, self.s0) return bool((actions.tensor == s0[:-1]).all().item()) else: diff --git a/src/gfn/gym/helpers/diffusion_utils.py b/src/gfn/gym/helpers/diffusion_utils.py index 8e39067a..4d7274f8 100644 --- a/src/gfn/gym/helpers/diffusion_utils.py +++ b/src/gfn/gym/helpers/diffusion_utils.py @@ -27,6 +27,7 @@ def viz_2d_slice( grid_width_n_points=200, log_reward_clamp_min=-10000.0, use_log_reward=False, + max_n_samples: int | None = None, ) -> None: x_points_dim1 = torch.linspace(plot_border[0], plot_border[1], grid_width_n_points) x_points_dim2 = torch.linspace(plot_border[2], plot_border[3], grid_width_n_points) @@ -46,6 +47,8 @@ def viz_2d_slice( ax.contour(x_points_dim1, x_points_dim2, log_r_x, levels=n_contour_levels) if samples is not None: + if max_n_samples is not None: + samples = samples[:max_n_samples] samples = samples[:, dims].detach().cpu() samples[:, 0] = torch.clamp(samples[:, 0], plot_border[0], plot_border[1]) samples[:, 1] = torch.clamp(samples[:, 1], plot_border[2], plot_border[3]) diff --git a/src/gfn/utils/modules.py b/src/gfn/utils/modules.py index 72697457..393d807a 100644 --- a/src/gfn/utils/modules.py +++ b/src/gfn/utils/modules.py @@ -1667,6 +1667,11 @@ def __init__( hidden_dim: int = 64, joint_layers: int = 2, zero_init: bool = False, + clipping: bool = False, + gfn_clip: float = 1e4, + t_scale: float = 1.0, + log_var_range: float = 4.0, # kept for parity with learned-var subclass + learn_variance: bool = False, # predict_flow: bool, # TODO: support predict flow for db or subtb # share_embeddings: bool = False, # flow_harmonics_dim: int = 64, @@ -1680,7 +1685,6 @@ def __init__( # clipping: bool = False, # TODO: support clipping # out_clip: float = 1e4, # lp_clip: float = 1e2, - # learn_variance: bool = True, # TODO: support learnable variance # log_var_range: float = 4.0, ): """Initialize the PISGradNetForward. @@ -1703,7 +1707,12 @@ def __init__( self.hidden_dim = hidden_dim self.joint_layers = joint_layers self.zero_init = zero_init - self.out_dim = s_dim # 2 * out_dim if learn_variance is True + self.learn_variance = learn_variance + self.out_dim = s_dim + 1 if self.learn_variance else s_dim + self.clipping = clipping + self.gfn_clip = gfn_clip + self.t_scale = t_scale + self.log_var_range = log_var_range assert ( self.s_emb_dim == self.t_emb_dim @@ -1740,10 +1749,18 @@ def forward( t_emb = self.t_model(t) out = self.joint_model(s_emb, t_emb) - # TODO: learn variance, lp, clipping, ... + if self.learn_variance: + drift, raw_log_std = out[..., :-1], out[..., [-1]] + if self.clipping: + drift = torch.clamp(drift, -self.gfn_clip, self.gfn_clip) + log_std = torch.tanh(raw_log_std) * self.log_var_range + out = torch.cat([drift, log_std], dim=-1) + else: + if self.clipping: + out = torch.clamp(out, -self.gfn_clip, self.gfn_clip) + if torch.isnan(out).any(): - print("+ out has {} nans".format(torch.isnan(out).sum())) - out = torch.nan_to_num(out) + raise ValueError("DiffusionPISGradNetForward produced NaNs") return out @@ -1774,3 +1791,83 @@ def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor: The output of the module (shape: (*batch_shape, s_dim)). """ return torch.zeros_like(preprocessed_states[..., :-1]) + + +class DiffusionPISGradNetBackward(nn.Module): + """Learnable backward correction module (PIS-style) for diffusion. + + Produces mean and optional log-std corrections that are tanh-scaled by + `pb_scale_range` to stay close to the analytic Brownian bridge. + """ + + def __init__( + self, + s_dim: int, + harmonics_dim: int = 64, + t_emb_dim: int = 64, + s_emb_dim: int = 64, + hidden_dim: int = 64, + joint_layers: int = 2, + zero_init: bool = False, + clipping: bool = False, + gfn_clip: float = 1e4, + pb_scale_range: float = 0.1, + log_var_range: float = 4.0, + learn_variance: bool = True, + ) -> None: + super().__init__() + self.s_dim = s_dim + self.out_dim = s_dim + (1 if learn_variance else 0) + self.harmonics_dim = harmonics_dim + self.t_emb_dim = t_emb_dim + self.s_emb_dim = s_emb_dim + self.hidden_dim = hidden_dim + self.joint_layers = joint_layers + self.zero_init = zero_init + self.clipping = clipping + self.gfn_clip = gfn_clip + self.pb_scale_range = pb_scale_range + self.log_var_range = log_var_range + self.learn_variance = learn_variance + + assert ( + self.s_emb_dim == self.t_emb_dim + ), "Dimensionality of state embedding and time embedding should be the same!" + + self.t_model = DiffusionPISTimeEncoding( + self.harmonics_dim, self.t_emb_dim, self.hidden_dim + ) + self.s_model = DiffusionPISStateEncoding(self.s_dim, self.s_emb_dim) + self.joint_model = DiffusionPISJointPolicy( + self.s_emb_dim, + self.hidden_dim, + self.out_dim, + self.joint_layers, + self.zero_init, + ) + + def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor: + s = preprocessed_states[..., :-1] + t = preprocessed_states[..., -1] + s_emb = self.s_model(s) + t_emb = self.t_model(t) + out = self.joint_model(s_emb, t_emb) + + if self.clipping: + out = torch.clamp(out, -self.gfn_clip, self.gfn_clip) + + # Tanh-scale to stay near Brownian bridge; last dim (if present) is log-std corr. + drift_corr = torch.tanh(out[..., : self.s_dim]) * self.pb_scale_range + if self.learn_variance and out.shape[-1] == self.s_dim + 1: + log_std_corr = torch.tanh(out[..., [-1]]) * self.pb_scale_range + log_std_corr = torch.clamp( + log_std_corr, -self.log_var_range, self.log_var_range + ) + out = torch.cat([drift_corr, log_std_corr], dim=-1) + else: + out = drift_corr + + if torch.isnan(out).any(): + raise ValueError("DiffusionPISGradNetBackward produced NaNs") + + return out diff --git a/testing/gflownet/test_mle_diffusion.py b/testing/gflownet/test_mle_diffusion.py new file mode 100644 index 00000000..d45755ef --- /dev/null +++ b/testing/gflownet/test_mle_diffusion.py @@ -0,0 +1,172 @@ +import math + +import torch + +from gfn.estimators import PinnedBrownianMotionBackward, PinnedBrownianMotionForward +from gfn.gflownet.mle import MLEDiffusion +from gfn.gym.diffusion_sampling import DiffusionSampling +from gfn.utils.modules import DiffusionFixedBackwardModule + +ENV = DiffusionSampling( + target_str="gmm2", + target_kwargs=None, + num_discretization_steps=100, + device=torch.device("cpu"), + debug=True, +) + + +class ZeroDriftModule(torch.nn.Module): + """Returns zero drift (and optional zero log-std if learn_variance).""" + + def __init__(self, s_dim: int, learn_variance: bool = False): + super().__init__() + self.s_dim = s_dim + self.learn_variance = learn_variance + # Required by IdentityPreprocessor in estimators. + self.input_dim = s_dim + 1 # state dim + time + + def forward(self, x: torch.Tensor) -> torch.Tensor: # x shape: (B, s_dim + 1) + batch = x.shape[0] + if self.learn_variance: + return torch.zeros(batch, self.s_dim + 1, device=x.device, dtype=x.dtype) + return torch.zeros(batch, self.s_dim, device=x.device, dtype=x.dtype) + + +def _build_estimators(s_dim: int, learn_variance: bool, num_steps: int = 1): + """Helper to build deterministic PF/PB for tests.""" + pf_module = ZeroDriftModule(s_dim=s_dim, learn_variance=learn_variance) + pf = PinnedBrownianMotionForward( + s_dim=s_dim, + pf_module=pf_module, + sigma=1.0, + num_discretization_steps=num_steps, + n_variance_outputs=1 if learn_variance else 0, + ) + pb = PinnedBrownianMotionBackward( + s_dim=s_dim, + pb_module=DiffusionFixedBackwardModule(s_dim), + sigma=1.0, + num_discretization_steps=num_steps, + n_variance_outputs=0, + pb_scale_range=0.1, + ) + return pf, pb + + +def test_mle_loss_fixed_variance_zero_terminal(): + """ + With zero drift, fixed variance (sigma=1), num_steps=1, and terminal states at 0, + the loss is deterministic: log(2π) per dimension /2 summed over dim -> log(2π). + """ + torch.manual_seed(0) + s_dim = 2 + pf, pb = _build_estimators(s_dim=s_dim, learn_variance=False, num_steps=1) + trainer = MLEDiffusion( + pf=pf, + pb=pb, + num_steps=1, + sigma=1.0, + t_scale=1.0, + pb_scale_range=0.1, + learn_variance=False, + ) + + batch = torch.zeros(4, s_dim) # terminal states near (0,0) + loss = trainer.loss(ENV, batch, exploration_std=0.0) + + expected_logp = -0.5 * s_dim * math.log(2 * math.pi) # log p for zero increment + expected_loss = -expected_logp # num_steps=1, loss = -logpf_sum.mean() + assert torch.isfinite(loss) + assert torch.allclose(loss, torch.tensor(expected_loss), atol=1e-6) + + +def test_mle_loss_learned_variance_zero_terminal(): + """ + Learned variance head returning log_std=0 should match the fixed-variance case + (std = exp(0)*sqrt(dt)*sqrt(t_scale) = 1 when num_steps=1, t_scale=1). + """ + torch.manual_seed(0) + s_dim = 2 + pf, pb = _build_estimators(s_dim=s_dim, learn_variance=True, num_steps=1) + trainer = MLEDiffusion( + pf=pf, + pb=pb, + num_steps=1, + sigma=1.0, + t_scale=1.0, + pb_scale_range=0.1, + learn_variance=True, + ) + batch = torch.zeros(3, s_dim) + loss = trainer.loss(ENV, batch, exploration_std=0.0) + + expected_logp = -0.5 * s_dim * math.log(2 * math.pi) + expected_loss = -expected_logp + assert torch.isfinite(loss) + assert torch.allclose(loss, torch.tensor(expected_loss), atol=1e-6) + + +def test_backward_bridge_mean_std_match_formula(): + """ + Validate Brownian bridge mean/std against closed form for num_steps=2 at t=1. + For s_curr=0, mean should be 0, std should be sigma*sqrt(dt*(t-dt)/t). + """ + s_dim = 2 + num_steps = 2 + sigma = 1.0 + pf, pb = _build_estimators(s_dim=s_dim, learn_variance=False, num_steps=num_steps) + + # Manually run the PB module once at t=1. + dt = 1.0 / num_steps + bsz = 3 + s_curr = torch.zeros(bsz, s_dim) + t_curr = torch.full((bsz, 1), 1.0) + pb_inp = torch.cat([s_curr, t_curr], dim=1) + pb_out = pb.module(pb_inp) + + is_s0 = (t_curr - dt) < dt * 1e-2 + base_mean = torch.where( + is_s0, + torch.zeros_like(s_curr), + s_curr * (1.0 - dt / t_curr), + ) + base_std = torch.where( + is_s0, + torch.zeros_like(t_curr), + sigma * (dt * (t_curr - dt) / t_curr).sqrt(), + ) + + # For zero corrections, mean_corr=0, corr_std=0. + mean_corr = pb_out[..., :s_dim] * pb.pb_scale_range + assert torch.allclose(mean_corr, torch.zeros_like(mean_corr)) + assert torch.allclose(base_mean, torch.zeros_like(base_mean)) + expected_std = sigma * math.sqrt(dt * (1.0 - dt) / 1.0) + assert torch.allclose(base_std.squeeze(-1), torch.full((bsz,), expected_std)) + + +def test_forward_logprob_zero_increment_matches_formula(): + """ + For PF with zero drift/log_std=0, num_steps=1, t_scale=1, increment=0, + the log-prob per dim is -0.5*log(2π); total logp = that * s_dim. + """ + s_dim = 2 + pf, pb = _build_estimators(s_dim=s_dim, learn_variance=True, num_steps=1) + trainer = MLEDiffusion( + pf=pf, + pb=pb, + num_steps=1, + sigma=1.0, + t_scale=1.0, + pb_scale_range=0.1, + learn_variance=True, + ) + + batch = torch.zeros(2, s_dim) + # Manually compute expected logp for zero increment: + # std = exp(0) * sqrt(dt) * sqrt(t_scale) = 1; logp = -0.5 * s_dim * log(2π) + expected_logp = -0.5 * s_dim * math.log(2 * math.pi) + expected_loss = -expected_logp + loss = trainer.loss(ENV, batch, exploration_std=0.0) + assert torch.isfinite(loss) + assert torch.allclose(loss, torch.tensor(expected_loss), atol=1e-6) diff --git a/testing/gym/test_diffusion_sampling_rtb.py b/testing/gym/test_diffusion_sampling_rtb.py new file mode 100644 index 00000000..8b4c9146 --- /dev/null +++ b/testing/gym/test_diffusion_sampling_rtb.py @@ -0,0 +1,43 @@ +import torch + +from gfn.gym.diffusion_sampling import ( + DiffusionSampling, + Grid25GaussianMixture, + Posterior9of25GaussianMixture, +) + + +def test_gmm25_prior_basic_sampling_and_log_reward(): + env = DiffusionSampling( + target_str="gmm25_prior", + target_kwargs=None, + num_discretization_steps=8, + device=torch.device("cpu"), + debug=True, + ) + assert isinstance(env.target, Grid25GaussianMixture) + x = env.target.sample(batch_size=16) + assert x.shape == (16, env.dim) + log_r = env.target.log_reward(x) + assert log_r.shape == (16,) + assert torch.isfinite(log_r).all() + + +def test_gmm25_posterior9_log_reward_matches_ratio(): + env = DiffusionSampling( + target_str="gmm25_posterior9", + target_kwargs=None, + num_discretization_steps=8, + device=torch.device("cpu"), + debug=True, + ) + assert isinstance(env.target, Posterior9of25GaussianMixture) + x = env.target.sample(batch_size=8) + assert x.shape == (8, env.dim) + + log_r = env.target.log_reward(x) + posterior_log = env.target.posterior.log_prob(x).flatten() + prior_log = env.target.prior.log_reward(x) + + assert torch.allclose(log_r, posterior_log - prior_log, atol=1e-5) + assert torch.isfinite(log_r).all() diff --git a/testing/test_diffusion_estimators.py b/testing/test_diffusion_estimators.py new file mode 100644 index 00000000..62597b15 --- /dev/null +++ b/testing/test_diffusion_estimators.py @@ -0,0 +1,270 @@ +import math + +import torch + +from gfn.estimators import PinnedBrownianMotionBackward, PinnedBrownianMotionForward +from gfn.gym.diffusion_sampling import DiffusionSampling +from gfn.samplers import Sampler +from gfn.utils.modules import DiffusionPISGradNetBackward + + +class _Identity(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + +class _ConstantJoint(torch.nn.Module): + def __init__(self, output: torch.Tensor): + super().__init__() + self.register_buffer("output", output) + + def forward( + self, s_emb: torch.Tensor, t_emb: torch.Tensor + ) -> torch.Tensor: # noqa: ARG002 + batch = s_emb.shape[0] + return self.output.expand(batch, -1) # type: ignore + + +class _ConstantModule(torch.nn.Module): + def __init__(self, output: torch.Tensor, input_dim: int): + super().__init__() + self.register_buffer("output", output) + self.input_dim = input_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: # noqa: ARG002 + batch = x.shape[0] + return self.output.expand(batch, -1) # type: ignore + + +def test_diffusion_pis_gradnet_backward_scales_and_clamps_outputs(): + s_dim = 2 + pb_scale_range = 0.2 + log_var_range = 0.05 + model = DiffusionPISGradNetBackward( + s_dim=s_dim, + harmonics_dim=4, + t_emb_dim=4, + s_emb_dim=4, + hidden_dim=8, + joint_layers=1, + pb_scale_range=pb_scale_range, + log_var_range=log_var_range, + learn_variance=True, + ) + + # Replace heavy components with deterministic stubs. + model.s_model = _Identity() + model.t_model = _Identity() + model.joint_model = _ConstantJoint( + torch.tensor([3.0, -4.0, 50.0], dtype=torch.float32) + ) + + preprocessed = torch.tensor([[0.1, -0.2, 0.3]], dtype=torch.float32) + out = model(preprocessed) + + drift = out[..., :s_dim] + log_std = out[..., -1] + + assert out.shape == (1, s_dim + 1) + assert torch.all(torch.abs(drift) <= pb_scale_range + 1e-6) + assert torch.allclose( + drift[0, 0], + torch.tanh(torch.tensor(3.0)) * pb_scale_range, + atol=1e-4, + ) + # Log-std correction is tanh-bounded then clamped to log_var_range. + assert torch.allclose(log_std, torch.full_like(log_std, log_var_range)) + + +def test_pinned_brownian_forward_marks_exit_on_final_step(): + s_dim = 2 + num_steps = 4 + env = DiffusionSampling( + target_str="gmm2", + target_kwargs={"seed": 0}, + num_discretization_steps=num_steps, + device=torch.device("cpu"), + debug=True, + ) + pf_module = _ConstantModule( + output=torch.zeros(1, s_dim, dtype=torch.float32), + input_dim=s_dim + 1, + ) + pf = PinnedBrownianMotionForward( + s_dim=s_dim, + pf_module=pf_module, + sigma=1.0, + num_discretization_steps=num_steps, + ) + + # t + dt reaches terminal time, so the drift should be converted to exit action (-inf). + terminal_states = env.states_from_tensor( + torch.tensor([[0.0, 0.0, 1.0 - pf.dt]], dtype=torch.float32) + ) + dist = pf.to_probability_distribution(terminal_states, pf(terminal_states)) + assert torch.isinf(dist.loc).all() + + # Earlier times should stay finite. + mid_states = env.states_from_tensor( + torch.tensor([[0.0, 0.0, 0.5]], dtype=torch.float32) + ) + mid_dist = pf.to_probability_distribution(mid_states, pf(mid_states)) + assert torch.isfinite(mid_dist.loc).all() + + +def test_pinned_brownian_forward_exit_condition_matches_steps(): + """Exit masking triggers only on last step according to is_final_step logic.""" + s_dim = 2 + num_steps = 5 + env = DiffusionSampling( + target_str="gmm2", + target_kwargs={"seed": 0}, + num_discretization_steps=num_steps, + device=torch.device("cpu"), + debug=True, + ) + pf_module = _ConstantModule( + output=torch.zeros(1, s_dim, dtype=torch.float32), + input_dim=s_dim + 1, + ) + pf = PinnedBrownianMotionForward( + s_dim=s_dim, + pf_module=pf_module, + sigma=1.0, + num_discretization_steps=num_steps, + ) + + dt = pf.dt + eps = dt * 1e-2 # _DIFFUSION_TERMINAL_TIME_EPS + times = torch.tensor( + [ + 0.0, # initial + dt, # early + 1.0 - 2 * dt, # mid + 1.0 - dt - 0.5 * eps, # should trigger final step mask + 1.0 - dt, # last step before terminal time + ], + dtype=torch.float32, + ) + states = env.states_from_tensor( + torch.stack([torch.zeros_like(times), torch.zeros_like(times), times], dim=1) + ) + + dist = pf.to_probability_distribution(states, pf(states)) + exit_mask = torch.isinf(dist.loc).all(dim=-1) + expected = torch.tensor([False, False, False, True, True]) + assert torch.equal(exit_mask, expected) + + +def test_pinned_brownian_forward_combines_exploration_variance(): + s_dim = 2 + num_steps = 5 + env = DiffusionSampling( + target_str="gmm2", + target_kwargs={"seed": 1}, + num_discretization_steps=num_steps, + device=torch.device("cpu"), + debug=True, + ) + pf_module = _ConstantModule( + output=torch.zeros(1, s_dim + 1, dtype=torch.float32), + input_dim=s_dim + 1, + ) + pf = PinnedBrownianMotionForward( + s_dim=s_dim, + pf_module=pf_module, + sigma=1.0, + num_discretization_steps=num_steps, + n_variance_outputs=1, + ) + + states = env.states_from_tensor(torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float32)) + base_std = math.sqrt(pf.dt) # log_std=0 -> exp(0) * sqrt(dt) + exploration_std = 0.4 + dist = pf.to_probability_distribution( + states, pf(states), exploration_std=exploration_std + ) + + expected = math.sqrt(base_std**2 + exploration_std**2) + assert torch.allclose(dist.scale, torch.full_like(dist.scale, expected), atol=1e-6) + + +def test_pinned_brownian_backward_applies_corrections_and_quadrature(): + s_dim = 2 + num_steps = 4 + pb_scale_range = 0.2 + sigma = 1.5 + env = DiffusionSampling( + target_str="gmm2", + target_kwargs={"seed": 2}, + num_discretization_steps=num_steps, + device=torch.device("cpu"), + debug=True, + ) + pb_module = _ConstantModule( + output=torch.tensor([[5.0, -5.0, 1.0]], dtype=torch.float32), + input_dim=s_dim + 1, + ) + pb = PinnedBrownianMotionBackward( + s_dim=s_dim, + pb_module=pb_module, + sigma=sigma, + num_discretization_steps=num_steps, + n_variance_outputs=1, + pb_scale_range=pb_scale_range, + ) + + t_curr = 0.5 + states = env.states_from_tensor( + torch.tensor([[0.5, -0.25, t_curr]], dtype=torch.float32) + ) + dist = pb.to_probability_distribution(states, pb(states)) + + dt = pb.dt + s_curr = states.tensor[:, :-1] + base_mean = s_curr * dt / t_curr + base_std = sigma * math.sqrt(dt * (t_curr - dt) / t_curr) + + expected_mean = base_mean + torch.tensor([[1.0, -1.0]], dtype=torch.float32) + expected_std = math.sqrt(base_std**2 + math.exp(pb_scale_range) ** 2) + + assert torch.allclose(dist.loc, expected_mean, atol=1e-6) + assert torch.allclose( + dist.scale, torch.full_like(dist.scale, expected_std), atol=1e-6 + ) + + +def test_diffusion_sampler_completes_after_num_steps(): + num_steps = 6 + batch_size = 3 + s_dim = 2 + env = DiffusionSampling( + target_str="gmm2", + target_kwargs={"seed": 3}, + num_discretization_steps=num_steps, + device=torch.device("cpu"), + debug=True, + ) + pf_module = _ConstantModule( + output=torch.zeros(1, s_dim, dtype=torch.float32), input_dim=s_dim + 1 + ) + pf = PinnedBrownianMotionForward( + s_dim=s_dim, + pf_module=pf_module, + sigma=1.0, + num_discretization_steps=num_steps, + ) + sampler = Sampler(estimator=pf) + + trajectories = sampler.sample_trajectories( + env, n=batch_size, save_logprobs=True, save_estimator_outputs=False + ) + + assert torch.all(trajectories.terminating_idx == num_steps) + # The sampler uses the estimator output directly (exit action = -inf) so the final + # state is the sink padding (non-finite). Verify sink detection and exit action. + final_states = trajectories.states[ + trajectories.terminating_idx, torch.arange(batch_size) + ] + assert final_states.is_sink_state.all() + assert trajectories.actions.is_exit[num_steps - 1].all() diff --git a/testing/test_environments.py b/testing/test_environments.py index 81f67428..05c78720 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -7,12 +7,16 @@ from gfn.actions import GraphActions, GraphActionType from gfn.env import Env, NonValidActionsError +from gfn.estimators import PinnedBrownianMotionBackward, PinnedBrownianMotionForward from gfn.gym import Box, DiscreteEBM, HyperGrid +from gfn.gym.diffusion_sampling import DiffusionSampling from gfn.gym.graph_building import GraphBuilding from gfn.gym.perfect_tree import PerfectBinaryTree from gfn.gym.set_addition import SetAddition from gfn.preprocessors import IdentityPreprocessor, KHotPreprocessor, OneHotPreprocessor +from gfn.samplers import Sampler from gfn.states import GraphStates +from gfn.utils.modules import DiffusionFixedBackwardModule, DiffusionPISGradNetForward # Utilities. @@ -893,3 +897,97 @@ def test_env_default_sf_bool_dtype(): assert env.sf.dtype == torch.bool assert isinstance(env.sf, torch.Tensor) assert torch.equal(env.sf, torch.zeros(state_shape, dtype=torch.bool)) + + +def test_diffusion_trajectory_mask_alignment(): + """Test that diffusion trajectory masks align correctly for PB calculation. + + This verifies that the estimator's exit action detection matches the environment's + terminal state detection, ensuring valid_states and valid_actions have the same + count in get_trajectory_pbs. A mismatch would cause an AssertionError. + + The key invariant is: for each trajectory step where we compute PB, we need + exactly one valid state (at t+1) and one valid action (at t). Exit actions + must be properly marked so they're excluded from the action mask. + """ + # Use small config for fast testing. + num_steps = 8 + batch_size = 16 + s_dim = 2 + + env = DiffusionSampling( + target_str="gmm2", + target_kwargs={"seed": 42}, + num_discretization_steps=num_steps, + device=torch.device("cpu"), + ) + + pf_module = DiffusionPISGradNetForward( + s_dim=s_dim, + harmonics_dim=16, + t_emb_dim=16, + s_emb_dim=16, + hidden_dim=32, + joint_layers=1, + ) + pb_module = DiffusionFixedBackwardModule(s_dim=s_dim) + + pf_estimator = PinnedBrownianMotionForward( + s_dim=s_dim, + pf_module=pf_module, + sigma=5.0, + num_discretization_steps=num_steps, + ) + pb_estimator = PinnedBrownianMotionBackward( + s_dim=s_dim, + pb_module=pb_module, + sigma=5.0, + num_discretization_steps=num_steps, + ) + + sampler = Sampler(estimator=pf_estimator) + + # Sample trajectories. + trajectories = sampler.sample_trajectories( + env, + n=batch_size, + save_logprobs=True, + save_estimator_outputs=False, + ) + + # Compute masks the same way get_trajectory_pbs does. + state_mask = ( + ~trajectories.states.is_sink_state & ~trajectories.states.is_initial_state + ) + state_mask[0, :] = False # Can't compute PB for first state row. + action_mask = ~trajectories.actions.is_dummy & ~trajectories.actions.is_exit + + valid_states_count = int(state_mask.sum()) + valid_actions_count = int(action_mask.sum()) + exit_count = int(trajectories.actions.is_exit.sum()) + + # Key assertions: + # 1. Exit actions should be detected (one per trajectory for fixed-length diffusion). + assert exit_count == batch_size, ( + f"Expected {batch_size} exit actions (one per trajectory), got {exit_count}. " + "The estimator may not be marking exit actions correctly." + ) + + # 2. Valid states and actions must match for PB calculation. + assert valid_states_count == valid_actions_count, ( + f"Mask mismatch: {valid_states_count} valid states vs {valid_actions_count} valid actions. " + f"Exit count: {exit_count}. This would cause get_trajectory_pbs to fail." + ) + + # 3. Verify get_trajectory_pbs runs without error (the actual alignment check). + from gfn.utils.prob_calculations import get_trajectory_pfs_and_pbs + + log_pfs, log_pbs = get_trajectory_pfs_and_pbs( + pf_estimator, + pb_estimator, + trajectories, + recalculate_all_logprobs=False, + ) + # Shape is (T, N) = (num_steps, batch_size) - per-step log probs for each trajectory. + assert log_pfs.shape == (num_steps, batch_size) + assert log_pbs.shape == (num_steps, batch_size) diff --git a/testing/test_rtb.py b/testing/test_rtb.py new file mode 100644 index 00000000..28427835 --- /dev/null +++ b/testing/test_rtb.py @@ -0,0 +1,82 @@ +import torch + +from gfn.estimators import DiscretePolicyEstimator +from gfn.gflownet import RelativeTrajectoryBalanceGFlowNet +from gfn.gym import HyperGrid +from gfn.preprocessors import KHotPreprocessor +from gfn.samplers import Sampler +from gfn.utils.modules import MLP + + +def _make_hypergrid_estimators(): + """Build simple forward policies for HyperGrid prior/posterior.""" + env = HyperGrid(ndim=2, height=4) + preproc = KHotPreprocessor(env.height, env.ndim) + assert isinstance(preproc.output_dim, int) + + pf_module_post = MLP(input_dim=preproc.output_dim, output_dim=env.n_actions) + pf_module_prior = MLP(input_dim=preproc.output_dim, output_dim=env.n_actions) + + pf_post = DiscretePolicyEstimator( + module=pf_module_post, + n_actions=env.n_actions, + preprocessor=preproc, + is_backward=False, + ) + pf_prior = DiscretePolicyEstimator( + module=pf_module_prior, + n_actions=env.n_actions, + preprocessor=preproc, + is_backward=False, + ) + return env, pf_post, pf_prior + + +def test_rtb_loss_backward_and_grads(): + torch.manual_seed(0) + env, pf_post, pf_prior = _make_hypergrid_estimators() + + gfn = RelativeTrajectoryBalanceGFlowNet( + pf=pf_post, + prior_pf=pf_prior, + init_logZ=0.0, + beta=1.0, + ) + sampler = Sampler(estimator=pf_post) + trajectories = sampler.sample_trajectories( + env, n=8, save_logprobs=True, save_estimator_outputs=False + ) + + loss = gfn.loss(env, trajectories, recalculate_all_logprobs=True) + assert torch.isfinite(loss) + + loss.backward() + + # Posterior parameters and logZ should receive gradients. + assert any(p.grad is not None for p in pf_post.parameters()) + assert any(p.grad is not None for p in gfn.logz_parameters()) + + # Prior parameters are not part of the RTB graph and should have no grads. + assert all(p.grad is None for p in pf_prior.parameters()) + + +def test_rtb_loss_forward_only_path(): + """Ensure RTB loss works with recalculate_all_logprobs=False.""" + torch.manual_seed(1) + env, pf_post, pf_prior = _make_hypergrid_estimators() + + gfn = RelativeTrajectoryBalanceGFlowNet( + pf=pf_post, + prior_pf=pf_prior, + init_logZ=0.0, + beta=0.5, + ) + sampler = Sampler(estimator=pf_post) + trajectories = sampler.sample_trajectories( + env, n=4, save_logprobs=True, save_estimator_outputs=False + ) + + # Use cached log_probs; should not rely on any backward policy. + loss = gfn.loss(env, trajectories, recalculate_all_logprobs=False) + assert torch.isfinite(loss) + loss.backward() diff --git a/tutorials/examples/output/.gitignore b/tutorials/examples/output/.gitignore new file mode 100644 index 00000000..e5513ab3 --- /dev/null +++ b/tutorials/examples/output/.gitignore @@ -0,0 +1,5 @@ +*.pt +*.jpg +*.jpeg +*.png +*.zip \ No newline at end of file diff --git a/tutorials/examples/train_diffusion_rtb.py b/tutorials/examples/train_diffusion_rtb.py new file mode 100644 index 00000000..c1c9d475 --- /dev/null +++ b/tutorials/examples/train_diffusion_rtb.py @@ -0,0 +1,548 @@ +#!/usr/bin/env python +""" +Minimal end-to-end Relative Trajectory Balance (RTB) fine-tuning training script for +diffusion models. + +- Prior is pre-trained (auto-runs if the prior checkpoint is missing), so + finetuning starts from a learned prior. +- Posterior is fine-tuned from this prior (pf). + +By default, uses the 25→9 GMM posterior target (`gmm25_posterior9`) with a +learnable posterior forward policy and a fixed prior forward policy. Loss is RTB (no +backward policy). This script outputs the prior weights alongside plots of samples +from both the prior and posterior distributions. +""" + +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt +import torch +from tqdm import tqdm + +from gfn.estimators import PinnedBrownianMotionBackward, PinnedBrownianMotionForward +from gfn.gflownet import RelativeTrajectoryBalanceGFlowNet +from gfn.gflownet.mle import MLEDiffusion +from gfn.gym.diffusion_sampling import DiffusionSampling +from gfn.gym.helpers.diffusion_utils import viz_2d_slice +from gfn.samplers import Sampler +from gfn.utils.common import set_seed +from gfn.utils.modules import ( + DiffusionFixedBackwardModule, + DiffusionPISGradNetBackward, + DiffusionPISGradNetForward, +) + + +def resolve_output_paths(args: argparse.Namespace) -> argparse.Namespace: + """Resolve all output paths relative to this script's directory.""" + script_dir = Path(__file__).resolve().parent + output_dir = Path(args.output_dir) + if not output_dir.is_absolute(): + output_dir = script_dir / output_dir + output_dir = output_dir.expanduser().resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + args.output_dir = output_dir + args.prior_ckpt_path = output_dir / "train_diffusion_rtb_prior_ckpt.pt" + args.pretrain_save_fig_path = output_dir / "train_diffusion_rtb_prior_samples.png" + args.save_fig_path = output_dir / "train_diffusion_rtb_posterior_samples.png" + + return args + + +def forward_kwargs( + args: argparse.Namespace, + s_dim: int, + num_steps: int, + sigma: float, + device: torch.device, +) -> dict: + return dict( + s_dim=s_dim, + num_steps=num_steps, + sigma=sigma, + harmonics_dim=args.harmonics_dim, + t_emb_dim=args.t_emb_dim, + s_emb_dim=args.s_emb_dim, + hidden_dim=args.hidden_dim, + joint_layers=args.joint_layers, + zero_init=args.zero_init, + learn_variance=args.learn_var, + clipping=args.clipping, + gfn_clip=args.gfn_clip, + t_scale=args.t_scale, + log_var_range=args.log_var_range, + device=device, + ) + + +def get_debug_metrics(estimator: torch.nn.Module) -> tuple[torch.Tensor, bool]: + """Compute gradient norm for a module; return (total_norm, has_nan).""" + grad_list = [p.grad.norm() for p in estimator.parameters() if p.grad is not None] + if grad_list: + total_norm = torch.norm(torch.stack(grad_list)) + else: + total_norm = torch.tensor(0.0, device=next(estimator.parameters()).device) + has_nan = torch.isnan(total_norm) + return total_norm, bool(has_nan) + + +def get_exploration_std( + iteration: int, + exploration_factor: float = 0.1, + warm_down_start: int = 500, + warm_down_end: int = 4500, + device: torch.device | None = None, + dtype: torch.dtype | None = None, +) -> torch.Tensor: + """Return a callable exploration std schedule for state-space noise. + + When exploration is enabled, return a step-index function that emits a fixed + std for the current training iteration, optionally linearly warmed down + after warm_down_start iters toward 0 by warm_down_end iters. + """ + device = device or torch.get_default_device() + dtype = dtype or torch.get_default_dtype() + + # Tensor ops only (torch.compile-friendly): no Python branching on iteration. + iter_t = torch.tensor(iteration, device=device, dtype=dtype) + # Clamp negatives to zero to avoid Python-side checks/overhead. + factor_t = torch.clamp( + torch.tensor(exploration_factor, device=device, dtype=dtype), min=0.0 + ) + start_t = torch.tensor(warm_down_start, device=device, dtype=dtype) + end_t = torch.tensor(warm_down_end, device=device, dtype=dtype) + + # Phase indicator: 1 before warm_down_start, linear decay afterward. + progress = torch.clamp(iter_t / end_t, min=0.0, max=1.0) + decay = torch.where( + iter_t < start_t, torch.ones_like(progress), torch.clamp(1.0 - progress, min=0.0) + ) + exploration_std = factor_t * decay + + return exploration_std + + +def build_forward_estimator( + s_dim: int, + num_steps: int, + sigma: float, + harmonics_dim: int, + t_emb_dim: int, + s_emb_dim: int, + hidden_dim: int, + joint_layers: int, + zero_init: bool, + learn_variance: bool, + clipping: bool, + gfn_clip: float, + t_scale: float, + log_var_range: float, + device: torch.device, +) -> PinnedBrownianMotionForward: + pf_module = DiffusionPISGradNetForward( + s_dim=s_dim, + harmonics_dim=harmonics_dim, + t_emb_dim=t_emb_dim, + s_emb_dim=s_emb_dim, + hidden_dim=hidden_dim, + joint_layers=joint_layers, + zero_init=zero_init, + clipping=clipping, + gfn_clip=gfn_clip, + t_scale=t_scale, + log_var_range=log_var_range, + learn_variance=learn_variance, + ) + + return PinnedBrownianMotionForward( + s_dim=s_dim, + pf_module=pf_module, + sigma=sigma, + num_discretization_steps=num_steps, + n_variance_outputs=1 if learn_variance else 0, + ).to(device) + + +def pretrain_prior(args: argparse.Namespace, device: torch.device, s_dim: int) -> None: + """ + Auto-pretrain the prior if the checkpoint is missing. + Saves to args.prior_ckpt_path and returns the resolved path. + """ + ckpt_path = Path(args.prior_ckpt_path) + + if ckpt_path.exists(): + if args.clobber_pretrained_prior: + print(f"[pretrain] Clobbering existing prior checkpoint at {ckpt_path}") + ckpt_path.unlink() + else: + return + + print(f"[pretrain] Prior checkpoint missing at {ckpt_path}, starting pretraining...") + + env_prior = DiffusionSampling( + target_str=args.pretrain_target, + target_kwargs=None, + num_discretization_steps=args.pretrain_num_steps, + device=device, + debug=__debug__, + ) + + pf_prior = build_forward_estimator( + **forward_kwargs( + args, + s_dim=s_dim, + num_steps=args.pretrain_num_steps, + sigma=args.pretrain_sigma, + device=device, + ) + ) + + # Build backward estimator: learned pb if enabled, else fixed Brownian bridge. + if args.pretrain_learn_pb: + pb_module = DiffusionPISGradNetBackward( + s_dim=s_dim, + harmonics_dim=args.harmonics_dim, + t_emb_dim=args.t_emb_dim, + s_emb_dim=args.s_emb_dim, + hidden_dim=args.hidden_dim, + joint_layers=args.joint_layers, + zero_init=args.zero_init, + clipping=args.clipping, + gfn_clip=args.gfn_clip, + pb_scale_range=args.pb_scale_range, + log_var_range=args.log_var_range, + learn_variance=args.learn_var, + ) + n_var_outputs = 1 if args.learn_var else 0 + pb_scale_range = args.pb_scale_range + else: + pb_module = DiffusionFixedBackwardModule(s_dim) + n_var_outputs = 0 + pb_scale_range = 0.0 + + pb_prior = PinnedBrownianMotionBackward( + s_dim=s_dim, + pb_module=pb_module, + sigma=args.pretrain_sigma, + num_discretization_steps=args.pretrain_num_steps, + n_variance_outputs=n_var_outputs, + pb_scale_range=pb_scale_range, + ).to(device) + + optim_params = [{"params": pf_prior.parameters(), "lr": args.lr}] + if args.pretrain_learn_pb: + optim_params.append({"params": pb_prior.parameters(), "lr": args.lr}) + optimizer = torch.optim.Adam( + optim_params, + lr=args.lr, + weight_decay=args.weight_decay, + ) + + # MLE trainer (uses forward PF and optional PB). + mle_trainer = MLEDiffusion( + pf=pf_prior, + pb=pb_prior, + num_steps=args.pretrain_num_steps, + sigma=args.pretrain_sigma, + t_scale=args.t_scale, + pb_scale_range=args.pb_scale_range, + learn_variance=args.learn_var, + debug=__debug__, + ) + + def _save_checkpoint(pf_prior, pb_prior, optimizer, it, ckpt_path): + ckpt_path.parent.mkdir(parents=True, exist_ok=True) + torch.save( + { + "pf_state_dict": pf_prior.state_dict(), + "pb_state_dict": pb_prior.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "step": it + 1, + }, + ckpt_path, + ) + + pf_prior.train() + pbar = tqdm(range(args.pretrain_steps), dynamic_ncols=True, desc="pretrain_prior") + + for it in pbar: + with torch.no_grad(): + batch = env_prior.target.sample(args.batch_size) + optimizer.zero_grad() + loss = mle_trainer.loss( + env_prior, batch, exploration_std=args.pretrain_exploration_factor + ) + loss.backward() + if __debug__: + total_norm, has_nan = get_debug_metrics(pf_prior) + print( + f"[pretrain][debug] step={it} loss={loss.item():.4e} grad_norm={total_norm.item():.4e}" + ) + if has_nan: + raise ValueError("NaN grad norm in pretrain.") + + optimizer.step() + + # Log progress only. + if (it + 1) % args.pretrain_log_interval == 0 or it == args.pretrain_steps - 1: + pbar.set_postfix({"loss": float(loss.item())}) + + # Final checkpoint after pretraining (no intermediate resume support). + _save_checkpoint(pf_prior, pb_prior, optimizer, it, ckpt_path) + print(f"[pretrain] Saved prior to {ckpt_path}") + + # Quick visual check of the learned prior. + with torch.no_grad(): + sampler_prior = Sampler(estimator=pf_prior) + trajectories = sampler_prior.sample_trajectories( + env=env_prior, + n=args.pretrain_vis_n, + ) + xs = trajectories.terminating_states.tensor[:, :-1] + plot_samples( + xs, + env_prior.target, + "RTB Prior Samples", + args.pretrain_save_fig_path, + return_fig=False, + ) + print(f"[pretrain] Saved prior samples plot to {args.pretrain_save_fig_path}") + + +def plot_samples( + xs: torch.Tensor, + target, + title: str, + save_path: Path | str, + return_fig: bool = False, +): + """Contour + scatter plot of samples against the posterior density.""" + assert target.plot_border is not None, "Target must define plot_border for plotting." + + # If target exposes a posterior density, build a lightweight shim with the same + # interface that viz_2d_slice expects (log_reward, dim, device, plot_border). + if hasattr(target, "posterior"): + # Use a shallow copy and replace log_reward to return posterior density + viz_target = target + + def _posterior_log_reward(x: torch.Tensor) -> torch.Tensor: + return viz_target.posterior.log_prob(x).flatten() + + viz_target.log_reward = _posterior_log_reward # type: ignore[attr-defined] + else: + viz_target = target + + fig, ax = plt.subplots(1, 1, figsize=(6, 6)) + viz_2d_slice( + ax, + viz_target, + (0, 1), + samples=xs, + plot_border=viz_target.plot_border, + use_log_reward=True, + grid_width_n_points=200, + max_n_samples=2000, + ) + + ax.set_title(title) + fig.tight_layout() + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path) + + if return_fig: + return fig + + plt.close(fig) + + return None + + +def main(args: argparse.Namespace) -> None: + """Runs the posterior finetuning pipeline, including prior pretraining if required.""" + args = resolve_output_paths(args) + set_seed(args.seed) + device = torch.device(args.device) + torch.set_default_device(device) + + # Environment / target + env = DiffusionSampling( + target_str=args.target, + target_kwargs=None, + num_discretization_steps=args.num_steps, + device=device, + debug=__debug__, + ) + s_dim = env.dim + + # Posterior forward (trainable) + pf_post = build_forward_estimator( + **forward_kwargs( + args, s_dim=s_dim, num_steps=args.num_steps, sigma=args.sigma, device=device + ) + ) + + # Prior forward. + pf_prior = build_forward_estimator( + **forward_kwargs( + args, s_dim=s_dim, num_steps=args.num_steps, sigma=args.sigma, device=device + ) + ) + + # Pretrain prior if needed, then load weights into both prior and posterior so + # finetuning starts from the learned prior. + pretrain_prior(args, device, s_dim) + + if args.prior_ckpt_path.exists(): + ckpt = torch.load(args.prior_ckpt_path, map_location=device) + state = ckpt.get("pf_state_dict", ckpt) + missing, unexpected = pf_prior.load_state_dict(state, strict=False) + if missing or unexpected: + print(f"[warn] prior load missing={missing}, unexpected={unexpected}") + # Initialize posterior from the same prior weights. + pf_post.load_state_dict(pf_prior.state_dict(), strict=False) + else: + raise Exception( + f"pretrained weights not found at {args.prior_ckpt_path}, pretraining failed" + ) + + # During finetuning, the prior is fixed, no grad. + pf_prior.eval() + for p in pf_prior.parameters(): + p.requires_grad_(False) + + gflownet = RelativeTrajectoryBalanceGFlowNet( + pf=pf_post, + prior_pf=pf_prior, + init_logZ=0.0, + beta=args.beta, + ).to(device) + + sampler = Sampler(estimator=pf_post) + + param_groups = [ + {"params": gflownet.pf.parameters(), "lr": args.lr}, + {"params": gflownet.logz_parameters(), "lr": args.lr_logz}, + ] + optimizer = torch.optim.Adam( + param_groups, lr=args.lr, weight_decay=args.weight_decay + ) + + for it in (pbar := tqdm(range(args.n_iterations), dynamic_ncols=True)): + trajectories = sampler.sample_trajectories( + env, + n=args.batch_size, + save_logprobs=False, + save_estimator_outputs=False, + # Extra exploration noise (combined with base PF variance in estimator). + exploration_std=get_exploration_std( + iteration=it, + exploration_factor=args.exploration_factor, + warm_down_start=args.exploration_warm_down_start, + warm_down_end=args.exploration_warm_down_end, + ), + ) + + optimizer.zero_grad() + loss = gflownet.loss(env, trajectories, recalculate_all_logprobs=True) + loss.backward() + optimizer.step() + + if (it + 1) % args.log_interval == 0 or it == args.n_iterations - 1: + with torch.no_grad(): + term_states = gflownet.sample_terminating_states(env, n=args.eval_n) + rewards = env.target.log_reward(term_states.tensor[:, :-1]) + avg_reward = rewards.mean().item() + pbar.set_postfix({"loss": float(loss.item()), "avg_reward": avg_reward}) + else: + pbar.set_postfix({"loss": float(loss.item())}) + + # Final visualization + with torch.no_grad(): + samples_states = gflownet.sample_terminating_states(env, n=args.vis_n) + xs = samples_states.tensor[:, :-1] + plot_samples( + xs, + env.target, + "RTB Posterior Samples", + args.save_fig_path, + return_fig=False, + ) + print(f"Saved final samples scatter to {args.save_fig_path}") + + +if __name__ == "__main__": + + def add_arg_group( + parser: argparse.ArgumentParser, + specs: list[tuple[tuple[str, ...], dict]], + ) -> None: + for args, kwargs in specs: + parser.add_argument(*args, **kwargs) + + # fmt: off + parser = argparse.ArgumentParser() + system_args = [ + (("--device",), {"type": str, "default": "cpu", "choices": ["cpu", "cuda", "mps"], "help": "Device for training."}), + (("--seed",), {"type": int, "default": 0, "help": "Random seed"}), + ] + + finetune_args = [ + (("--target",), {"type": str, "default": "gmm25_posterior9", "help": "Diffusion target"}), + (("--num_steps",), {"type": int, "default": 100, "help": "Discretization steps"}), + (("--sigma",), {"type": float, "default": 2.0, "help": "Pinned Brownian motion sigma"}), + (("--harmonics_dim",), {"type": int, "default": 64}), + (("--t_emb_dim",), {"type": int, "default": 64}), + (("--s_emb_dim",), {"type": int, "default": 64}), + (("--hidden_dim",), {"type": int, "default": 64}), + (("--joint_layers",), {"type": int, "default": 2}), + (("--zero_init",), {"action": "store_true", "default": True}), + (("--learn_var",), {"action": "store_true", "default": False, "help": "Learned variance"}), + (("--clipping",), {"action": argparse.BooleanOptionalAction, "default": False, "help": "Clip model outputs"}), + (("--gfn_clip",), {"type": float, "default": 1e4, "help": "Drift clip value"}), + (("--t_scale",), {"type": float, "default": 5.0, "help": "Diffusion std scale"}), + (("--log_var_range",), {"type": float, "default": 4.0, "help": "Bound for learned log-std"}), + ] + + pretrain_args = [ + (("--clobber_pretrained_prior",), {"action": "store_true", "default": False, "help": "Overwrite existing prior"}), + (("--pretrain_learn_pb",), {"action": "store_true", "default": False, "help": "Enable learned backward policy"}), + (("--pb_scale_range",), {"type": float, "default": 0.1, "help": "Tanh scaling for pb"}), + (("--pretrain_target",), {"type": str, "default": "gmm25_prior", "help": "Target used for pretraining"}), + (("--pretrain_num_steps",), {"type": int, "default": 100, "help": "Pretrain discretization steps"}), + (("--pretrain_sigma",), {"type": float, "default": 2.0, "help": "Pretrain diffusion sigma"}), + (("--pretrain_exploration_factor",), {"type": float, "default": 0.0, "help": "Pretrain std expansion"}), + (("--pretrain_steps",), {"type": int, "default": 10000, "help": "Pretrain steps"}), + (("--pretrain_log_interval",), {"type": int, "default": 100, "help": "Pretrain log interval"}), + (("--pretrain_vis_n",), {"type": int, "default": 2000, "help": "Pretrain samples to plot"}), + ] + + train_args = [ + (("--n_iterations",), {"type": int, "default": 5000}), + (("--batch_size",), {"type": int, "default": 500}), + (("--lr",), {"type": float, "default": 1e-3}), + (("--lr_logz",), {"type": float, "default": 1e-1}), + (("--weight_decay",), {"type": float, "default": 0.0, "help": "Weight decay"}), + (("--beta",), {"type": float, "default": 1.0, "help": "RTB beta"}), + (("--exploration_factor",), {"type": float, "default": 0.5, "help": "Step-wise std expansion"}), + (("--exploration_warm_down_start",), {"type": float, "default": 500, "help": "Warmdown start iter"}), + (("--exploration_warm_down_end",), {"type": float, "default": 4500, "help": "Warmdown end iter"}), + ] + + log_args = [ + (("--log_interval",), {"type": int, "default": 100}), + (("--eval_n",), {"type": int, "default": 500}), + (("--vis_n",), {"type": int, "default": 2000, "help": "Samples for final plot"}), + (("--output_dir",), {"type": str, "default": "output", "help": "relative output dir"}), + ] + + add_arg_group(parser, system_args) + add_arg_group(parser, finetune_args) + add_arg_group(parser, pretrain_args) + add_arg_group(parser, train_args) + add_arg_group(parser, log_args) + + args = parser.parse_args() + main(args)