diff --git a/src/gfn/containers/transitions.py b/src/gfn/containers/transitions.py index ab7eca9a..4671fc54 100644 --- a/src/gfn/containers/transitions.py +++ b/src/gfn/containers/transitions.py @@ -194,7 +194,7 @@ def log_rewards(self) -> torch.Tensor | None: If not provided at initialization, log rewards are computed on demand for terminating transitions. """ - if self.is_backward: + if self.is_backward: # TODO: Why can't backward trajectories have log_rewards? return None if self._log_rewards is None: diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 6441c3eb..5542019b 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -168,6 +168,7 @@ class PFBasedGFlowNet(GFlowNet[TrainingSampleType], ABC): pb: The backward policy estimator, or None if it can be ignored (e.g., the gflownet DAG is a tree, and pb is therefore always 1). constant_pb: Whether to ignore the backward policy estimator. + log_reward_clip_min: If finite, clips log rewards to this value. """ def __init__( @@ -175,6 +176,7 @@ def __init__( pf: Estimator, pb: Estimator | None, constant_pb: bool = False, + log_reward_clip_min: float = float("-inf"), ) -> None: """Initializes a PFBasedGFlowNet instance. @@ -186,6 +188,7 @@ def __init__( gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. + log_reward_clip_min: If finite, clips log rewards to this value. """ super().__init__() @@ -215,11 +218,12 @@ def __init__( self.pf = pf self.pb = pb self.constant_pb = constant_pb + self.log_reward_clip_min = log_reward_clip_min # Advisory: recurrent PF with non-recurrent PB is unusual # (tree DAGs typically prefer pb=None with constant_pb=True). # Import locally to avoid circular imports during module import time. - from gfn.estimators import RecurrentDiscretePolicyEstimator # type: ignore + from gfn.estimators import RecurrentDiscretePolicyEstimator if isinstance(self.pf, RecurrentDiscretePolicyEstimator) and isinstance( self.pb, Estimator @@ -288,7 +292,7 @@ def pf_pb_parameters(self) -> list[torch.Tensor]: return [v for k, v in self.named_parameters() if "pb" in k or "pf" in k] -class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]): +class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories], ABC): """A GFlowNet that operates on complete trajectories. Attributes: @@ -297,32 +301,9 @@ class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]): pb is therefore always 1. constant_pb: Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1. + log_reward_clip_min: If finite, clips log rewards to this value. """ - def __init__( - self, - pf: Estimator, - pb: Estimator | None, - constant_pb: bool = False, - ) -> None: - """Initializes a TrajectoryBasedGFlowNet instance. - - Args: - pf: The forward policy estimator. - pb: The backward policy estimator, or None if the gflownet DAG is a tree, - and pb is therefore always 1. - constant_pb: Whether to ignore the backward policy estimator, e.g., if the - gflownet DAG is a tree, and pb is therefore always 1. Must be set - explicitly by user to ensure that pb is an Estimator except under this - special case. - - """ - super().__init__( - pf, - pb, - constant_pb=constant_pb, - ) - def get_pfs_and_pbs( self, trajectories: Trajectories, @@ -388,8 +369,9 @@ def get_scores( total_log_pb_trajectories = log_pb_trajectories.sum(dim=0) log_rewards = trajectories.log_rewards + assert log_rewards is not None - if math.isfinite(self.log_reward_clip_min) and log_rewards is not None: + if math.isfinite(self.log_reward_clip_min): log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) if torch.any(torch.isinf(total_log_pf_trajectories)): @@ -399,7 +381,6 @@ def get_scores( assert total_log_pf_trajectories.shape == (trajectories.n_trajectories,) assert total_log_pb_trajectories.shape == (trajectories.n_trajectories,) - assert log_rewards is not None return total_log_pf_trajectories - total_log_pb_trajectories - log_rewards def to_training_samples(self, trajectories: Trajectories) -> Trajectories: diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index 7376574c..feff5462 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -60,11 +60,9 @@ class DBGFlowNet(PFBasedGFlowNet[Transitions]): logF: A ScalarEstimator or ConditionalScalarEstimator for estimating the log flow of the states. forward_looking: Whether to use the forward-looking GFN loss. - log_reward_clip_min: If finite, clips log rewards to this value. - safe_log_prob_min: If True, uses -1e10 as the minimum log probability value - to avoid numerical instability, otherwise uses -1e38. constant_pb: Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1. + log_reward_clip_min: If finite, clips log rewards to this value. """ def __init__( @@ -73,9 +71,8 @@ def __init__( pb: Estimator | None, logF: ScalarEstimator | ConditionalScalarEstimator, forward_looking: bool = False, - log_reward_clip_min: float = -float("inf"), - safe_log_prob_min: bool = True, constant_pb: bool = False, + log_reward_clip_min: float = -float("inf"), ) -> None: """Initializes a DBGFlowNet instance. @@ -86,19 +83,19 @@ def __init__( logF: A ScalarEstimator or ConditionalScalarEstimator for estimating the log flow of the states. forward_looking: Whether to use the forward-looking GFN loss. - log_reward_clip_min: If finite, clips log rewards to this value. - safe_log_prob_min: If True, uses -1e10 as the minimum log probability value - to avoid numerical instability, otherwise uses -1e38. constant_pb: Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. + log_reward_clip_min: If finite, clips log rewards to this value. """ - super().__init__(pf, pb, constant_pb=constant_pb) + super().__init__( + pf, pb, constant_pb=constant_pb, log_reward_clip_min=log_reward_clip_min + ) # Disallow recurrent PF for transition-based DB - from gfn.estimators import RecurrentDiscretePolicyEstimator # type: ignore + from gfn.estimators import RecurrentDiscretePolicyEstimator if isinstance(self.pf, RecurrentDiscretePolicyEstimator): raise TypeError( @@ -112,11 +109,6 @@ def __init__( self.logF = logF self.forward_looking = forward_looking - self.log_reward_clip_min = log_reward_clip_min - if safe_log_prob_min: - self.log_prob_min = -1e10 - else: - self.log_prob_min = -1e38 def logF_named_parameters(self) -> dict[str, torch.Tensor]: """Returns a dictionary of named parameters containing 'logF' in their name. @@ -191,14 +183,15 @@ def get_scores( if len(states) == 0: return torch.tensor(0.0, device=transitions.device) - # uncomment next line for debugging - # assert transitions.states.is_sink_state.equal(transitions.actions.is_dummy) check_compatibility(states, actions, transitions) + assert ( + not transitions.states.is_sink_state.any() + ), "Transition from sink state is not allowed. This is a bug." - log_pf_actions, log_pb_actions = self.get_pfs_and_pbs( - transitions, recalculate_all_logprobs - ) + ### Compute log_pf and log_pb + log_pf, log_pb = self.get_pfs_and_pbs(transitions, recalculate_all_logprobs) + ### Compute log_F_s # LogF is potentially a conditional computation. if transitions.conditions is not None: with has_conditions_exception_handler("logF", self.logF): @@ -207,50 +200,65 @@ def get_scores( with no_conditions_exception_handler("logF", self.logF): log_F_s = self.logF(states).squeeze(-1) - if self.forward_looking: - log_rewards = env.log_reward(states) - if math.isfinite(self.log_reward_clip_min): - log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) - log_F_s = log_F_s + log_rewards - - preds = log_pf_actions + log_F_s - - # uncomment next line for debugging - # assert transitions.next_states.is_sink_state.equal(transitions.is_terminating) - - # automatically removes invalid transitions (i.e. s_f -> s_f) - valid_next_states = transitions.next_states[~transitions.is_terminating] - valid_transitions_is_terminating = transitions.is_terminating[ - ~transitions.states.is_sink_state - ] + ### Compute log_F_s_next + log_F_s_next = torch.zeros_like(log_F_s) + is_terminating = transitions.is_terminating + is_intermediate = ~is_terminating - if len(valid_next_states) == 0: - return torch.tensor(0.0, device=transitions.device) - - # LogF is potentially a conditional computation. + # Assign log_F_s_next for intermediate next states + interm_next_states = transitions.next_states[is_intermediate] + # log_F is potentially a conditional computation. if transitions.conditions is not None: with has_conditions_exception_handler("logF", self.logF): - valid_log_F_s_next = self.logF( - valid_next_states, - transitions.conditions[~transitions.is_terminating], + log_F_s_next[is_intermediate] = self.logF( + interm_next_states, + transitions.conditions[is_intermediate], ).squeeze(-1) else: with no_conditions_exception_handler("logF", self.logF): - valid_log_F_s_next = self.logF(valid_next_states).squeeze(-1) - - log_F_s_next = torch.zeros_like(log_pb_actions) - log_F_s_next[~valid_transitions_is_terminating] = valid_log_F_s_next - assert transitions.log_rewards is not None - valid_transitions_log_rewards = transitions.log_rewards[ - ~transitions.states.is_sink_state - ] - log_F_s_next[valid_transitions_is_terminating] = valid_transitions_log_rewards[ - valid_transitions_is_terminating - ] - targets = log_pb_actions + log_F_s_next + log_F_s_next[is_intermediate] = self.logF(interm_next_states).squeeze(-1) - scores = preds - targets + # Apply forward-looking if applicable + if self.forward_looking: + import warnings + + warnings.warn( + "Rewards should be defined over edges in forward-looking settings. " + "The current implementation is a special case of this, where the edge " + "reward is defined as the difference between the reward of two states " + "that the edge connects. If your environment is not the case, " + "forward-looking may be inappropriate." + ) + + # Reward calculation can also be conditional. + if transitions.conditions is not None: + log_rewards_state = env.log_reward(states, transitions.conditions) # type: ignore + log_rewards_next = env.log_reward( + interm_next_states, transitions.conditions[is_intermediate] # type: ignore + ) + else: + log_rewards_state = env.log_reward(states) + log_rewards_next = env.log_reward(interm_next_states) + if math.isfinite(self.log_reward_clip_min): + log_rewards_state = log_rewards_state.clamp_min(self.log_reward_clip_min) + log_rewards_next = log_rewards_next.clamp_min(self.log_reward_clip_min) + log_F_s = log_F_s + log_rewards_state + log_F_s_next[is_intermediate] = ( + log_F_s_next[is_intermediate] + log_rewards_next + ) + + # Assign log_F_s_next for terminating transitions as log_rewards + log_rewards = transitions.log_rewards + assert log_rewards is not None + if math.isfinite(self.log_reward_clip_min): + log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) + log_F_s_next[is_terminating] = log_rewards[is_terminating] + + ### Compute scores + preds = log_pf + log_F_s + targets = log_pb + log_F_s_next + scores = preds - targets assert scores.shape == (transitions.n_transitions,) return scores diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index e09fb0c7..37296c61 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -102,7 +102,9 @@ def __init__( special case. """ - super().__init__(pf, pb, constant_pb=constant_pb) + super().__init__( + pf, pb, constant_pb=constant_pb, log_reward_clip_min=log_reward_clip_min + ) assert any( isinstance(logF, cls) for cls in [ScalarEstimator, ConditionalScalarEstimator] @@ -110,7 +112,6 @@ def __init__( self.logF = logF self.weighting = weighting self.lamda = lamda - self.log_reward_clip_min = log_reward_clip_min self.forward_looking = forward_looking def logF_named_parameters(self) -> dict[str, torch.Tensor]: diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index 03c6a7a3..8e2d6e6b 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -35,8 +35,10 @@ class TBGFlowNet(TrajectoryBasedGFlowNet): pb: The backward policy estimator, or None if the gflownet DAG is a tree, and pb is therefore always 1. logZ: A learnable parameter or a ScalarEstimator instance (for conditional GFNs). + constant_pb: Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb + is therefore always 1. Must be set explicitly by user to ensure that pb + is an Estimator except under this special case. log_reward_clip_min: If finite, clips log rewards to this value. - constant_pb: Whether the gflownet DAG is a tree, and pb is therefore always 1. """ def __init__( @@ -45,8 +47,8 @@ def __init__( pb: Estimator | None, logZ: nn.Parameter | ScalarEstimator | None = None, init_logZ: float = 0.0, - log_reward_clip_min: float = -float("inf"), constant_pb: bool = False, + log_reward_clip_min: float = -float("inf"), ): """Initializes a TBGFlowNet instance. @@ -57,15 +59,16 @@ def __init__( logZ: A learnable parameter or a ScalarEstimator instance (for conditional GFNs). init_logZ: The initial value for the logZ parameter (used if logZ is None). - log_reward_clip_min: If finite, clips log rewards to this value. constant_pb: Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. + log_reward_clip_min: If finite, clips log rewards to this value. """ - super().__init__(pf, pb, constant_pb=constant_pb) + super().__init__( + pf, pb, constant_pb=constant_pb, log_reward_clip_min=log_reward_clip_min + ) self.logZ = logZ or nn.Parameter(torch.tensor(init_logZ)) - self.log_reward_clip_min = log_reward_clip_min def logz_named_parameters(self) -> dict[str, torch.Tensor]: """Returns a dictionary of named parameters containing 'logZ' in their name. @@ -138,25 +141,12 @@ class LogPartitionVarianceGFlowNet(TrajectoryBasedGFlowNet): Attributes: pf: The forward policy estimator. pb: The backward policy estimator. + constant_pb: Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb + is therefore always 1. Must be set explicitly by user to ensure that pb + is an Estimator except under this special case. log_reward_clip_min: If finite, clips log rewards to this value. """ - def __init__( - self, - pf: Estimator, - pb: Estimator, - log_reward_clip_min: float = -float("inf"), - ): - """Initializes a LogPartitionVarianceGFlowNet instance. - - Args: - pf: The forward policy estimator. - pb: The backward policy estimator. - log_reward_clip_min: If finite, clips log rewards to this value. - """ - super().__init__(pf, pb) - self.log_reward_clip_min = log_reward_clip_min - def loss( self, env: Env, diff --git a/tutorials/examples/train_hypergrid_gafn.py b/tutorials/examples/train_hypergrid_gafn.py index 8eb6835e..4df290d1 100644 --- a/tutorials/examples/train_hypergrid_gafn.py +++ b/tutorials/examples/train_hypergrid_gafn.py @@ -142,7 +142,9 @@ def __init__( flow_estimator: The flow estimator, required if use_edge_ri is True. log_reward_clip_min: If finite, clips log rewards to this value. """ - super().__init__(pf, pb, logZ, init_logZ, log_reward_clip_min) + super().__init__( + pf, pb, logZ, init_logZ, log_reward_clip_min=log_reward_clip_min + ) self.rnd = rnd self.use_edge_ri = use_edge_ri if use_edge_ri and flow_estimator is None: diff --git a/tutorials/examples/train_hypergrid_simple.py b/tutorials/examples/train_hypergrid_simple.py index cf3c96b0..10cd7a08 100644 --- a/tutorials/examples/train_hypergrid_simple.py +++ b/tutorials/examples/train_hypergrid_simple.py @@ -20,8 +20,8 @@ import torch from tqdm import tqdm -from gfn.estimators import DiscretePolicyEstimator -from gfn.gflownet import TBGFlowNet +from gfn.estimators import DiscretePolicyEstimator, ScalarEstimator +from gfn.gflownet import DBGFlowNet, FMGFlowNet, TBGFlowNet from gfn.gym import HyperGrid from gfn.preprocessors import KHotPreprocessor from gfn.samplers import Sampler @@ -67,24 +67,53 @@ def main(args): ) else: module_PB = DiscreteUniform(output_dim=env.n_actions - 1) - pf_estimator = DiscretePolicyEstimator( - module_PF, env.n_actions, preprocessor=preprocessor, is_backward=False - ) - pb_estimator = DiscretePolicyEstimator( - module_PB, env.n_actions, preprocessor=preprocessor, is_backward=True - ) - gflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator, init_logZ=0.0) - # Feed pf to the sampler. - sampler = Sampler(estimator=pf_estimator) - - # Move the gflownet to the GPU. - gflownet = gflownet.to(device) + # Initialize the key components + # 1. Estimator(s) + # 2. GFlowNet + # 3. Optimizer + # 4. Sampler + if args.loss == "FM": + logF_estimator = DiscretePolicyEstimator( + module=module_PF, + n_actions=env.n_actions, + preprocessor=preprocessor, + ) + gflownet = FMGFlowNet(logF_estimator).to(device) + optimizer = torch.optim.Adam(gflownet.logF.parameters(), lr=args.lr) + sampler = Sampler(estimator=logF_estimator) - # Policy parameters have their own LR. Log Z gets dedicated learning rate - # (typically higher). - optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=args.lr) - optimizer.add_param_group({"params": gflownet.logz_parameters(), "lr": args.lr_logz}) + else: + pf_estimator = DiscretePolicyEstimator( + module_PF, env.n_actions, preprocessor=preprocessor, is_backward=False + ) + pb_estimator = DiscretePolicyEstimator( + module_PB, env.n_actions, preprocessor=preprocessor, is_backward=True + ) + if args.loss == "DB": + logF_module = MLP( + input_dim=preprocessor.output_dim, + output_dim=1, + # trunk=module_PF.trunk, # FIXME: This raises an Error + ) + logF_estimator = ScalarEstimator( + module=logF_module, preprocessor=preprocessor + ) + gflownet = DBGFlowNet(pf=pf_estimator, pb=pb_estimator, logF=logF_estimator) + else: + gflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator, init_logZ=0.0) + + gflownet = gflownet.to(device) + optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=args.lr) + if isinstance(gflownet, DBGFlowNet): + optimizer.add_param_group( + {"params": gflownet.logF.parameters(), "lr": args.lr} + ) + else: # TBGFlowNet + optimizer.add_param_group( + {"params": gflownet.logz_parameters(), "lr": args.lr_logz} + ) + sampler = Sampler(estimator=pf_estimator) validation_info = {"l1_dist": float("inf")} visited_terminating_states = env.states_from_batch_shape((0,)) @@ -94,7 +123,7 @@ def main(args): trajectories = sampler.sample_trajectories( env, n=args.batch_size, - save_logprobs=True, + save_logprobs=False, save_estimator_outputs=False, epsilon=args.epsilon, ) @@ -103,7 +132,9 @@ def main(args): ) optimizer.zero_grad() - loss = gflownet.loss(env, trajectories, recalculate_all_logprobs=False) + loss = gflownet.loss_from_trajectories( + env, trajectories, recalculate_all_logprobs=False + ) loss.backward() gflownet.assert_finite_gradients() @@ -138,11 +169,18 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--no_cuda", action="store_true", help="Prevent CUDA usage") + parser.add_argument( + "--loss", + type=str, + choices=["FM", "TB", "DB"], + default="TB", + help="Loss function to use", + ) parser.add_argument( "--ndim", type=int, default=2, help="Number of dimensions in the environment" ) parser.add_argument( - "--height", type=int, default=64, help="Height of the environment" + "--height", type=int, default=32, help="Height of the environment" ) parser.add_argument( "--R0",