Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/gfn/containers/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def log_rewards(self) -> torch.Tensor | None:
If not provided at initialization, log rewards are computed on demand for
terminating transitions.
"""
if self.is_backward:
if self.is_backward: # TODO: Why can't backward trajectories have log_rewards?
return None

if self._log_rewards is None:
Expand Down
37 changes: 9 additions & 28 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,15 @@ class PFBasedGFlowNet(GFlowNet[TrainingSampleType], ABC):
pb: The backward policy estimator, or None if it can be ignored (e.g., the
gflownet DAG is a tree, and pb is therefore always 1).
constant_pb: Whether to ignore the backward policy estimator.
log_reward_clip_min: If finite, clips log rewards to this value.
"""

def __init__(
self,
pf: Estimator,
pb: Estimator | None,
constant_pb: bool = False,
log_reward_clip_min: float = float("-inf"),
) -> None:
"""Initializes a PFBasedGFlowNet instance.

Expand All @@ -186,6 +188,7 @@ def __init__(
gflownet DAG is a tree, and pb is therefore always 1. Must be set
explicitly by user to ensure that pb is an Estimator except under this
special case.
log_reward_clip_min: If finite, clips log rewards to this value.

"""
super().__init__()
Expand Down Expand Up @@ -215,11 +218,12 @@ def __init__(
self.pf = pf
self.pb = pb
self.constant_pb = constant_pb
self.log_reward_clip_min = log_reward_clip_min

# Advisory: recurrent PF with non-recurrent PB is unusual
# (tree DAGs typically prefer pb=None with constant_pb=True).
# Import locally to avoid circular imports during module import time.
from gfn.estimators import RecurrentDiscretePolicyEstimator # type: ignore
from gfn.estimators import RecurrentDiscretePolicyEstimator

if isinstance(self.pf, RecurrentDiscretePolicyEstimator) and isinstance(
self.pb, Estimator
Expand Down Expand Up @@ -288,7 +292,7 @@ def pf_pb_parameters(self) -> list[torch.Tensor]:
return [v for k, v in self.named_parameters() if "pb" in k or "pf" in k]


class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]):
class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories], ABC):
"""A GFlowNet that operates on complete trajectories.

Attributes:
Expand All @@ -297,32 +301,9 @@ class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]):
pb is therefore always 1.
constant_pb: Whether to ignore the backward policy estimator, e.g., if the
gflownet DAG is a tree, and pb is therefore always 1.
log_reward_clip_min: If finite, clips log rewards to this value.
"""

def __init__(
self,
pf: Estimator,
pb: Estimator | None,
constant_pb: bool = False,
) -> None:
"""Initializes a TrajectoryBasedGFlowNet instance.

Args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused about how you were able to get rid of this, but seems fine!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It calls super().__init__(...) and does nothing else, so I could remove it.

pf: The forward policy estimator.
pb: The backward policy estimator, or None if the gflownet DAG is a tree,
and pb is therefore always 1.
constant_pb: Whether to ignore the backward policy estimator, e.g., if the
gflownet DAG is a tree, and pb is therefore always 1. Must be set
explicitly by user to ensure that pb is an Estimator except under this
special case.

"""
super().__init__(
pf,
pb,
constant_pb=constant_pb,
)

def get_pfs_and_pbs(
self,
trajectories: Trajectories,
Expand Down Expand Up @@ -388,8 +369,9 @@ def get_scores(
total_log_pb_trajectories = log_pb_trajectories.sum(dim=0)

log_rewards = trajectories.log_rewards
assert log_rewards is not None

if math.isfinite(self.log_reward_clip_min) and log_rewards is not None:
if math.isfinite(self.log_reward_clip_min):
log_rewards = log_rewards.clamp_min(self.log_reward_clip_min)

if torch.any(torch.isinf(total_log_pf_trajectories)):
Expand All @@ -399,7 +381,6 @@ def get_scores(

assert total_log_pf_trajectories.shape == (trajectories.n_trajectories,)
assert total_log_pb_trajectories.shape == (trajectories.n_trajectories,)
assert log_rewards is not None
return total_log_pf_trajectories - total_log_pb_trajectories - log_rewards

def to_training_samples(self, trajectories: Trajectories) -> Trajectories:
Expand Down
120 changes: 64 additions & 56 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,9 @@ class DBGFlowNet(PFBasedGFlowNet[Transitions]):
logF: A ScalarEstimator or ConditionalScalarEstimator for estimating the log
flow of the states.
forward_looking: Whether to use the forward-looking GFN loss.
log_reward_clip_min: If finite, clips log rewards to this value.
safe_log_prob_min: If True, uses -1e10 as the minimum log probability value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you remove this? it's useful.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • log_reward_clip_min is in the PFBasedGFlowNet now
  • safe_log_prob_min was not used anywhere.

to avoid numerical instability, otherwise uses -1e38.
constant_pb: Whether to ignore the backward policy estimator, e.g., if the
gflownet DAG is a tree, and pb is therefore always 1.
log_reward_clip_min: If finite, clips log rewards to this value.
"""

def __init__(
Expand All @@ -73,9 +71,8 @@ def __init__(
pb: Estimator | None,
logF: ScalarEstimator | ConditionalScalarEstimator,
forward_looking: bool = False,
log_reward_clip_min: float = -float("inf"),
safe_log_prob_min: bool = True,
constant_pb: bool = False,
log_reward_clip_min: float = -float("inf"),
) -> None:
"""Initializes a DBGFlowNet instance.

Expand All @@ -86,19 +83,19 @@ def __init__(
logF: A ScalarEstimator or ConditionalScalarEstimator for estimating the log
flow of the states.
forward_looking: Whether to use the forward-looking GFN loss.
log_reward_clip_min: If finite, clips log rewards to this value.
safe_log_prob_min: If True, uses -1e10 as the minimum log probability value
to avoid numerical instability, otherwise uses -1e38.
constant_pb: Whether to ignore the backward policy estimator, e.g., if the
gflownet DAG is a tree, and pb is therefore always 1. Must be set
explicitly by user to ensure that pb is an Estimator except under this
special case.
log_reward_clip_min: If finite, clips log rewards to this value.

"""
super().__init__(pf, pb, constant_pb=constant_pb)
super().__init__(
pf, pb, constant_pb=constant_pb, log_reward_clip_min=log_reward_clip_min
)

# Disallow recurrent PF for transition-based DB
from gfn.estimators import RecurrentDiscretePolicyEstimator # type: ignore
from gfn.estimators import RecurrentDiscretePolicyEstimator

if isinstance(self.pf, RecurrentDiscretePolicyEstimator):
raise TypeError(
Expand All @@ -112,11 +109,6 @@ def __init__(

self.logF = logF
self.forward_looking = forward_looking
self.log_reward_clip_min = log_reward_clip_min
if safe_log_prob_min:
self.log_prob_min = -1e10
else:
self.log_prob_min = -1e38

def logF_named_parameters(self) -> dict[str, torch.Tensor]:
"""Returns a dictionary of named parameters containing 'logF' in their name.
Expand Down Expand Up @@ -191,14 +183,15 @@ def get_scores(
if len(states) == 0:
return torch.tensor(0.0, device=transitions.device)

# uncomment next line for debugging
# assert transitions.states.is_sink_state.equal(transitions.actions.is_dummy)
check_compatibility(states, actions, transitions)
assert (
not transitions.states.is_sink_state.any()
), "Transition from sink state is not allowed. This is a bug."

log_pf_actions, log_pb_actions = self.get_pfs_and_pbs(
transitions, recalculate_all_logprobs
)
### Compute log_pf and log_pb
log_pf, log_pb = self.get_pfs_and_pbs(transitions, recalculate_all_logprobs)

### Compute log_F_s
# LogF is potentially a conditional computation.
if transitions.conditions is not None:
with has_conditions_exception_handler("logF", self.logF):
Expand All @@ -207,50 +200,65 @@ def get_scores(
with no_conditions_exception_handler("logF", self.logF):
log_F_s = self.logF(states).squeeze(-1)

if self.forward_looking:
log_rewards = env.log_reward(states)
if math.isfinite(self.log_reward_clip_min):
log_rewards = log_rewards.clamp_min(self.log_reward_clip_min)
log_F_s = log_F_s + log_rewards

preds = log_pf_actions + log_F_s

# uncomment next line for debugging
# assert transitions.next_states.is_sink_state.equal(transitions.is_terminating)

# automatically removes invalid transitions (i.e. s_f -> s_f)
valid_next_states = transitions.next_states[~transitions.is_terminating]
valid_transitions_is_terminating = transitions.is_terminating[
~transitions.states.is_sink_state
]
### Compute log_F_s_next
log_F_s_next = torch.zeros_like(log_F_s)
is_terminating = transitions.is_terminating
is_intermediate = ~is_terminating

if len(valid_next_states) == 0:
return torch.tensor(0.0, device=transitions.device)

# LogF is potentially a conditional computation.
# Assign log_F_s_next for intermediate next states
interm_next_states = transitions.next_states[is_intermediate]
# log_F is potentially a conditional computation.
if transitions.conditions is not None:
with has_conditions_exception_handler("logF", self.logF):
valid_log_F_s_next = self.logF(
valid_next_states,
transitions.conditions[~transitions.is_terminating],
log_F_s_next[is_intermediate] = self.logF(
interm_next_states,
transitions.conditions[is_intermediate],
).squeeze(-1)
else:
with no_conditions_exception_handler("logF", self.logF):
valid_log_F_s_next = self.logF(valid_next_states).squeeze(-1)

log_F_s_next = torch.zeros_like(log_pb_actions)
log_F_s_next[~valid_transitions_is_terminating] = valid_log_F_s_next
assert transitions.log_rewards is not None
valid_transitions_log_rewards = transitions.log_rewards[
~transitions.states.is_sink_state
]
log_F_s_next[valid_transitions_is_terminating] = valid_transitions_log_rewards[
valid_transitions_is_terminating
]
targets = log_pb_actions + log_F_s_next
log_F_s_next[is_intermediate] = self.logF(interm_next_states).squeeze(-1)

scores = preds - targets
# Apply forward-looking if applicable
if self.forward_looking:
import warnings

warnings.warn(
"Rewards should be defined over edges in forward-looking settings. "
"The current implementation is a special case of this, where the edge "
"reward is defined as the difference between the reward of two states "
"that the edge connects. If your environment is not the case, "
"forward-looking may be inappropriate."
)

# Reward calculation can also be conditional.
if transitions.conditions is not None:
log_rewards_state = env.log_reward(states, transitions.conditions) # type: ignore
log_rewards_next = env.log_reward(
interm_next_states, transitions.conditions[is_intermediate] # type: ignore
)
else:
log_rewards_state = env.log_reward(states)
log_rewards_next = env.log_reward(interm_next_states)
if math.isfinite(self.log_reward_clip_min):
log_rewards_state = log_rewards_state.clamp_min(self.log_reward_clip_min)
log_rewards_next = log_rewards_next.clamp_min(self.log_reward_clip_min)

log_F_s = log_F_s + log_rewards_state
log_F_s_next[is_intermediate] = (
log_F_s_next[is_intermediate] + log_rewards_next
)

# Assign log_F_s_next for terminating transitions as log_rewards
log_rewards = transitions.log_rewards
assert log_rewards is not None
if math.isfinite(self.log_reward_clip_min):
log_rewards = log_rewards.clamp_min(self.log_reward_clip_min)
log_F_s_next[is_terminating] = log_rewards[is_terminating]

### Compute scores
preds = log_pf + log_F_s
targets = log_pb + log_F_s_next
scores = preds - targets
assert scores.shape == (transitions.n_transitions,)
return scores

Expand Down
5 changes: 3 additions & 2 deletions src/gfn/gflownet/sub_trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,16 @@ def __init__(
special case.

"""
super().__init__(pf, pb, constant_pb=constant_pb)
super().__init__(
pf, pb, constant_pb=constant_pb, log_reward_clip_min=log_reward_clip_min
)
assert any(
isinstance(logF, cls)
for cls in [ScalarEstimator, ConditionalScalarEstimator]
), "logF must be a ScalarEstimator or derived"
self.logF = logF
self.weighting = weighting
self.lamda = lamda
self.log_reward_clip_min = log_reward_clip_min
self.forward_looking = forward_looking

def logF_named_parameters(self) -> dict[str, torch.Tensor]:
Expand Down
32 changes: 11 additions & 21 deletions src/gfn/gflownet/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ class TBGFlowNet(TrajectoryBasedGFlowNet):
pb: The backward policy estimator, or None if the gflownet DAG is a tree, and
pb is therefore always 1.
logZ: A learnable parameter or a ScalarEstimator instance (for conditional GFNs).
constant_pb: Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb
is therefore always 1. Must be set explicitly by user to ensure that pb
is an Estimator except under this special case.
log_reward_clip_min: If finite, clips log rewards to this value.
constant_pb: Whether the gflownet DAG is a tree, and pb is therefore always 1.
"""

def __init__(
Expand All @@ -45,8 +47,8 @@ def __init__(
pb: Estimator | None,
logZ: nn.Parameter | ScalarEstimator | None = None,
init_logZ: float = 0.0,
log_reward_clip_min: float = -float("inf"),
constant_pb: bool = False,
log_reward_clip_min: float = -float("inf"),
):
"""Initializes a TBGFlowNet instance.

Expand All @@ -57,15 +59,16 @@ def __init__(
logZ: A learnable parameter or a ScalarEstimator instance (for
conditional GFNs).
init_logZ: The initial value for the logZ parameter (used if logZ is None).
log_reward_clip_min: If finite, clips log rewards to this value.
constant_pb: Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb
is therefore always 1. Must be set explicitly by user to ensure that pb
is an Estimator except under this special case.
log_reward_clip_min: If finite, clips log rewards to this value.
"""
super().__init__(pf, pb, constant_pb=constant_pb)
super().__init__(
pf, pb, constant_pb=constant_pb, log_reward_clip_min=log_reward_clip_min
)

self.logZ = logZ or nn.Parameter(torch.tensor(init_logZ))
self.log_reward_clip_min = log_reward_clip_min

def logz_named_parameters(self) -> dict[str, torch.Tensor]:
"""Returns a dictionary of named parameters containing 'logZ' in their name.
Expand Down Expand Up @@ -138,25 +141,12 @@ class LogPartitionVarianceGFlowNet(TrajectoryBasedGFlowNet):
Attributes:
pf: The forward policy estimator.
pb: The backward policy estimator.
constant_pb: Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb
is therefore always 1. Must be set explicitly by user to ensure that pb
is an Estimator except under this special case.
log_reward_clip_min: If finite, clips log rewards to this value.
"""

def __init__(
self,
pf: Estimator,
pb: Estimator,
log_reward_clip_min: float = -float("inf"),
):
"""Initializes a LogPartitionVarianceGFlowNet instance.

Args:
pf: The forward policy estimator.
pb: The backward policy estimator.
log_reward_clip_min: If finite, clips log rewards to this value.
"""
super().__init__(pf, pb)
self.log_reward_clip_min = log_reward_clip_min

def loss(
self,
env: Env,
Expand Down
4 changes: 3 additions & 1 deletion tutorials/examples/train_hypergrid_gafn.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ def __init__(
flow_estimator: The flow estimator, required if use_edge_ri is True.
log_reward_clip_min: If finite, clips log rewards to this value.
"""
super().__init__(pf, pb, logZ, init_logZ, log_reward_clip_min)
super().__init__(
pf, pb, logZ, init_logZ, log_reward_clip_min=log_reward_clip_min
)
self.rnd = rnd
self.use_edge_ri = use_edge_ri
if use_edge_ri and flow_estimator is None:
Expand Down
Loading