Skip to content
Draft
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
57d833d
Update cli_args.py
morgan-heisler Apr 20, 2026
6b60cc1
Update cli_args.py
morgan-heisler Apr 20, 2026
e4b7294
Update functional.py
morgan-heisler Apr 20, 2026
5b65e25
Update functional.py
morgan-heisler Apr 20, 2026
5d37fa0
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
0981cb5
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
b20e844
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
8545230
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
74b5952
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
5b2dbc7
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
646f47a
Update functional.py
morgan-heisler Apr 20, 2026
63f3b7e
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
24f1604
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
fe59591
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
c44a2cf
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
c1e9e81
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
2e094bf
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
d87ae4f
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
e91b6b3
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
5874f3f
Update functional.py
morgan-heisler Apr 20, 2026
0b78a08
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
9674738
Update areal/utils/functional/functional.py
morgan-heisler Apr 20, 2026
a050b21
Update functional.py
morgan-heisler Apr 20, 2026
a9595b6
Update functional.py
morgan-heisler Apr 20, 2026
1fa9387
Merge branch 'main' into two-stage-sampling
morgan-heisler Apr 21, 2026
b854527
files changes after running `pre-commit run --all-files`
morgan-heisler Apr 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,6 +1241,7 @@ class RejectionSamplingConfig:
For KL metrics, aggregation is arithmetic.
upper: Upper bound for filtering.
lower: Lower bound for filtering (optional).
token_action: Enables two-stage rejection sampling and importance sampling mode ('mask' or 'clamp')(optional).
"""

level: str = field(
Expand Down Expand Up @@ -1309,6 +1310,27 @@ class RejectionSamplingConfig:
"For 'kl_k1' metric: can be used to filter negative KL estimates."
},
)
token_action: str | None = field(
default=None,
metadata={
"help": (
"Enables two-stage Geo-RS + Token-MIS/TIS mode. "
"Only valid when level='sequence' and metric='ratio'. "
"Stage 1 (Geo-RS): sequences whose geometric-mean importance ratio "
"exceeds `upper` are fully rejected (loss_mask zeroed for all tokens). "
"Stage 2: on accepted sequences, apply per-token correction using "
"this action — 'mask' (Token-MIS: zero tokens where per-token "
"ratio > upper) or 'clamp' (Token-TIS: clamp per-token ratio to "
"[lower, upper]). "
"None disables Stage 2 (pure sequence-level Geo-RS only). "
"Experimental results (PR #1084) show that neither Geo-RS alone "
"nor Token-MIS/TIS alone is stable under severe off-policy drift; "
"the two-stage combination is necessary for both grad_norm and "
"approx_kl stability."
),
"choices": ["mask", "clamp"],
},
)

def __post_init__(self):
"""Validate configuration."""
Expand Down Expand Up @@ -1377,6 +1399,30 @@ def __post_init__(self):
UserWarning,
stacklevel=2,
)
# Validate two-stage (Geo-RS + Token-MIS/TIS) constraints.
if self.token_action is not None:
_VALID_TOKEN_ACTIONS = ("mask", "clamp")
if self.token_action not in _VALID_TOKEN_ACTIONS:
raise ValueError(
f"token_action must be one of {_VALID_TOKEN_ACTIONS} or None, "
f"got '{self.token_action}'"
)
if self.level != "sequence":
raise ValueError(
"token_action (two-stage Geo-RS + Token-MIS/TIS) requires "
f"level='sequence'. Got level='{self.level}'."
)
if self.metric != "ratio":
raise ValueError(
"token_action (two-stage Geo-RS + Token-MIS/TIS) requires "
f"metric='ratio'. Got metric='{self.metric}'."
)
if self.action != "mask":
raise ValueError(
"token_action (two-stage mode) requires action='mask' for "
"the sequence-level stage (hard Geo-RS rejection). "
f"Got action='{self.action}'."
)


@dataclass
Expand Down
53 changes: 53 additions & 0 deletions areal/utils/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,26 @@ def apply_rejection_sampling(
),
behave_imp_weight,
)
# ── Stage 2: Token-MIS/TIS (1D packed) ──────────────────────────────
if config.token_action is not None:
# behave_imp_weight holds per-token ratios.
# Shape: [total_tokens] in 1D packed format.
# token_ratio = behave_imp_weight
token_ratio = torch.exp(log_ratio)
Comment thread
morgan-heisler marked this conversation as resolved.
Outdated
if config.token_action == "mask":
token_oor = token_ratio > config.upper
if config.lower is not None:
token_oor = token_oor | (token_ratio < config.lower)
loss_mask = loss_mask * (~token_oor).to(loss_mask.dtype)
behave_imp_weight = behave_imp_weight * (~token_oor).to(
behave_imp_weight.dtype
)
Comment thread
morgan-heisler marked this conversation as resolved.
Outdated
elif config.token_action == "clamp":
clamp_lower = config.lower if config.lower is not None else 0.0
behave_imp_weight = token_ratio.clamp(
min=clamp_lower, max=config.upper
)
# ── End Stage 2 ──────────────────────────────────────────────────────
else:
# 2D padded format
agg_values = log_ratio if _use_log_agg else metric
Expand Down Expand Up @@ -386,6 +406,39 @@ def apply_rejection_sampling(
),
behave_imp_weight,
)

# ── Stage 2: Token-MIS/TIS on Geo-RS-accepted sequences ─────────────
# Runs only in two-stage mode (config.token_action is not None).
# At this point, loss_mask already has Stage-1-rejected sequences
# zeroed out. We now apply per-token filtering on surviving tokens.
if config.token_action is not None:
# behave_imp_weight holds per-token ratios π_prox / π_behave.
# Shape: [batch, seq_len] in 2D padded format.
# token_ratio = behave_imp_weight
token_ratio = torch.exp(log_ratio)

if config.token_action == "mask":
# Token-MIS: zero out tokens where the per-token ratio exceeds
# upper, and optionally where it falls below lower.
# This suppresses tokens where the current policy has drifted
# far from the behavior policy at the token level.
token_oor = token_ratio > config.upper
if config.lower is not None:
token_oor = token_oor | (token_ratio < config.lower)
loss_mask = loss_mask * (~token_oor).to(loss_mask.dtype)
behave_imp_weight = behave_imp_weight * (~token_oor).to(
behave_imp_weight.dtype
)

elif config.token_action == "clamp":
# Token-TIS: clamp the per-token importance ratio to
# [lower, upper]. All tokens remain in the gradient but
# their contribution is bounded.
clamp_lower = config.lower if config.lower is not None else 0.0
behave_imp_weight = token_ratio.clamp(
min=clamp_lower, max=config.upper
)
# ── End Stage 2 ──────────────────────────────────────────────────────
else:
# Token level
if config.action == "mask":
Expand Down
Loading