-
Notifications
You must be signed in to change notification settings - Fork 55
Refactor Detailed Balance #432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
00941d9
ac108cb
a86889a
7400b6a
dec1c9f
5f4ce7e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -168,13 +168,15 @@ class PFBasedGFlowNet(GFlowNet[TrainingSampleType], ABC): | |
| pb: The backward policy estimator, or None if it can be ignored (e.g., the | ||
| gflownet DAG is a tree, and pb is therefore always 1). | ||
| constant_pb: Whether to ignore the backward policy estimator. | ||
| log_reward_clip_min: If finite, clips log rewards to this value. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| pf: Estimator, | ||
| pb: Estimator | None, | ||
| constant_pb: bool = False, | ||
| log_reward_clip_min: float = float("-inf"), | ||
| ) -> None: | ||
| """Initializes a PFBasedGFlowNet instance. | ||
|
|
||
|
|
@@ -186,6 +188,7 @@ def __init__( | |
| gflownet DAG is a tree, and pb is therefore always 1. Must be set | ||
| explicitly by user to ensure that pb is an Estimator except under this | ||
| special case. | ||
| log_reward_clip_min: If finite, clips log rewards to this value. | ||
|
|
||
| """ | ||
| super().__init__() | ||
|
|
@@ -215,11 +218,12 @@ def __init__( | |
| self.pf = pf | ||
| self.pb = pb | ||
| self.constant_pb = constant_pb | ||
| self.log_reward_clip_min = log_reward_clip_min | ||
|
|
||
| # Advisory: recurrent PF with non-recurrent PB is unusual | ||
| # (tree DAGs typically prefer pb=None with constant_pb=True). | ||
| # Import locally to avoid circular imports during module import time. | ||
| from gfn.estimators import RecurrentDiscretePolicyEstimator # type: ignore | ||
| from gfn.estimators import RecurrentDiscretePolicyEstimator | ||
|
|
||
| if isinstance(self.pf, RecurrentDiscretePolicyEstimator) and isinstance( | ||
| self.pb, Estimator | ||
|
|
@@ -288,7 +292,7 @@ def pf_pb_parameters(self) -> list[torch.Tensor]: | |
| return [v for k, v in self.named_parameters() if "pb" in k or "pf" in k] | ||
|
|
||
|
|
||
| class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]): | ||
| class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories], ABC): | ||
| """A GFlowNet that operates on complete trajectories. | ||
|
|
||
| Attributes: | ||
|
|
@@ -297,32 +301,9 @@ class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]): | |
| pb is therefore always 1. | ||
| constant_pb: Whether to ignore the backward policy estimator, e.g., if the | ||
| gflownet DAG is a tree, and pb is therefore always 1. | ||
| log_reward_clip_min: If finite, clips log rewards to this value. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| pf: Estimator, | ||
| pb: Estimator | None, | ||
| constant_pb: bool = False, | ||
| ) -> None: | ||
| """Initializes a TrajectoryBasedGFlowNet instance. | ||
|
|
||
| Args: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a bit confused about how you were able to get rid of this, but seems fine!
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It calls |
||
| pf: The forward policy estimator. | ||
| pb: The backward policy estimator, or None if the gflownet DAG is a tree, | ||
| and pb is therefore always 1. | ||
| constant_pb: Whether to ignore the backward policy estimator, e.g., if the | ||
| gflownet DAG is a tree, and pb is therefore always 1. Must be set | ||
| explicitly by user to ensure that pb is an Estimator except under this | ||
| special case. | ||
|
|
||
| """ | ||
| super().__init__( | ||
| pf, | ||
| pb, | ||
| constant_pb=constant_pb, | ||
| ) | ||
|
|
||
| def get_pfs_and_pbs( | ||
| self, | ||
| trajectories: Trajectories, | ||
|
|
@@ -388,8 +369,9 @@ def get_scores( | |
| total_log_pb_trajectories = log_pb_trajectories.sum(dim=0) | ||
|
|
||
| log_rewards = trajectories.log_rewards | ||
| assert log_rewards is not None | ||
|
|
||
| if math.isfinite(self.log_reward_clip_min) and log_rewards is not None: | ||
| if math.isfinite(self.log_reward_clip_min): | ||
| log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) | ||
|
|
||
| if torch.any(torch.isinf(total_log_pf_trajectories)): | ||
|
|
@@ -399,7 +381,6 @@ def get_scores( | |
|
|
||
| assert total_log_pf_trajectories.shape == (trajectories.n_trajectories,) | ||
| assert total_log_pb_trajectories.shape == (trajectories.n_trajectories,) | ||
| assert log_rewards is not None | ||
| return total_log_pf_trajectories - total_log_pb_trajectories - log_rewards | ||
|
|
||
| def to_training_samples(self, trajectories: Trajectories) -> Trajectories: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,11 +60,9 @@ class DBGFlowNet(PFBasedGFlowNet[Transitions]): | |
| logF: A ScalarEstimator or ConditionalScalarEstimator for estimating the log | ||
| flow of the states. | ||
| forward_looking: Whether to use the forward-looking GFN loss. | ||
| log_reward_clip_min: If finite, clips log rewards to this value. | ||
| safe_log_prob_min: If True, uses -1e10 as the minimum log probability value | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why did you remove this? it's useful.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| to avoid numerical instability, otherwise uses -1e38. | ||
| constant_pb: Whether to ignore the backward policy estimator, e.g., if the | ||
| gflownet DAG is a tree, and pb is therefore always 1. | ||
| log_reward_clip_min: If finite, clips log rewards to this value. | ||
| """ | ||
|
|
||
| def __init__( | ||
|
|
@@ -73,9 +71,8 @@ def __init__( | |
| pb: Estimator | None, | ||
| logF: ScalarEstimator | ConditionalScalarEstimator, | ||
| forward_looking: bool = False, | ||
| log_reward_clip_min: float = -float("inf"), | ||
| safe_log_prob_min: bool = True, | ||
| constant_pb: bool = False, | ||
| log_reward_clip_min: float = -float("inf"), | ||
| ) -> None: | ||
| """Initializes a DBGFlowNet instance. | ||
|
|
||
|
|
@@ -86,19 +83,19 @@ def __init__( | |
| logF: A ScalarEstimator or ConditionalScalarEstimator for estimating the log | ||
| flow of the states. | ||
| forward_looking: Whether to use the forward-looking GFN loss. | ||
| log_reward_clip_min: If finite, clips log rewards to this value. | ||
| safe_log_prob_min: If True, uses -1e10 as the minimum log probability value | ||
| to avoid numerical instability, otherwise uses -1e38. | ||
| constant_pb: Whether to ignore the backward policy estimator, e.g., if the | ||
| gflownet DAG is a tree, and pb is therefore always 1. Must be set | ||
| explicitly by user to ensure that pb is an Estimator except under this | ||
| special case. | ||
| log_reward_clip_min: If finite, clips log rewards to this value. | ||
|
|
||
| """ | ||
| super().__init__(pf, pb, constant_pb=constant_pb) | ||
| super().__init__( | ||
| pf, pb, constant_pb=constant_pb, log_reward_clip_min=log_reward_clip_min | ||
| ) | ||
|
|
||
| # Disallow recurrent PF for transition-based DB | ||
| from gfn.estimators import RecurrentDiscretePolicyEstimator # type: ignore | ||
| from gfn.estimators import RecurrentDiscretePolicyEstimator | ||
|
|
||
| if isinstance(self.pf, RecurrentDiscretePolicyEstimator): | ||
| raise TypeError( | ||
|
|
@@ -112,11 +109,6 @@ def __init__( | |
|
|
||
| self.logF = logF | ||
| self.forward_looking = forward_looking | ||
| self.log_reward_clip_min = log_reward_clip_min | ||
| if safe_log_prob_min: | ||
| self.log_prob_min = -1e10 | ||
| else: | ||
| self.log_prob_min = -1e38 | ||
|
|
||
| def logF_named_parameters(self) -> dict[str, torch.Tensor]: | ||
| """Returns a dictionary of named parameters containing 'logF' in their name. | ||
|
|
@@ -191,14 +183,15 @@ def get_scores( | |
| if len(states) == 0: | ||
| return torch.tensor(0.0, device=transitions.device) | ||
|
|
||
| # uncomment next line for debugging | ||
| # assert transitions.states.is_sink_state.equal(transitions.actions.is_dummy) | ||
| check_compatibility(states, actions, transitions) | ||
| assert ( | ||
| not transitions.states.is_sink_state.any() | ||
| ), "Transition from sink state is not allowed. This is a bug." | ||
|
|
||
| log_pf_actions, log_pb_actions = self.get_pfs_and_pbs( | ||
| transitions, recalculate_all_logprobs | ||
| ) | ||
| ### Compute log_pf and log_pb | ||
| log_pf, log_pb = self.get_pfs_and_pbs(transitions, recalculate_all_logprobs) | ||
|
|
||
| ### Compute log_F_s | ||
| # LogF is potentially a conditional computation. | ||
| if transitions.conditions is not None: | ||
| with has_conditions_exception_handler("logF", self.logF): | ||
|
|
@@ -207,50 +200,65 @@ def get_scores( | |
| with no_conditions_exception_handler("logF", self.logF): | ||
| log_F_s = self.logF(states).squeeze(-1) | ||
|
|
||
| if self.forward_looking: | ||
| log_rewards = env.log_reward(states) | ||
| if math.isfinite(self.log_reward_clip_min): | ||
| log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) | ||
| log_F_s = log_F_s + log_rewards | ||
|
|
||
| preds = log_pf_actions + log_F_s | ||
|
|
||
| # uncomment next line for debugging | ||
| # assert transitions.next_states.is_sink_state.equal(transitions.is_terminating) | ||
|
|
||
| # automatically removes invalid transitions (i.e. s_f -> s_f) | ||
| valid_next_states = transitions.next_states[~transitions.is_terminating] | ||
| valid_transitions_is_terminating = transitions.is_terminating[ | ||
| ~transitions.states.is_sink_state | ||
| ] | ||
| ### Compute log_F_s_next | ||
| log_F_s_next = torch.zeros_like(log_F_s) | ||
| is_terminating = transitions.is_terminating | ||
| is_intermediate = ~is_terminating | ||
|
|
||
| if len(valid_next_states) == 0: | ||
| return torch.tensor(0.0, device=transitions.device) | ||
|
|
||
| # LogF is potentially a conditional computation. | ||
| # Assign log_F_s_next for intermediate next states | ||
| interm_next_states = transitions.next_states[is_intermediate] | ||
| # log_F is potentially a conditional computation. | ||
| if transitions.conditions is not None: | ||
| with has_conditions_exception_handler("logF", self.logF): | ||
| valid_log_F_s_next = self.logF( | ||
| valid_next_states, | ||
| transitions.conditions[~transitions.is_terminating], | ||
| log_F_s_next[is_intermediate] = self.logF( | ||
| interm_next_states, | ||
| transitions.conditions[is_intermediate], | ||
| ).squeeze(-1) | ||
| else: | ||
| with no_conditions_exception_handler("logF", self.logF): | ||
| valid_log_F_s_next = self.logF(valid_next_states).squeeze(-1) | ||
|
|
||
| log_F_s_next = torch.zeros_like(log_pb_actions) | ||
| log_F_s_next[~valid_transitions_is_terminating] = valid_log_F_s_next | ||
| assert transitions.log_rewards is not None | ||
| valid_transitions_log_rewards = transitions.log_rewards[ | ||
| ~transitions.states.is_sink_state | ||
| ] | ||
| log_F_s_next[valid_transitions_is_terminating] = valid_transitions_log_rewards[ | ||
| valid_transitions_is_terminating | ||
| ] | ||
| targets = log_pb_actions + log_F_s_next | ||
| log_F_s_next[is_intermediate] = self.logF(interm_next_states).squeeze(-1) | ||
|
|
||
| scores = preds - targets | ||
| # Apply forward-looking if applicable | ||
| if self.forward_looking: | ||
| import warnings | ||
|
|
||
| warnings.warn( | ||
| "Rewards should be defined over edges in forward-looking settings. " | ||
| "The current implementation is a special case of this, where the edge " | ||
| "reward is defined as the difference between the reward of two states " | ||
| "that the edge connects. If your environment is not the case, " | ||
| "forward-looking may be inappropriate." | ||
| ) | ||
|
|
||
| # Reward calculation can also be conditional. | ||
| if transitions.conditions is not None: | ||
| log_rewards_state = env.log_reward(states, transitions.conditions) # type: ignore | ||
| log_rewards_next = env.log_reward( | ||
| interm_next_states, transitions.conditions[is_intermediate] # type: ignore | ||
| ) | ||
| else: | ||
| log_rewards_state = env.log_reward(states) | ||
| log_rewards_next = env.log_reward(interm_next_states) | ||
| if math.isfinite(self.log_reward_clip_min): | ||
| log_rewards_state = log_rewards_state.clamp_min(self.log_reward_clip_min) | ||
| log_rewards_next = log_rewards_next.clamp_min(self.log_reward_clip_min) | ||
|
|
||
| log_F_s = log_F_s + log_rewards_state | ||
| log_F_s_next[is_intermediate] = ( | ||
| log_F_s_next[is_intermediate] + log_rewards_next | ||
| ) | ||
|
|
||
| # Assign log_F_s_next for terminating transitions as log_rewards | ||
| log_rewards = transitions.log_rewards | ||
| assert log_rewards is not None | ||
| if math.isfinite(self.log_reward_clip_min): | ||
| log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) | ||
| log_F_s_next[is_terminating] = log_rewards[is_terminating] | ||
|
|
||
| ### Compute scores | ||
| preds = log_pf + log_F_s | ||
| targets = log_pb + log_F_s_next | ||
| scores = preds - targets | ||
hyeok9855 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| assert scores.shape == (transitions.n_transitions,) | ||
| return scores | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.