Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
57d833d
Update cli_args.py
morgan-heisler Apr 20, 2026
6b60cc1
Update cli_args.py
morgan-heisler Apr 20, 2026
e4b7294
Update functional.py
morgan-heisler Apr 20, 2026
5b65e25
Update functional.py
morgan-heisler Apr 20, 2026
5d37fa0
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
0981cb5
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
b20e844
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
8545230
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
74b5952
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
5b2dbc7
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
646f47a
Update functional.py
morgan-heisler Apr 20, 2026
63f3b7e
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
24f1604
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
fe59591
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
c44a2cf
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
c1e9e81
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
2e094bf
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
d87ae4f
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
e91b6b3
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
5874f3f
Update functional.py
morgan-heisler Apr 20, 2026
0b78a08
Update test_rejection_sampling.py
morgan-heisler Apr 20, 2026
9674738
Update areal/utils/functional/functional.py
morgan-heisler Apr 20, 2026
a050b21
Update functional.py
morgan-heisler Apr 20, 2026
a9595b6
Update functional.py
morgan-heisler Apr 20, 2026
1fa9387
Merge branch 'main' into two-stage-sampling
morgan-heisler Apr 21, 2026
b854527
files changes after running `pre-commit run --all-files`
morgan-heisler Apr 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,6 +1241,7 @@ class RejectionSamplingConfig:
For KL metrics, aggregation is arithmetic.
upper: Upper bound for filtering.
lower: Lower bound for filtering (optional).
token_action: Enables two-stage rejection sampling and importance sampling mode ('mask' or 'clamp')(optional).
"""

level: str = field(
Expand Down Expand Up @@ -1309,6 +1310,27 @@ class RejectionSamplingConfig:
"For 'kl_k1' metric: can be used to filter negative KL estimates."
},
)
token_action: str | None = field(
default=None,
metadata={
"help": (
"Enables two-stage Geo-RS + Token-MIS/TIS mode. "
"Only valid when level='sequence' and metric='ratio'. "
"Stage 1 (Geo-RS): sequences whose geometric-mean importance ratio "
"exceeds `upper` are fully rejected (loss_mask zeroed for all tokens). "
"Stage 2: on accepted sequences, apply per-token correction using "
"this action — 'mask' (Token-MIS: zero tokens where per-token "
"ratio > upper) or 'clamp' (Token-TIS: clamp per-token ratio to "
"[lower, upper]). "
"None disables Stage 2 (pure sequence-level Geo-RS only). "
"Experimental results (PR #1084) show that neither Geo-RS alone "
"nor Token-MIS/TIS alone is stable under severe off-policy drift; "
"the two-stage combination is necessary for both grad_norm and "
"approx_kl stability."
),
"choices": ["mask", "clamp"],
},
)

def __post_init__(self):
"""Validate configuration."""
Expand Down Expand Up @@ -1377,6 +1399,30 @@ def __post_init__(self):
UserWarning,
stacklevel=2,
)
# Validate two-stage (Geo-RS + Token-MIS/TIS) constraints.
if self.token_action is not None:
_VALID_TOKEN_ACTIONS = ("mask", "clamp")
if self.token_action not in _VALID_TOKEN_ACTIONS:
raise ValueError(
f"token_action must be one of {_VALID_TOKEN_ACTIONS} or None, "
f"got '{self.token_action}'"
)
if self.level != "sequence":
raise ValueError(
"token_action (two-stage Geo-RS + Token-MIS/TIS) requires "
f"level='sequence'. Got level='{self.level}'."
)
if self.metric != "ratio":
raise ValueError(
"token_action (two-stage Geo-RS + Token-MIS/TIS) requires "
f"metric='ratio'. Got metric='{self.metric}'."
)
if self.action != "mask":
raise ValueError(
"token_action (two-stage mode) requires action='mask' for "
"the sequence-level stage (hard Geo-RS rejection). "
f"Got action='{self.action}'."
)


@dataclass
Expand Down
52 changes: 52 additions & 0 deletions areal/utils/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def apply_rejection_sampling(
behave_imp_weight = torch.exp(log_ratio)
# Save original weight before any clamping, to compute clamped fraction later.
original_weight = behave_imp_weight
per_token_ratio = behave_imp_weight.clone()

# Step 4: Aggregate and filter
#
Expand Down Expand Up @@ -342,6 +343,25 @@ def apply_rejection_sampling(
),
behave_imp_weight,
)
# ── Stage 2: Token-MIS/TIS (1D packed) ──────────────────────────────
if config.token_action is not None:
# behave_imp_weight holds per-token ratios.
# Shape: [total_tokens] in 1D packed format.
token_ratio = per_token_ratio
if config.token_action == "mask":
token_oor = token_ratio > config.upper
if config.lower is not None:
token_oor = token_oor | (token_ratio < config.lower)
loss_mask = loss_mask * (~token_oor).to(loss_mask.dtype)
behave_imp_weight = token_ratio * (~token_oor).to(
behave_imp_weight.dtype
)
elif config.token_action == "clamp":
clamp_lower = config.lower if config.lower is not None else 0.0
behave_imp_weight = token_ratio.clamp(
min=clamp_lower, max=config.upper
)
# ── End Stage 2 ──────────────────────────────────────────────────────
else:
# 2D padded format
agg_values = log_ratio if _use_log_agg else metric
Expand Down Expand Up @@ -386,6 +406,38 @@ def apply_rejection_sampling(
),
behave_imp_weight,
)

# ── Stage 2: Token-MIS/TIS on Geo-RS-accepted sequences ─────────────
# Runs only in two-stage mode (config.token_action is not None).
# At this point, loss_mask already has Stage-1-rejected sequences
# zeroed out. We now apply per-token filtering on surviving tokens.
if config.token_action is not None:
# behave_imp_weight holds per-token ratios π_prox / π_behave.
# Shape: [batch, seq_len] in 2D padded format.
token_ratio = per_token_ratio

if config.token_action == "mask":
# Token-MIS: zero out tokens where the per-token ratio exceeds
# upper, and optionally where it falls below lower.
# This suppresses tokens where the current policy has drifted
# far from the behavior policy at the token level.
token_oor = token_ratio > config.upper
if config.lower is not None:
token_oor = token_oor | (token_ratio < config.lower)
loss_mask = loss_mask * (~token_oor).to(loss_mask.dtype)
behave_imp_weight = token_ratio * (~token_oor).to(
behave_imp_weight.dtype
)

elif config.token_action == "clamp":
# Token-TIS: clamp the per-token importance ratio to
# [lower, upper]. All tokens remain in the gradient but
# their contribution is bounded.
clamp_lower = config.lower if config.lower is not None else 0.0
behave_imp_weight = token_ratio.clamp(
min=clamp_lower, max=config.upper
)
# ── End Stage 2 ──────────────────────────────────────────────────────
else:
# Token level
if config.action == "mask":
Expand Down
18 changes: 10 additions & 8 deletions docs/en/cli_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -1034,16 +1034,18 @@ Attributes:
For KL metrics, aggregation is arithmetic.
upper: Upper bound for filtering.
lower: Lower bound for filtering (optional).
token_action: Enables two-stage rejection sampling and importance sampling mode ('mask' or 'clamp')(optional).
```

| Parameter | Type | Default | Description |
| --------- | ------------- | --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `level` | string | `"token"` | Filtering granularity. 'token': per-token filtering (each token judged independently). 'sequence': per-sequence filtering (all tokens in a sequence share the same fate). When metric='ratio', both the filtering decision and the correction weight (behave_imp_weight) operate at sequence level using the geometric mean. **Choices:** `token`, `sequence` |
| `action` | string | `"mask"` | Action to take when metric exceeds threshold. 'mask': zero out loss_mask for filtered tokens/sequences (rejection, completely excludes from gradient computation). 'clamp': clamp importance weight to \[lower, upper\] bounds (truncation, tokens still participate in gradient but with bounded weight). **Choices:** `mask`, `clamp` |
| `metric` | string | `"ratio"` | Divergence metric for filtering. 'ratio': direct importance ratio π_proximal/π_behave. 'kl_k1': KL estimator k1 = log(r), forward KL unbiased estimator (can be negative). 'kl_k2': KL estimator k2 = 0.5 * (log r)^2, non-negative quadratic approximation. 'kl_k3': KL estimator k3 = r - log(r) - 1, non-negative exact forward KL estimator. **Choices:** `ratio`, `kl_k1`, `kl_k2`, `kl_k3` |
| `agg` | string | `"mean"` | Aggregation method for sequence-level filtering. Only used when level='sequence'. For 'ratio' metric, aggregation is in log space: 'sum' = exp(sum(log(r_i))), 'mean' = exp(mean(log(r_i))) = geometric mean (length-invariant, consistent with GSPO). For KL metrics, aggregation is arithmetic: 'sum' = sum(kl_i), 'mean' = mean(kl_i). 'max': max of per-token metric values (most conservative). **Choices:** `sum`, `mean`, `max` |
| `upper` | float | `5.0` | Upper bound for filtering. Tokens/sequences with metric > upper are filtered out (loss_mask zeroed). For 'ratio' metric: must be > 1.0, typical values are 2.0 or 5.0. For 'kl_k2'/'kl_k3' metrics: typical values are 0.5-2.0. |
| `lower` | float \| None | `None` | Lower bound for filtering (optional). None means no lower bound. For 'ratio' metric: typical value is 0.5 (filter out tokens where policy probability dropped significantly). Must be > 0. For 'kl_k1' metric: can be used to filter negative KL estimates. |
| Parameter | Type | Default | Description |
| -------------- | -------------- | --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `level` | string | `"token"` | Filtering granularity. 'token': per-token filtering (each token judged independently). 'sequence': per-sequence filtering (all tokens in a sequence share the same fate). When metric='ratio', both the filtering decision and the correction weight (behave_imp_weight) operate at sequence level using the geometric mean. **Choices:** `token`, `sequence` |
| `action` | string | `"mask"` | Action to take when metric exceeds threshold. 'mask': zero out loss_mask for filtered tokens/sequences (rejection, completely excludes from gradient computation). 'clamp': clamp importance weight to \[lower, upper\] bounds (truncation, tokens still participate in gradient but with bounded weight). **Choices:** `mask`, `clamp` |
| `metric` | string | `"ratio"` | Divergence metric for filtering. 'ratio': direct importance ratio π_proximal/π_behave. 'kl_k1': KL estimator k1 = log(r), forward KL unbiased estimator (can be negative). 'kl_k2': KL estimator k2 = 0.5 * (log r)^2, non-negative quadratic approximation. 'kl_k3': KL estimator k3 = r - log(r) - 1, non-negative exact forward KL estimator. **Choices:** `ratio`, `kl_k1`, `kl_k2`, `kl_k3` |
| `agg` | string | `"mean"` | Aggregation method for sequence-level filtering. Only used when level='sequence'. For 'ratio' metric, aggregation is in log space: 'sum' = exp(sum(log(r_i))), 'mean' = exp(mean(log(r_i))) = geometric mean (length-invariant, consistent with GSPO). For KL metrics, aggregation is arithmetic: 'sum' = sum(kl_i), 'mean' = mean(kl_i). 'max': max of per-token metric values (most conservative). **Choices:** `sum`, `mean`, `max` |
| `upper` | float | `5.0` | Upper bound for filtering. Tokens/sequences with metric > upper are filtered out (loss_mask zeroed). For 'ratio' metric: must be > 1.0, typical values are 2.0 or 5.0. For 'kl_k2'/'kl_k3' metrics: typical values are 0.5-2.0. |
| `lower` | float \| None | `None` | Lower bound for filtering (optional). None means no lower bound. For 'ratio' metric: typical value is 0.5 (filter out tokens where policy probability dropped significantly). Must be > 0. For 'kl_k1' metric: can be used to filter negative KL estimates. |
| `token_action` | string \| None | `None` | Enables two-stage Geo-RS + Token-MIS/TIS mode. Only valid when level='sequence' and metric='ratio'. Stage 1 (Geo-RS): sequences whose geometric-mean importance ratio exceeds `upper` are fully rejected (loss_mask zeroed for all tokens). Stage 2: on accepted sequences, apply per-token correction using this action — 'mask' (Token-MIS: zero tokens where per-token ratio > upper) or 'clamp' (Token-TIS: clamp per-token ratio to \[lower, upper\]). None disables Stage 2 (pure sequence-level Geo-RS only). Experimental results (PR #1084) show that neither Geo-RS alone nor Token-MIS/TIS alone is stable under severe off-policy drift; the two-stage combination is necessary for both grad_norm and approx_kl stability. **Choices:** `mask`, `clamp` |

(section-scheduler)=

Expand Down
Loading