Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/gfn/containers/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def log_rewards(self) -> torch.Tensor | None:
If not provided at initialization, log rewards are computed on demand for
terminating transitions.
"""
if self.is_backward:
if self.is_backward: # TODO: Why can't backward trajectories have log_rewards?
return None

if self._log_rewards is None:
Expand Down
35 changes: 8 additions & 27 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,15 @@ class PFBasedGFlowNet(GFlowNet[TrainingSampleType], ABC):
pb: The backward policy estimator, or None if it can be ignored (e.g., the
gflownet DAG is a tree, and pb is therefore always 1).
constant_pb: Whether to ignore the backward policy estimator.
log_reward_clip_min: If finite, clips log rewards to this value.
"""

def __init__(
self,
pf: Estimator,
pb: Estimator | None,
constant_pb: bool = False,
log_reward_clip_min: float = float("-inf"),
) -> None:
"""Initializes a PFBasedGFlowNet instance.

Expand All @@ -186,6 +188,7 @@ def __init__(
gflownet DAG is a tree, and pb is therefore always 1. Must be set
explicitly by user to ensure that pb is an Estimator except under this
special case.
log_reward_clip_min: If finite, clips log rewards to this value.

"""
super().__init__()
Expand Down Expand Up @@ -215,6 +218,7 @@ def __init__(
self.pf = pf
self.pb = pb
self.constant_pb = constant_pb
self.log_reward_clip_min = log_reward_clip_min

# Advisory: recurrent PF with non-recurrent PB is unusual
# (tree DAGs typically prefer pb=None with constant_pb=True).
Expand Down Expand Up @@ -288,7 +292,7 @@ def pf_pb_parameters(self) -> list[torch.Tensor]:
return [v for k, v in self.named_parameters() if "pb" in k or "pf" in k]


class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]):
class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories], ABC):
"""A GFlowNet that operates on complete trajectories.

Attributes:
Expand All @@ -297,32 +301,9 @@ class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]):
pb is therefore always 1.
constant_pb: Whether to ignore the backward policy estimator, e.g., if the
gflownet DAG is a tree, and pb is therefore always 1.
log_reward_clip_min: If finite, clips log rewards to this value.
"""

def __init__(
self,
pf: Estimator,
pb: Estimator | None,
constant_pb: bool = False,
) -> None:
"""Initializes a TrajectoryBasedGFlowNet instance.

Args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused about how you were able to get rid of this, but seems fine!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It calls super().__init__(...) and does nothing else, so I could remove it.

pf: The forward policy estimator.
pb: The backward policy estimator, or None if the gflownet DAG is a tree,
and pb is therefore always 1.
constant_pb: Whether to ignore the backward policy estimator, e.g., if the
gflownet DAG is a tree, and pb is therefore always 1. Must be set
explicitly by user to ensure that pb is an Estimator except under this
special case.

"""
super().__init__(
pf,
pb,
constant_pb=constant_pb,
)

def get_pfs_and_pbs(
self,
trajectories: Trajectories,
Expand Down Expand Up @@ -388,8 +369,9 @@ def get_scores(
total_log_pb_trajectories = log_pb_trajectories.sum(dim=0)

log_rewards = trajectories.log_rewards
assert log_rewards is not None

if math.isfinite(self.log_reward_clip_min) and log_rewards is not None:
if math.isfinite(self.log_reward_clip_min):
log_rewards = log_rewards.clamp_min(self.log_reward_clip_min)

if torch.any(torch.isinf(total_log_pf_trajectories)):
Expand All @@ -399,7 +381,6 @@ def get_scores(

assert total_log_pf_trajectories.shape == (trajectories.n_trajectories,)
assert total_log_pb_trajectories.shape == (trajectories.n_trajectories,)
assert log_rewards is not None
return total_log_pf_trajectories - total_log_pb_trajectories - log_rewards

def to_training_samples(self, trajectories: Trajectories) -> Trajectories:
Expand Down
118 changes: 63 additions & 55 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,9 @@ class DBGFlowNet(PFBasedGFlowNet[Transitions]):
logF: A ScalarEstimator or ConditionalScalarEstimator for estimating the log
flow of the states.
forward_looking: Whether to use the forward-looking GFN loss.
log_reward_clip_min: If finite, clips log rewards to this value.
safe_log_prob_min: If True, uses -1e10 as the minimum log probability value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you remove this? it's useful.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • log_reward_clip_min is in the PFBasedGFlowNet now
  • safe_log_prob_min was not used anywhere.

to avoid numerical instability, otherwise uses -1e38.
constant_pb: Whether to ignore the backward policy estimator, e.g., if the
gflownet DAG is a tree, and pb is therefore always 1.
log_reward_clip_min: If finite, clips log rewards to this value.
"""

def __init__(
Expand All @@ -73,9 +71,8 @@ def __init__(
pb: Estimator | None,
logF: ScalarEstimator | ConditionalScalarEstimator,
forward_looking: bool = False,
log_reward_clip_min: float = -float("inf"),
safe_log_prob_min: bool = True,
constant_pb: bool = False,
log_reward_clip_min: float = -float("inf"),
) -> None:
"""Initializes a DBGFlowNet instance.

Expand All @@ -86,16 +83,16 @@ def __init__(
logF: A ScalarEstimator or ConditionalScalarEstimator for estimating the log
flow of the states.
forward_looking: Whether to use the forward-looking GFN loss.
log_reward_clip_min: If finite, clips log rewards to this value.
safe_log_prob_min: If True, uses -1e10 as the minimum log probability value
to avoid numerical instability, otherwise uses -1e38.
constant_pb: Whether to ignore the backward policy estimator, e.g., if the
gflownet DAG is a tree, and pb is therefore always 1. Must be set
explicitly by user to ensure that pb is an Estimator except under this
special case.
log_reward_clip_min: If finite, clips log rewards to this value.

"""
super().__init__(pf, pb, constant_pb=constant_pb)
super().__init__(
pf, pb, constant_pb=constant_pb, log_reward_clip_min=log_reward_clip_min
)

# Disallow recurrent PF for transition-based DB
from gfn.estimators import RecurrentDiscretePolicyEstimator # type: ignore
Expand All @@ -112,11 +109,6 @@ def __init__(

self.logF = logF
self.forward_looking = forward_looking
self.log_reward_clip_min = log_reward_clip_min
if safe_log_prob_min:
self.log_prob_min = -1e10
else:
self.log_prob_min = -1e38

def logF_named_parameters(self) -> dict[str, torch.Tensor]:
"""Returns a dictionary of named parameters containing 'logF' in their name.
Expand Down Expand Up @@ -191,14 +183,15 @@ def get_scores(
if len(states) == 0:
return torch.tensor(0.0, device=transitions.device)

# uncomment next line for debugging
# assert transitions.states.is_sink_state.equal(transitions.actions.is_dummy)
check_compatibility(states, actions, transitions)
assert (
not transitions.states.is_sink_state.any()
), "Transition from sink state is not allowed. This is a bug."

log_pf_actions, log_pb_actions = self.get_pfs_and_pbs(
transitions, recalculate_all_logprobs
)
### Compute log_pf and log_pb
log_pf, log_pb = self.get_pfs_and_pbs(transitions, recalculate_all_logprobs)

### Compute log_F_s
# LogF is potentially a conditional computation.
if transitions.conditions is not None:
with has_conditions_exception_handler("logF", self.logF):
Expand All @@ -207,50 +200,65 @@ def get_scores(
with no_conditions_exception_handler("logF", self.logF):
log_F_s = self.logF(states).squeeze(-1)

if self.forward_looking:
log_rewards = env.log_reward(states)
if math.isfinite(self.log_reward_clip_min):
log_rewards = log_rewards.clamp_min(self.log_reward_clip_min)
log_F_s = log_F_s + log_rewards

preds = log_pf_actions + log_F_s

# uncomment next line for debugging
# assert transitions.next_states.is_sink_state.equal(transitions.is_terminating)

# automatically removes invalid transitions (i.e. s_f -> s_f)
valid_next_states = transitions.next_states[~transitions.is_terminating]
valid_transitions_is_terminating = transitions.is_terminating[
~transitions.states.is_sink_state
]
### Compute log_F_s_next
log_F_s_next = torch.zeros_like(log_F_s)
is_terminating = transitions.is_terminating
is_intermediate = ~is_terminating

if len(valid_next_states) == 0:
return torch.tensor(0.0, device=transitions.device)

# LogF is potentially a conditional computation.
# Assign log_F_s_next for intermediate next states
interm_next_states = transitions.next_states[is_intermediate]
# log_F is potentially a conditional computation.
if transitions.conditions is not None:
with has_conditions_exception_handler("logF", self.logF):
valid_log_F_s_next = self.logF(
valid_next_states,
transitions.conditions[~transitions.is_terminating],
log_F_s_next[is_intermediate] = self.logF(
interm_next_states,
transitions.conditions[is_intermediate],
).squeeze(-1)
else:
with no_conditions_exception_handler("logF", self.logF):
valid_log_F_s_next = self.logF(valid_next_states).squeeze(-1)

log_F_s_next = torch.zeros_like(log_pb_actions)
log_F_s_next[~valid_transitions_is_terminating] = valid_log_F_s_next
assert transitions.log_rewards is not None
valid_transitions_log_rewards = transitions.log_rewards[
~transitions.states.is_sink_state
]
log_F_s_next[valid_transitions_is_terminating] = valid_transitions_log_rewards[
valid_transitions_is_terminating
]
targets = log_pb_actions + log_F_s_next
log_F_s_next[is_intermediate] = self.logF(interm_next_states).squeeze(-1)

scores = preds - targets
# Apply forward-looking if applicable
if self.forward_looking:
import warnings # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logging should have typing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im not sure I understand either. but why is type: ignore required for an import?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it was required to avoid a Python error, but it turned out to be okay without it. Fixed.


warnings.warn(
"Rewards should be defined over edges in forward-looking settings. "
"The current implementation is a special case of this, where the edge "
"reward is defined as the difference between the reward of two states "
"that the edge connects. If your environment is not the case, "
"forward-looking may be inappropriate."
)

# Reward calculation can also be conditional.
if transitions.conditions is not None:
log_rewards_state = env.log_reward(states, transitions.conditions) # type: ignore
log_rewards_next = env.log_reward(
interm_next_states, transitions.conditions[is_intermediate] # type: ignore
)
else:
log_rewards_state = env.log_reward(states)
log_rewards_next = env.log_reward(interm_next_states)
if math.isfinite(self.log_reward_clip_min):
log_rewards_state = log_rewards_state.clamp_min(self.log_reward_clip_min)
log_rewards_next = log_rewards_next.clamp_min(self.log_reward_clip_min)

log_F_s = log_F_s + log_rewards_state
log_F_s_next[is_intermediate] = (
log_F_s_next[is_intermediate] + log_rewards_next
)

# Assign log_F_s_next for terminating transitions as log_rewards
log_rewards = transitions.log_rewards
assert log_rewards is not None
if math.isfinite(self.log_reward_clip_min):
log_rewards = log_rewards.clamp_min(self.log_reward_clip_min)
log_F_s_next[is_terminating] = log_rewards[is_terminating]

### Compute scores
preds = log_pf + log_F_s
targets = log_pb + log_F_s_next
scores = preds - targets
assert scores.shape == (transitions.n_transitions,)
return scores

Expand Down
5 changes: 3 additions & 2 deletions src/gfn/gflownet/sub_trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,16 @@ def __init__(
special case.

"""
super().__init__(pf, pb, constant_pb=constant_pb)
super().__init__(
pf, pb, constant_pb=constant_pb, log_reward_clip_min=log_reward_clip_min
)
assert any(
isinstance(logF, cls)
for cls in [ScalarEstimator, ConditionalScalarEstimator]
), "logF must be a ScalarEstimator or derived"
self.logF = logF
self.weighting = weighting
self.lamda = lamda
self.log_reward_clip_min = log_reward_clip_min
self.forward_looking = forward_looking

def logF_named_parameters(self) -> dict[str, torch.Tensor]:
Expand Down
32 changes: 11 additions & 21 deletions src/gfn/gflownet/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ class TBGFlowNet(TrajectoryBasedGFlowNet):
pb: The backward policy estimator, or None if the gflownet DAG is a tree, and
pb is therefore always 1.
logZ: A learnable parameter or a ScalarEstimator instance (for conditional GFNs).
constant_pb: Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb
is therefore always 1. Must be set explicitly by user to ensure that pb
is an Estimator except under this special case.
log_reward_clip_min: If finite, clips log rewards to this value.
constant_pb: Whether the gflownet DAG is a tree, and pb is therefore always 1.
"""

def __init__(
Expand All @@ -45,8 +47,8 @@ def __init__(
pb: Estimator | None,
logZ: nn.Parameter | ScalarEstimator | None = None,
init_logZ: float = 0.0,
log_reward_clip_min: float = -float("inf"),
constant_pb: bool = False,
log_reward_clip_min: float = -float("inf"),
):
"""Initializes a TBGFlowNet instance.

Expand All @@ -57,15 +59,16 @@ def __init__(
logZ: A learnable parameter or a ScalarEstimator instance (for
conditional GFNs).
init_logZ: The initial value for the logZ parameter (used if logZ is None).
log_reward_clip_min: If finite, clips log rewards to this value.
constant_pb: Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb
is therefore always 1. Must be set explicitly by user to ensure that pb
is an Estimator except under this special case.
log_reward_clip_min: If finite, clips log rewards to this value.
"""
super().__init__(pf, pb, constant_pb=constant_pb)
super().__init__(
pf, pb, constant_pb=constant_pb, log_reward_clip_min=log_reward_clip_min
)

self.logZ = logZ or nn.Parameter(torch.tensor(init_logZ))
self.log_reward_clip_min = log_reward_clip_min

def logz_named_parameters(self) -> dict[str, torch.Tensor]:
"""Returns a dictionary of named parameters containing 'logZ' in their name.
Expand Down Expand Up @@ -138,25 +141,12 @@ class LogPartitionVarianceGFlowNet(TrajectoryBasedGFlowNet):
Attributes:
pf: The forward policy estimator.
pb: The backward policy estimator.
constant_pb: Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb
is therefore always 1. Must be set explicitly by user to ensure that pb
is an Estimator except under this special case.
log_reward_clip_min: If finite, clips log rewards to this value.
"""

def __init__(
self,
pf: Estimator,
pb: Estimator,
log_reward_clip_min: float = -float("inf"),
):
"""Initializes a LogPartitionVarianceGFlowNet instance.

Args:
pf: The forward policy estimator.
pb: The backward policy estimator.
log_reward_clip_min: If finite, clips log rewards to this value.
"""
super().__init__(pf, pb)
self.log_reward_clip_min = log_reward_clip_min

def loss(
self,
env: Env,
Expand Down
4 changes: 3 additions & 1 deletion tutorials/examples/train_hypergrid_gafn.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ def __init__(
flow_estimator: The flow estimator, required if use_edge_ri is True.
log_reward_clip_min: If finite, clips log rewards to this value.
"""
super().__init__(pf, pb, logZ, init_logZ, log_reward_clip_min)
super().__init__(
pf, pb, logZ, init_logZ, log_reward_clip_min=log_reward_clip_min
)
self.rnd = rnd
self.use_edge_ri = use_edge_ri
if use_edge_ri and flow_estimator is None:
Expand Down
Loading