Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 55 additions & 8 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,14 +421,28 @@ def local_search(
- A Trajectories object refined by local search
- A boolean tensor indicating which trajectories were updated
"""
# High-level outline:
# 1) Choose the backtrack length K per-trajectory (either from `back_steps` or
# via a ratio of current lengths). This determines how much prefix to keep.
# 2) Sample backward trajectories from terminal states using the backward policy,
# then reverse them into forward-time to obtain the prefix trajectories.
# 3) Extract the junction states at step `n_prevs = L - K - 1` and reconstruct
# forward suffixes using the forward policy starting from those junctions.
# 4) Optionally compute PF/PB per-step log-probabilities for MH acceptance.
# 5) Splice prefix and suffix into candidate new trajectories.
# 6) Accept/reject (MH or greedy by reward), return the candidates and update mask.
# TODO: Implement local search for GraphStates.
if isinstance(env.States, GraphStates):
raise NotImplementedError("Local search is not implemented for GraphStates.")

# Ensure PF/PB log-probabilities are computed when MH acceptance is requested.
save_logprobs = save_logprobs or use_metropolis_hastings

# K-step backward sampling with the backward estimator,
# where K is the number of backward steps used in https://arxiv.org/abs/2202.01361.
# Compute per-trajectory backtrack length K. When specified via `back_ratio`,
# K is proportional to the previous trajectory length; otherwise clamp the
# provided `back_steps` to valid bounds.
if back_steps is None:
assert (
back_ratio is not None and 0 < back_ratio <= 1
Expand All @@ -441,6 +455,7 @@ def local_search(
back_steps,
)

# 1) Backward sampling from terminal states (PolicyMixin-driven Sampler).
prev_trajectories = self.backward_sampler.sample_trajectories(
env,
states=trajectories.terminating_states,
Expand All @@ -454,16 +469,20 @@ def local_search(
# This is called `prev_trajectories` since they are the trajectories before
# the local search. The `new_trajectories` will be obtained by performing local
# search on them.
# Convert backward trajectories to forward-time ordering (s0 -> ... -> sf).
prev_trajectories = prev_trajectories.reverse_backward_trajectories()
assert prev_trajectories.log_rewards is not None

# Reconstructing with self.estimator
# 2) Derive junction positions and gather the junction states (one per traj).
n_prevs = prev_trajectories.terminating_idx - K - 1
junction_states_tsr = torch.gather(
prev_trajectories.states.tensor,
0,
(n_prevs).view(1, -1, 1).expand(-1, -1, *trajectories.states.state_shape),
).squeeze(0)
# 3) Reconstruct forward suffixes starting from the junction states using
# the forward policy estimator owned by `self`.
recon_trajectories = super().sample_trajectories(
env,
states=env.states_from_tensor(junction_states_tsr),
Expand All @@ -474,6 +493,7 @@ def local_search(
)

# Calculate the log probabilities as needed.
# 4) PF on prefix and reconstructed suffix (needed for MH or for logging).
prev_trajectories_log_pf = (
get_trajectory_pfs(pf=self.estimator, trajectories=prev_trajectories)
if save_logprobs
Expand All @@ -484,6 +504,7 @@ def local_search(
if save_logprobs
else None
)
# 5) PB on prefix and reconstructed suffix (needed only for MH acceptance).
prev_trajectories_log_pb = (
get_trajectory_pbs(
pb=self.backward_sampler.estimator, trajectories=prev_trajectories
Expand All @@ -499,6 +520,7 @@ def local_search(
else None
)

# 6) Splice prefix and suffix into candidate trajectories (and aligned PF/PB).
(
new_trajectories,
new_trajectories_log_pf,
Expand All @@ -514,6 +536,9 @@ def local_search(
debug=debug,
)

# 7) Accept/reject. With MH, accept with probability:
# min\{1, R(x') p_B(x->s') p_F(s'->x') / [R(x) p_B(x'->s') p_F(s'->x)]\}.
# Without MH, accept when the episodic reward improves (ties accepted).
if use_metropolis_hastings:
assert (
prev_trajectories_log_pb is not None
Expand All @@ -529,6 +554,8 @@ def local_search(
# = p_B(x->s')p_F(s'->x') / p_B(x'->s')p_F(s'->x)
# = p_B(x->s'->s0)p_F(s0->s'->x') / p_B(x'->s'->s0)p_F(s0->s'->x)
# = p_B(tau|x)p_F(tau') / p_B(tau'|x')p_F(tau)
# Combine episodic reward and log-prob sums, clamp at 0 (min with 1 in prob
# space).
log_accept_ratio = torch.clamp_max(
new_trajectories.log_rewards
+ prev_trajectories_log_pb.sum(0)
Expand Down Expand Up @@ -601,6 +628,9 @@ def sample_trajectories(
The final trajectories container contains both the initial trajectories
and the improved trajectories from local search.
"""
# Roll out an initial batch with the forward policy, then perform
# `n_local_search_loops` rounds of refinement. Each round appends
# one candidate per original seed trajectory to the container.
trajectories = super().sample_trajectories(
env,
n,
Expand All @@ -614,10 +644,12 @@ def sample_trajectories(
if n is None:
n = int(trajectories.n_trajectories)

# Indices referring to the current seed trajectories within the container.
# Initially these are the first `n` entries (the initial batch).
search_indices = torch.arange(n, device=trajectories.states.device)

for it in range(n_local_search_loops):
# Search phase
# Run a single local-search refinement on the current seeds.
ls_trajectories, is_updated = self.local_search(
env,
trajectories[search_indices],
Expand All @@ -629,8 +661,11 @@ def sample_trajectories(
use_metropolis_hastings,
**policy_kwargs,
)
# Append refined candidates; they occupy a new contiguous block at the end.
trajectories.extend(ls_trajectories)

# Map accepted seeds to the indices of the just-appended block so that
# the next round uses the latest accepted candidates as seeds.
last_indices = torch.arange(
n * it, n * (it + 1), device=trajectories.states.device
)
Expand Down Expand Up @@ -683,6 +718,10 @@ def _combine_prev_and_recon_trajectories( # noqa: C901
trajectory segments. The debug mode compares the vectorized approach
with a for-loop implementation to ensure correctness.
"""
# Goal: splice each trajectory's prefix (from backward sampling, now forward-ordered)
# with its reconstructed suffix (from the forward policy), starting at the
# junction step `n_prevs[i]`. We mirror this splice on PF/PB per-step tensors
# when they are provided.
new_trajectories_log_pf = None
new_trajectories_log_pb = None

Expand All @@ -691,6 +730,8 @@ def _combine_prev_and_recon_trajectories( # noqa: C901
env = prev_trajectories.env

# Obtain full trajectories by concatenating the backward and forward parts.
# Determine per-batch prefix lengths (n_prevs) and suffix lengths (n_recons),
# plus their maxima to size the output tensors.
max_n_prev = n_prevs.max()
n_recons = recon_trajectories.terminating_idx
max_n_recon = n_recons.max()
Expand All @@ -700,6 +741,10 @@ def _combine_prev_and_recon_trajectories( # noqa: C901
max_traj_len = int(new_trajectories_dones.max().item())

# Create helper indices and masks
# Build helper indices and masks over (time, batch). `prev_mask` selects the
# prefix region; `state_recon_mask`/`action_recon_mask` select the suffix
# region for states/actions respectively. The corresponding `*_mask2` versions
# index into the recon tensors without offset.
idx = torch.arange(max_traj_len + 1).unsqueeze(1).expand(-1, bs).to(n_prevs)

prev_mask = idx < n_prevs
Expand All @@ -708,7 +753,7 @@ def _combine_prev_and_recon_trajectories( # noqa: C901
action_recon_mask = (idx[:-1] >= n_prevs) * (idx[:-1] <= n_prevs + n_recons - 1)
action_recon_mask2 = idx[:max_n_recon] <= n_recons - 1

# Transpose for easier indexing
# Transpose to (batch, time, ...) for efficient advanced indexing.
prev_trajectories_states_tsr = prev_trajectories.states.tensor.transpose(0, 1)
prev_trajectories_actions_tsr = prev_trajectories.actions.tensor.transpose(0, 1)
recon_trajectories_states_tsr = recon_trajectories.states.tensor.transpose(0, 1)
Expand All @@ -721,16 +766,16 @@ def _combine_prev_and_recon_trajectories( # noqa: C901
action_recon_mask = action_recon_mask.transpose(0, 1)
action_recon_mask2 = action_recon_mask2.transpose(0, 1)

# Prepare the new states and actions
# Note that these are initialized in transposed shapes
# Prepare output state/action tensors in transposed shapes. Initialize to
# sink/dummy values and fill prefix/suffix segments below.
new_trajectories_states_tsr = env.sf.repeat(bs, max_traj_len + 1, 1).to(
prev_trajectories.states.tensor
)
new_trajectories_actions_tsr = env.dummy_action.repeat(bs, max_traj_len, 1).to(
prev_trajectories.actions.tensor
)

# Assign the first part (backtracked from backward policy) of the trajectory
# Assign prefix segment from `prev_trajectories` using `prev_mask`.
prev_mask_truc = prev_mask[:, :max_n_prev]
new_trajectories_states_tsr[prev_mask] = prev_trajectories_states_tsr[
:, :max_n_prev
Expand All @@ -751,7 +796,7 @@ def _combine_prev_and_recon_trajectories( # noqa: C901
new_trajectories_states_tsr = new_trajectories_states_tsr.transpose(0, 1)
new_trajectories_actions_tsr = new_trajectories_actions_tsr.transpose(0, 1)

# Similarly, combine log_pf and log_pb if needed
# Similarly, splice PF/PB per-step log-probabilities if they were provided.
if (
prev_trajectories_log_pf is not None
and recon_trajectories_log_pf is not None
Expand Down Expand Up @@ -823,15 +868,15 @@ def _combine_prev_and_recon_trajectories( # noqa: C901
for i in range(bs):
_n_prev = n_prevs[i]

# Backward part
# Backward part (prefix)
_new_trajectories_states_tsr[: _n_prev + 1, i] = (
prev_trajectories.states.tensor[: _n_prev + 1, i]
)
_new_trajectories_actions_tsr[:_n_prev, i] = (
prev_trajectories.actions.tensor[:_n_prev, i]
)

# Forward part
# Forward part (suffix)
_len_recon = recon_trajectories.terminating_idx[i]
_new_trajectories_states_tsr[
_n_prev + 1 : _n_prev + _len_recon + 1, i
Expand Down Expand Up @@ -876,6 +921,8 @@ def _combine_prev_and_recon_trajectories( # noqa: C901
):
assert torch.all(_new_trajectories_log_pb == new_trajectories_log_pb)

# Materialize the spliced trajectories container (forward-time), carrying over
# the prefix conditioning and episodic reward from the reconstructed suffix.
new_trajectories = Trajectories(
env=env,
states=env.states_from_tensor(new_trajectories_states_tsr),
Expand Down
Loading