-
Notifications
You must be signed in to change notification settings - Fork 55
Relative trajectory balance #457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 8 commits
ee17515
85190b0
6ecd0ca
08258ad
6ea54cd
a0666da
f771c2c
51ae63d
4aebf20
f685b86
e647877
06ccc10
bfbbc22
8f8cadb
bb2cb45
69efb35
dd2c4e4
c8eb351
63a4bd1
46807ed
b40655b
74b8a60
7deba0e
a2f5a6c
ee22f3c
f8d5c5b
ad457f6
f439a78
b141f3f
97a2453
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,3 +1,4 @@ | ||||||||||||||||||||||||||||||
| import math | ||||||||||||||||||||||||||||||
| from abc import ABC, abstractmethod | ||||||||||||||||||||||||||||||
| from collections import defaultdict | ||||||||||||||||||||||||||||||
| from typing import Any, Callable, Dict, List, Optional, Protocol, cast, runtime_checkable | ||||||||||||||||||||||||||||||
|
|
@@ -1290,6 +1291,7 @@ def __init__( | |||||||||||||||||||||||||||||
| pf_module: nn.Module, | ||||||||||||||||||||||||||||||
| sigma: float, | ||||||||||||||||||||||||||||||
| num_discretization_steps: int, | ||||||||||||||||||||||||||||||
| n_variance_outputs: int = 0, | ||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||
| """Initialize the PinnedBrownianMotionForward. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
@@ -1305,6 +1307,12 @@ def __init__( | |||||||||||||||||||||||||||||
| self.sigma = sigma | ||||||||||||||||||||||||||||||
| self.num_discretization_steps = num_discretization_steps | ||||||||||||||||||||||||||||||
| self.dt = 1.0 / self.num_discretization_steps | ||||||||||||||||||||||||||||||
| self.n_variance_outputs = n_variance_outputs | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||||||
| def expected_output_dim(self) -> int: | ||||||||||||||||||||||||||||||
| # Drift (s_dim) plus optional variance outputs. | ||||||||||||||||||||||||||||||
| return self.s_dim + self.n_variance_outputs | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def forward(self, input: States) -> torch.Tensor: | ||||||||||||||||||||||||||||||
| """Forward pass of the module. | ||||||||||||||||||||||||||||||
|
|
@@ -1329,7 +1337,6 @@ def to_probability_distribution( | |||||||||||||||||||||||||||||
| states: States, | ||||||||||||||||||||||||||||||
| module_output: torch.Tensor, | ||||||||||||||||||||||||||||||
| **policy_kwargs: Any, | ||||||||||||||||||||||||||||||
| # TODO: add epsilon-noisy exploration | ||||||||||||||||||||||||||||||
| ) -> IsotropicGaussian: | ||||||||||||||||||||||||||||||
| """Transform the output of the module into a IsotropicGaussian distribution, | ||||||||||||||||||||||||||||||
| which is the distribution of the next states under the pinned Brownian motion | ||||||||||||||||||||||||||||||
|
|
@@ -1339,24 +1346,66 @@ def to_probability_distribution( | |||||||||||||||||||||||||||||
| states: The states to use, states.tensor.shape = (*batch_shape, s_dim + 1). | ||||||||||||||||||||||||||||||
| module_output: The output of the module (actions), as a tensor of shape | ||||||||||||||||||||||||||||||
| (*batch_shape, s_dim). | ||||||||||||||||||||||||||||||
| **policy_kwargs: Keyword arguments to modify the distribution. | ||||||||||||||||||||||||||||||
| **policy_kwargs: Keyword arguments to modify the distribution. Supported | ||||||||||||||||||||||||||||||
| keys: | ||||||||||||||||||||||||||||||
| - exploration_std: Optional callable or float controlling extra | ||||||||||||||||||||||||||||||
| exploration noise on top of the base diffusion std. The callable | ||||||||||||||||||||||||||||||
| should accept an integer step index and return a non-negative | ||||||||||||||||||||||||||||||
| standard deviation in state space. When provided, the extra noise | ||||||||||||||||||||||||||||||
| is combined in variance-space (logaddexp) with the base diffusion | ||||||||||||||||||||||||||||||
| variance; non-positive exploration is ignored. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||
| A IsotropicGaussian distribution (distribution of the next states) | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| assert len(states.batch_shape) == 1, "States must have a batch_shape of length 1" | ||||||||||||||||||||||||||||||
| s_curr = states.tensor[:, :-1] | ||||||||||||||||||||||||||||||
| # s_curr = states.tensor[:, :-1] | ||||||||||||||||||||||||||||||
| t_curr = states.tensor[:, [-1]] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| module_output = torch.where( | ||||||||||||||||||||||||||||||
| (1.0 - t_curr) < self.dt * 1e-2, # sf case; when t_curr is 1.0 | ||||||||||||||||||||||||||||||
| torch.full_like(s_curr, -float("inf")), # This is the exit action | ||||||||||||||||||||||||||||||
| # torch.full_like(s_curr, -float("inf")), # This is the exit action | ||||||||||||||||||||||||||||||
| torch.full_like(module_output, -float("inf")), # This is the exit action | ||||||||||||||||||||||||||||||
| module_output, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| fwd_mean = self.dt * module_output | ||||||||||||||||||||||||||||||
| fwd_std = torch.tensor(self.sigma * self.dt**0.5, device=fwd_mean.device) | ||||||||||||||||||||||||||||||
| fwd_std = fwd_std.repeat(fwd_mean.shape[0], 1) | ||||||||||||||||||||||||||||||
| drift = module_output[..., : self.s_dim] | ||||||||||||||||||||||||||||||
| if self.n_variance_outputs > 0: | ||||||||||||||||||||||||||||||
| var_part = module_output[..., self.s_dim :] | ||||||||||||||||||||||||||||||
| # Reduce extra variance dims to a single scalar (isotropic for now). | ||||||||||||||||||||||||||||||
| log_std = var_part.mean(dim=-1, keepdim=True) | ||||||||||||||||||||||||||||||
| fwd_std = torch.exp(log_std) * math.sqrt(self.dt) | ||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||
| fwd_std = torch.tensor(self.sigma * self.dt**0.5, device=drift.device) | ||||||||||||||||||||||||||||||
| fwd_std = fwd_std.repeat(drift.shape[0], 1) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Match reference behavior: scale diffusion noise (not drift) by t_scale if present. | ||||||||||||||||||||||||||||||
| t_scale_factor = getattr(self.module, "t_scale", 1.0) | ||||||||||||||||||||||||||||||
| if t_scale_factor != 1.0: | ||||||||||||||||||||||||||||||
| fwd_std = fwd_std * math.sqrt(t_scale_factor) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| fwd_mean = self.dt * drift | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Optional exploration noise: combine variances (quadrature/logaddexp). | ||||||||||||||||||||||||||||||
| exploration_std = policy_kwargs.pop("exploration_std", None) | ||||||||||||||||||||||||||||||
| exploration_std_t = torch.as_tensor( | ||||||||||||||||||||||||||||||
| exploration_std if exploration_std is not None else 0.0, | ||||||||||||||||||||||||||||||
| device=fwd_std.device, | ||||||||||||||||||||||||||||||
| dtype=fwd_std.dtype, | ||||||||||||||||||||||||||||||
| ).clamp(min=0.0) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Combine base diffusion variance σ_base^2 with exploration variance σ_expl^2: | ||||||||||||||||||||||||||||||
| # σ_combined = sqrt(σ_base^2 + σ_expl^2). torch.compile friendly. | ||||||||||||||||||||||||||||||
| base_log_var = 2 * fwd_std.log() # log(σ_base^2) | ||||||||||||||||||||||||||||||
|
Comment on lines
+1412
to
+1414
|
||||||||||||||||||||||||||||||
| # Combine base diffusion variance σ_base^2 with exploration variance σ_expl^2: | |
| # σ_combined = sqrt(σ_base^2 + σ_expl^2). torch.compile friendly. | |
| base_log_var = 2 * fwd_std.log() # log(σ_base^2) | |
| # If there is no positive exploration noise, keep the base diffusion std. | |
| # This avoids unnecessary log operations and potential log(0) issues. | |
| if exploration_std_t.eq(0).all(): | |
| return IsotropicGaussian(fwd_mean, fwd_std) | |
| # Combine base diffusion variance σ_base^2 with exploration variance σ_expl^2: | |
| # σ_combined = sqrt(σ_base^2 + σ_expl^2). torch.compile friendly. | |
| # Clamp fwd_std to a small positive value before taking the log to avoid | |
| # numerical issues when fwd_std is extremely small or zero. | |
| safe_fwd_std = fwd_std.clamp_min(1e-12) | |
| base_log_var = 2 * safe_fwd_std.log() # log(σ_base^2) |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hyeok9855 this might have been the cause of the problem you had before in your code (learning slower), worth checking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hyeok9855 actually- this was my bug!!
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,11 @@ | |
| from gfn.estimators import Estimator | ||
| from gfn.samplers import Sampler | ||
| from gfn.states import States | ||
| from gfn.utils.prob_calculations import get_trajectory_pfs_and_pbs | ||
| from gfn.utils.prob_calculations import ( | ||
| get_trajectory_pbs, | ||
| get_trajectory_pfs, | ||
| get_trajectory_pfs_and_pbs, | ||
| ) | ||
|
|
||
| TrainingSampleType = TypeVar("TrainingSampleType", bound=Container) | ||
|
|
||
|
|
@@ -343,6 +347,32 @@ def get_pfs_and_pbs( | |
| recalculate_all_logprobs, | ||
| ) | ||
|
|
||
| def trajectory_log_probs_forward( | ||
| self, | ||
| trajectories: Trajectories, | ||
| fill_value: float = 0.0, | ||
| recalculate_all_logprobs: bool = True, | ||
| ) -> torch.Tensor: | ||
| """Evaluates forward logprobs only for each trajectory in the batch.""" | ||
| return get_trajectory_pfs( | ||
| self.pf, | ||
| trajectories, | ||
| fill_value=fill_value, | ||
| recalculate_all_logprobs=recalculate_all_logprobs, | ||
| ) | ||
|
|
||
| def trajectory_log_probs_backward( | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. leaving these here because they might come in handy, but I don't think they're actually needed right now in this implementation. |
||
| self, | ||
| trajectories: Trajectories, | ||
| fill_value: float = 0.0, | ||
| ) -> torch.Tensor: | ||
| """Evaluates backward logprobs only for each trajectory in the batch.""" | ||
| return get_trajectory_pbs( | ||
| self.pb, | ||
| trajectories, | ||
| fill_value=fill_value, | ||
| ) | ||
|
|
||
| def get_scores( | ||
| self, | ||
| trajectories: Trajectories, | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |||||
| and the [Log Partition Variance loss](https://arxiv.org/abs/2302.05446). | ||||||
| """ | ||||||
|
|
||||||
| import math | ||||||
| from typing import cast | ||||||
|
|
||||||
| import torch | ||||||
|
|
@@ -16,6 +17,7 @@ | |||||
| is_callable_exception_handler, | ||||||
| warn_about_recalculating_logprobs, | ||||||
| ) | ||||||
| from gfn.utils.prob_calculations import get_trajectory_pfs | ||||||
|
|
||||||
|
|
||||||
| class TBGFlowNet(TrajectoryBasedGFlowNet): | ||||||
|
|
@@ -132,6 +134,120 @@ def loss( | |||||
| return loss | ||||||
|
|
||||||
|
|
||||||
| class RelativeTrajectoryBalanceGFlowNet(TrajectoryBasedGFlowNet): | ||||||
| r"""GFlowNet for the Relative Trajectory Balance (RTB) loss. | ||||||
|
|
||||||
| This objective matches a posterior sampler to a prior diffusion (or other | ||||||
| sequential) model by minimizing | ||||||
|
|
||||||
| .. math:: | ||||||
|
|
||||||
| \left(\log Z_\phi + \log p_\phi(\tau) - \log p_\theta(\tau) | ||||||
| - \beta \log r(x_T)\right)^2, | ||||||
|
|
||||||
| where :math:`p_\theta` is a fixed prior process, :math:`p_\phi` is the | ||||||
| learnable posterior, :math:`r` is a positive reward/constraint on the | ||||||
| terminal state :math:`x_T`, and :math:`\log Z_\phi` is a learned scalar | ||||||
| normalizer. | ||||||
| """ | ||||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| pf: Estimator, | ||||||
| prior_pf: Estimator, | ||||||
| *, | ||||||
| logZ: nn.Parameter | ScalarEstimator | None = None, | ||||||
| init_logZ: float = 0.0, | ||||||
| beta: float = 1.0, | ||||||
| log_reward_clip_min: float = -float("inf"), | ||||||
| debug: bool = False, | ||||||
| ): | ||||||
| """Initializes an RTB GFlowNet. | ||||||
|
|
||||||
| Args: | ||||||
| pf: Posterior forward policy estimator :math:`p_\\phi`. | ||||||
| prior_pf: Fixed prior forward policy estimator :math:`p_\\theta`. | ||||||
| logZ: Learnable log-partition parameter or ScalarEstimator for | ||||||
| conditional settings. Defaults to a scalar parameter. | ||||||
| init_logZ: Initial value for logZ if ``logZ`` is None. | ||||||
| beta: Optional scaling applied to the terminal log-reward. | ||||||
| log_reward_clip_min: If finite, clips terminal log-rewards. | ||||||
| debug: if True, enables extra checks at the cost of execution speed. | ||||||
| """ | ||||||
| super().__init__( | ||||||
| pf=pf, | ||||||
| pb=None, | ||||||
| constant_pb=True, | ||||||
| log_reward_clip_min=log_reward_clip_min, | ||||||
| ) | ||||||
| self.prior_pf = prior_pf | ||||||
| self.beta = torch.tensor(beta) | ||||||
josephdviviano marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| self.logZ = logZ or nn.Parameter(torch.tensor(init_logZ)) | ||||||
| self.debug = debug # TODO: to be passed to base classes. | ||||||
|
||||||
| self.debug = debug # TODO: to be passed to base classes. | |
| self.debug = debug |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The commented-out line appears to be dead code that should be removed. If it's intended for reference, consider moving it to a comment explaining why the change was made rather than leaving commented code.