From 57d833d2e8608869bae50ddbeff54991f512c02a Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 11:09:39 -0700 Subject: [PATCH 01/25] Update cli_args.py --- areal/api/cli_args.py | 1 + 1 file changed, 1 insertion(+) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index b0bcc4c75d..69c3be5e3d 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 + import argparse import json import os From 6b60cc165cb57918323f7566464cd5fdc7f7cd5a Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 11:10:04 -0700 Subject: [PATCH 02/25] Update cli_args.py --- areal/api/cli_args.py | 47 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 69c3be5e3d..2b416edd23 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 - import argparse import json import os @@ -1242,6 +1241,7 @@ class RejectionSamplingConfig: For KL metrics, aggregation is arithmetic. upper: Upper bound for filtering. lower: Lower bound for filtering (optional). + token_action: Enables two-stage rejection sampling and importance sampling mode ('mask' or 'clamp')(optional). """ level: str = field( @@ -1310,6 +1310,27 @@ class RejectionSamplingConfig: "For 'kl_k1' metric: can be used to filter negative KL estimates." }, ) + token_action: str | None = field( + default=None, + metadata={ + "help": ( + "Enables two-stage Geo-RS + Token-MIS/TIS mode. " + "Only valid when level='sequence' and metric='ratio'. " + "Stage 1 (Geo-RS): sequences whose geometric-mean importance ratio " + "exceeds `upper` are fully rejected (loss_mask zeroed for all tokens). " + "Stage 2: on accepted sequences, apply per-token correction using " + "this action — 'mask' (Token-MIS: zero tokens where per-token " + "ratio > upper) or 'clamp' (Token-TIS: clamp per-token ratio to " + "[lower, upper]). " + "None disables Stage 2 (pure sequence-level Geo-RS only). " + "Experimental results (PR #1084) show that neither Geo-RS alone " + "nor Token-MIS/TIS alone is stable under severe off-policy drift; " + "the two-stage combination is necessary for both grad_norm and " + "approx_kl stability." + ), + "choices": ["mask", "clamp"], + }, + ) def __post_init__(self): """Validate configuration.""" @@ -1378,6 +1399,30 @@ def __post_init__(self): UserWarning, stacklevel=2, ) + # Validate two-stage (Geo-RS + Token-MIS/TIS) constraints. + if self.token_action is not None: + _VALID_TOKEN_ACTIONS = ("mask", "clamp") + if self.token_action not in _VALID_TOKEN_ACTIONS: + raise ValueError( + f"token_action must be one of {_VALID_TOKEN_ACTIONS} or None, " + f"got '{self.token_action}'" + ) + if self.level != "sequence": + raise ValueError( + "token_action (two-stage Geo-RS + Token-MIS/TIS) requires " + f"level='sequence'. Got level='{self.level}'." + ) + if self.metric != "ratio": + raise ValueError( + "token_action (two-stage Geo-RS + Token-MIS/TIS) requires " + f"metric='ratio'. Got metric='{self.metric}'." + ) + if self.action != "mask": + raise ValueError( + "token_action (two-stage mode) requires action='mask' for " + "the sequence-level stage (hard Geo-RS rejection). " + f"Got action='{self.action}'." + ) @dataclass From e4b7294a614e21408b989f3ff9300883ce971f2c Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 11:19:40 -0700 Subject: [PATCH 03/25] Update functional.py --- areal/utils/functional/functional.py | 32 ++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/areal/utils/functional/functional.py b/areal/utils/functional/functional.py index a79aab6024..e88cb94619 100644 --- a/areal/utils/functional/functional.py +++ b/areal/utils/functional/functional.py @@ -386,6 +386,38 @@ def apply_rejection_sampling( ), behave_imp_weight, ) + + # ── Stage 2: Token-MIS/TIS on Geo-RS-accepted sequences ───────────── + # Runs only in two-stage mode (cfg.token_action is not None). + # At this point, loss_mask already has Stage-1-rejected sequences + # zeroed out. We now apply per-token filtering on surviving tokens. + if cfg.token_action is not None: + # behave_imp_weight holds per-token ratios π_prox / π_behave. + # Shape: [batch, seq_len] in 2D padded format. + token_ratio = behave_imp_weight + + if cfg.token_action == "mask": + # Token-MIS: zero out tokens where the per-token ratio exceeds + # upper, and optionally where it falls below lower. + # This suppresses tokens where the current policy has drifted + # far from the behavior policy at the token level. + token_oor = token_ratio > cfg.upper + if cfg.lower is not None: + token_oor = token_oor | (token_ratio < cfg.lower) + loss_mask = loss_mask * (~token_oor).to(loss_mask.dtype) + behave_imp_weight = behave_imp_weight * (~token_oor).to( + behave_imp_weight.dtype + ) + + elif cfg.token_action == "clamp": + # Token-TIS: clamp the per-token importance ratio to + # [lower, upper]. All tokens remain in the gradient but + # their contribution is bounded. + clamp_lower = cfg.lower if cfg.lower is not None else 0.0 + behave_imp_weight = token_ratio.clamp( + min=clamp_lower, max=cfg.upper + ) + # ── End Stage 2 ────────────────────────────────────────────────────── else: # Token level if config.action == "mask": From 5b65e25df2d4bd8643a4d9a52f369663ebff05dc Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 11:22:48 -0700 Subject: [PATCH 04/25] Update functional.py --- areal/utils/functional/functional.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/areal/utils/functional/functional.py b/areal/utils/functional/functional.py index e88cb94619..e5d4b64fb2 100644 --- a/areal/utils/functional/functional.py +++ b/areal/utils/functional/functional.py @@ -342,6 +342,25 @@ def apply_rejection_sampling( ), behave_imp_weight, ) + # ── Stage 2: Token-MIS/TIS (1D packed) ────────────────────────────── + if cfg.token_action is not None: + # behave_imp_weight holds per-token ratios. + # Shape: [total_tokens] in 1D packed format. + token_ratio = behave_imp_weight + if cfg.token_action == "mask": + token_oor = token_ratio > cfg.upper + if cfg.lower is not None: + token_oor = token_oor | (token_ratio < cfg.lower) + loss_mask = loss_mask * (~token_oor).to(loss_mask.dtype) + behave_imp_weight = behave_imp_weight * (~token_oor).to( + behave_imp_weight.dtype + ) + elif cfg.token_action == "clamp": + clamp_lower = cfg.lower if cfg.lower is not None else 0.0 + behave_imp_weight = token_ratio.clamp( + min=clamp_lower, max=cfg.upper + ) + # ── End Stage 2 ────────────────────────────────────────────────────── else: # 2D padded format agg_values = log_ratio if _use_log_agg else metric From 5d37fa09550c74e6f90ef2203ad37fcb648d20bd Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 11:26:50 -0700 Subject: [PATCH 05/25] Update test_rejection_sampling.py --- tests/test_rejection_sampling.py | 276 +++++++++++++++++++++++++++++++ 1 file changed, 276 insertions(+) diff --git a/tests/test_rejection_sampling.py b/tests/test_rejection_sampling.py index 4c1dbb8c40..d60b14ff84 100644 --- a/tests/test_rejection_sampling.py +++ b/tests/test_rejection_sampling.py @@ -833,6 +833,282 @@ def test_invalid_metric_raises(self): with pytest.raises(ValueError, match="metric must be one of"): RejectionSamplingConfig(metric="invalid") + +class TestTwoStageRejectionSampling: + """Tests for two-stage Geo-RS + Token-MIS/TIS mode (from closed PR #1084). + + The two-stage pipeline: + Stage 1 — Geo-RS: reject sequences whose geometric-mean ratio > upper. + Stage 2 — Token-MIS/TIS: on accepted sequences, filter/clamp per-token. + """ + + # ── Config validation ───────────────────────────────────────────────────── + + def test_token_action_requires_sequence_level(self): + """token_action must be combined with level='sequence'.""" + with pytest.raises(ValueError, match="level='sequence'"): + RejectionSamplingConfig( + level="token", + action="mask", + metric="ratio", + upper=2.0, + token_action="mask", + ) + + def test_token_action_requires_ratio_metric(self): + """token_action is only defined for metric='ratio'.""" + with pytest.raises(ValueError, match="metric='ratio'"): + RejectionSamplingConfig( + level="sequence", + action="mask", + metric="kl_k2", + upper=1.0, + token_action="mask", + ) + + def test_token_action_requires_action_mask_at_sequence_level(self): + """Sequence-level stage must use action='mask' (hard rejection only).""" + with pytest.raises(ValueError, match="action='mask'"): + RejectionSamplingConfig( + level="sequence", + action="clamp", # invalid for two-stage + metric="ratio", + upper=2.0, + token_action="mask", + ) + + def test_token_action_invalid_string(self): + """token_action must be 'mask', 'clamp', or None.""" + with pytest.raises(ValueError, match="token_action must be one of"): + RejectionSamplingConfig( + level="sequence", + action="mask", + metric="ratio", + upper=2.0, + token_action="truncate", # typo / invalid choice + ) + + def test_valid_two_stage_mis_config(self): + """Geo-RS + Token-MIS config constructs without error.""" + cfg = RejectionSamplingConfig( + level="sequence", + action="mask", + metric="ratio", + agg="mean", + upper=2.0, + token_action="mask", + ) + assert cfg.token_action == "mask" + assert cfg.level == "sequence" + + def test_valid_two_stage_tis_config(self): + """Geo-RS + Token-TIS config constructs without error.""" + cfg = RejectionSamplingConfig( + level="sequence", + action="mask", + metric="ratio", + agg="mean", + upper=2.0, + lower=0.5, + token_action="clamp", + ) + assert cfg.token_action == "clamp" + assert cfg.lower == 0.5 + + # ── Functional tests — 2D padded format ────────────────────────────────── + + @staticmethod + def _batch_inputs(): + """ + Return a 2D padded batch with three sequences of length 4. + + Sequence 0: per-token ratio = 1.5 → geo-mean = 1.5 (accepted, upper=2.0) + Sequence 1: per-token ratio = 3.0 → geo-mean = 3.0 (rejected, > upper) + Sequence 2: per-token ratio = 0.8 → geo-mean = 0.8 (accepted) + """ + ratios = torch.tensor([ + [1.5, 1.5, 1.5, 1.5], + [3.0, 3.0, 3.0, 3.0], + [0.8, 0.8, 0.8, 0.8], + ]) + loss_mask = torch.ones(3, 4) + log_probs = torch.log(ratios) + old_log_probs = torch.zeros_like(log_probs) + return loss_mask, ratios, log_probs, old_log_probs + + def test_stage1_rejects_divergent_sequence(self): + """Stage 1 (Geo-RS) must fully zero-out the rejected sequence.""" + cfg = RejectionSamplingConfig( + level="sequence", action="mask", metric="ratio", + agg="mean", upper=2.0, token_action="mask", + ) + loss_mask, ratios, log_probs, old_log_probs = self._batch_inputs() + new_mask, _ = apply_rejection_sampling( + cfg=cfg, + loss_mask=loss_mask, + behave_imp_weight=ratios, + log_probs=log_probs, + old_log_probs=old_log_probs, + ) + # Sequence 1 (geo-mean 3.0 > 2.0) must be fully masked. + assert new_mask[1].sum() == 0, "Rejected sequence must be fully zeroed" + # Sequences 0 and 2 are accepted and their token ratios ≤ upper → kept. + assert new_mask[0].sum() == 4 + assert new_mask[2].sum() == 4 + + def test_stage2_mis_filters_high_token_within_accepted_seq(self): + """ + Stage 2 (Token-MIS) filters individual high-ratio tokens inside + a sequence that was accepted by Geo-RS. + """ + cfg = RejectionSamplingConfig( + level="sequence", action="mask", metric="ratio", + agg="mean", upper=2.0, token_action="mask", + ) + # Seq 0: geo-mean ≈ exp(mean([0, 0, log(2.5), 0])) ≈ 1.26 → accepted by Geo-RS + # but token[2] = 2.5 > upper → masked by Token-MIS + # Seq 1: all ratios = 1.0 → accepted, all tokens kept + ratios = torch.tensor([ + [1.0, 1.0, 2.5, 1.0], + [1.0, 1.0, 1.0, 1.0], + ]) + loss_mask = torch.ones(2, 4) + log_probs = torch.log(ratios) + old_log_probs = torch.zeros_like(log_probs) + + new_mask, _ = apply_rejection_sampling( + cfg=cfg, + loss_mask=loss_mask, + behave_imp_weight=ratios, + log_probs=log_probs, + old_log_probs=old_log_probs, + ) + assert new_mask[0, 0] == 1 + assert new_mask[0, 1] == 1 + assert new_mask[0, 2] == 0, "Token-MIS must mask the 2.5-ratio token" + assert new_mask[0, 3] == 1 + assert new_mask[1].sum() == 4, "Clean sequence must be fully kept" + + def test_stage2_tis_clamps_token_weights_not_mask(self): + """ + Stage 2 (Token-TIS) clamps per-token weights but must NOT zero loss_mask. + All tokens continue to contribute to the gradient. + """ + cfg = RejectionSamplingConfig( + level="sequence", action="mask", metric="ratio", + agg="mean", upper=2.0, lower=0.5, token_action="clamp", + ) + # Both sequences accepted by Geo-RS (geo-means ≤ 2.0). + ratios = torch.tensor([ + [0.2, 1.0, 1.8, 3.5], # tokens 0 and 3 out of [0.5, 2.0] + [0.8, 1.2, 1.5, 0.9], # all in range + ]) + loss_mask = torch.ones(2, 4) + log_probs = torch.log(ratios.clamp(min=1e-6)) + old_log_probs = torch.zeros_like(log_probs) + + new_mask, new_weight = apply_rejection_sampling( + cfg=cfg, + loss_mask=loss_mask, + behave_imp_weight=ratios, + log_probs=log_probs, + old_log_probs=old_log_probs, + ) + # loss_mask must be entirely unchanged — TIS never zeros tokens. + assert new_mask.sum() == 8, "Token-TIS must not zero any loss_mask tokens" + # Weights clamped to [0.5, 2.0]. + assert new_weight[0, 0] == pytest.approx(0.5), "0.2 clamped to lower=0.5" + assert new_weight[0, 1] == pytest.approx(1.0), "1.0 unchanged" + assert new_weight[0, 2] == pytest.approx(1.8), "1.8 unchanged" + assert new_weight[0, 3] == pytest.approx(2.0), "3.5 clamped to upper=2.0" + assert new_weight[1].allclose(ratios[1]), "Seq 1 weights unchanged" + + def test_stage1_dominates_even_if_stage2_would_pass(self): + """ + Tokens in a Stage-1-rejected sequence must stay masked even if their + individual token ratio would have passed the Token-MIS threshold. + """ + cfg = RejectionSamplingConfig( + level="sequence", action="mask", metric="ratio", + agg="mean", upper=2.0, token_action="mask", + ) + loss_mask = torch.ones(1, 4) + # geo-mean = 4.0 > 2.0 → Stage 1 rejects this sequence entirely. + ratios = torch.full((1, 4), 4.0) + log_probs = torch.log(ratios) + old_log_probs = torch.zeros_like(log_probs) + + new_mask, _ = apply_rejection_sampling( + cfg=cfg, + loss_mask=loss_mask, + behave_imp_weight=ratios, + log_probs=log_probs, + old_log_probs=old_log_probs, + ) + assert new_mask.sum() == 0, "Stage 1 rejection must dominate Stage 2" + + def test_none_token_action_identical_to_pure_sequence_geo_rs(self): + """ + token_action=None must produce results identical to the existing + level='sequence', action='mask' mode — no Stage 2 runs. + """ + ratios = torch.tensor([ + [1.5, 1.5, 1.5, 1.5], + [3.0, 3.0, 3.0, 3.0], + [0.8, 0.8, 0.8, 0.8], + ]) + loss_mask = torch.ones(3, 4) + log_probs = torch.log(ratios) + old_log_probs = torch.zeros_like(log_probs) + + cfg_two_stage_off = RejectionSamplingConfig( + level="sequence", action="mask", metric="ratio", + agg="mean", upper=2.0, token_action=None, + ) + cfg_original = RejectionSamplingConfig( + level="sequence", action="mask", metric="ratio", + agg="mean", upper=2.0, + ) + + mask_off, w_off = apply_rejection_sampling( + cfg_two_stage_off, loss_mask.clone(), ratios.clone(), + log_probs, old_log_probs, + ) + mask_orig, w_orig = apply_rejection_sampling( + cfg_original, loss_mask.clone(), ratios.clone(), + log_probs, old_log_probs, + ) + + torch.testing.assert_close(mask_off, mask_orig) + torch.testing.assert_close(w_off, w_orig) + + def test_lower_bound_also_applied_in_token_mis(self): + """ + Token-MIS with a `lower` bound must also mask tokens whose ratio + falls below `lower` (policy has dropped sharply at that token). + """ + cfg = RejectionSamplingConfig( + level="sequence", action="mask", metric="ratio", + agg="mean", upper=3.0, lower=0.5, token_action="mask", + ) + loss_mask = torch.ones(1, 4) + # Seq geo-mean ≈ exp(mean(log([0.3, 1.0, 1.0, 1.0]))) ≈ 0.84 → accepted + # but token[0] = 0.3 < lower=0.5 → masked by Token-MIS + ratios = torch.tensor([[0.3, 1.0, 1.0, 1.0]]) + log_probs = torch.log(ratios) + old_log_probs = torch.zeros_like(log_probs) + + new_mask, _ = apply_rejection_sampling( + cfg=cfg, + loss_mask=loss_mask, + behave_imp_weight=ratios, + log_probs=log_probs, + old_log_probs=old_log_probs, + ) + assert new_mask[0, 0] == 0, "Token below lower bound must be masked" + assert new_mask[0, 1] == 1 + assert new_mask[0, 2] == 1 + assert new_mask[0, 3] == 1 def test_invalid_agg_raises(self): """Invalid agg should raise ValueError.""" with pytest.raises(ValueError, match="agg must be one of"): From 0981cb5e09f35014b9749e9c675e9eb50808fbf3 Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 12:19:27 -0700 Subject: [PATCH 06/25] Update test_rejection_sampling.py --- tests/test_rejection_sampling.py | 36 ++++++++++++++++---------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/test_rejection_sampling.py b/tests/test_rejection_sampling.py index d60b14ff84..02af2db27a 100644 --- a/tests/test_rejection_sampling.py +++ b/tests/test_rejection_sampling.py @@ -890,7 +890,7 @@ def test_token_action_invalid_string(self): def test_valid_two_stage_mis_config(self): """Geo-RS + Token-MIS config constructs without error.""" - cfg = RejectionSamplingConfig( + config = RejectionSamplingConfig( level="sequence", action="mask", metric="ratio", @@ -898,12 +898,12 @@ def test_valid_two_stage_mis_config(self): upper=2.0, token_action="mask", ) - assert cfg.token_action == "mask" - assert cfg.level == "sequence" + assert config.token_action == "mask" + assert config.level == "sequence" def test_valid_two_stage_tis_config(self): """Geo-RS + Token-TIS config constructs without error.""" - cfg = RejectionSamplingConfig( + config = RejectionSamplingConfig( level="sequence", action="mask", metric="ratio", @@ -912,8 +912,8 @@ def test_valid_two_stage_tis_config(self): lower=0.5, token_action="clamp", ) - assert cfg.token_action == "clamp" - assert cfg.lower == 0.5 + assert config.token_action == "clamp" + assert config.lower == 0.5 # ── Functional tests — 2D padded format ────────────────────────────────── @@ -938,13 +938,13 @@ def _batch_inputs(): def test_stage1_rejects_divergent_sequence(self): """Stage 1 (Geo-RS) must fully zero-out the rejected sequence.""" - cfg = RejectionSamplingConfig( + config = RejectionSamplingConfig( level="sequence", action="mask", metric="ratio", agg="mean", upper=2.0, token_action="mask", ) loss_mask, ratios, log_probs, old_log_probs = self._batch_inputs() new_mask, _ = apply_rejection_sampling( - cfg=cfg, + config=config, loss_mask=loss_mask, behave_imp_weight=ratios, log_probs=log_probs, @@ -961,7 +961,7 @@ def test_stage2_mis_filters_high_token_within_accepted_seq(self): Stage 2 (Token-MIS) filters individual high-ratio tokens inside a sequence that was accepted by Geo-RS. """ - cfg = RejectionSamplingConfig( + config = RejectionSamplingConfig( level="sequence", action="mask", metric="ratio", agg="mean", upper=2.0, token_action="mask", ) @@ -977,7 +977,7 @@ def test_stage2_mis_filters_high_token_within_accepted_seq(self): old_log_probs = torch.zeros_like(log_probs) new_mask, _ = apply_rejection_sampling( - cfg=cfg, + config=config, loss_mask=loss_mask, behave_imp_weight=ratios, log_probs=log_probs, @@ -994,7 +994,7 @@ def test_stage2_tis_clamps_token_weights_not_mask(self): Stage 2 (Token-TIS) clamps per-token weights but must NOT zero loss_mask. All tokens continue to contribute to the gradient. """ - cfg = RejectionSamplingConfig( + config = RejectionSamplingConfig( level="sequence", action="mask", metric="ratio", agg="mean", upper=2.0, lower=0.5, token_action="clamp", ) @@ -1008,7 +1008,7 @@ def test_stage2_tis_clamps_token_weights_not_mask(self): old_log_probs = torch.zeros_like(log_probs) new_mask, new_weight = apply_rejection_sampling( - cfg=cfg, + config=config, loss_mask=loss_mask, behave_imp_weight=ratios, log_probs=log_probs, @@ -1028,7 +1028,7 @@ def test_stage1_dominates_even_if_stage2_would_pass(self): Tokens in a Stage-1-rejected sequence must stay masked even if their individual token ratio would have passed the Token-MIS threshold. """ - cfg = RejectionSamplingConfig( + config = RejectionSamplingConfig( level="sequence", action="mask", metric="ratio", agg="mean", upper=2.0, token_action="mask", ) @@ -1039,7 +1039,7 @@ def test_stage1_dominates_even_if_stage2_would_pass(self): old_log_probs = torch.zeros_like(log_probs) new_mask, _ = apply_rejection_sampling( - cfg=cfg, + config=config, loss_mask=loss_mask, behave_imp_weight=ratios, log_probs=log_probs, @@ -1071,11 +1071,11 @@ def test_none_token_action_identical_to_pure_sequence_geo_rs(self): ) mask_off, w_off = apply_rejection_sampling( - cfg_two_stage_off, loss_mask.clone(), ratios.clone(), + config=cfg_two_stage_off, loss_mask.clone(), ratios.clone(), log_probs, old_log_probs, ) mask_orig, w_orig = apply_rejection_sampling( - cfg_original, loss_mask.clone(), ratios.clone(), + config=cfg_original, loss_mask.clone(), ratios.clone(), log_probs, old_log_probs, ) @@ -1087,7 +1087,7 @@ def test_lower_bound_also_applied_in_token_mis(self): Token-MIS with a `lower` bound must also mask tokens whose ratio falls below `lower` (policy has dropped sharply at that token). """ - cfg = RejectionSamplingConfig( + config = RejectionSamplingConfig( level="sequence", action="mask", metric="ratio", agg="mean", upper=3.0, lower=0.5, token_action="mask", ) @@ -1099,7 +1099,7 @@ def test_lower_bound_also_applied_in_token_mis(self): old_log_probs = torch.zeros_like(log_probs) new_mask, _ = apply_rejection_sampling( - cfg=cfg, + config=config, loss_mask=loss_mask, behave_imp_weight=ratios, log_probs=log_probs, From b20e8444d845a7c5e37fe7fe86f9eb0598d75436 Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 12:23:28 -0700 Subject: [PATCH 07/25] Update test_rejection_sampling.py --- tests/test_rejection_sampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_rejection_sampling.py b/tests/test_rejection_sampling.py index 02af2db27a..e4c1b4a195 100644 --- a/tests/test_rejection_sampling.py +++ b/tests/test_rejection_sampling.py @@ -1071,11 +1071,11 @@ def test_none_token_action_identical_to_pure_sequence_geo_rs(self): ) mask_off, w_off = apply_rejection_sampling( - config=cfg_two_stage_off, loss_mask.clone(), ratios.clone(), + cfg_two_stage_off, loss_mask.clone(), ratios.clone(), log_probs, old_log_probs, ) mask_orig, w_orig = apply_rejection_sampling( - config=cfg_original, loss_mask.clone(), ratios.clone(), + cfg_original, loss_mask.clone(), ratios.clone(), log_probs, old_log_probs, ) From 854523003b48fa02d8bb7bb39c02eafe15ea99c8 Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 12:37:08 -0700 Subject: [PATCH 08/25] Update test_rejection_sampling.py --- tests/test_rejection_sampling.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_rejection_sampling.py b/tests/test_rejection_sampling.py index e4c1b4a195..b957af6ed0 100644 --- a/tests/test_rejection_sampling.py +++ b/tests/test_rejection_sampling.py @@ -946,7 +946,7 @@ def test_stage1_rejects_divergent_sequence(self): new_mask, _ = apply_rejection_sampling( config=config, loss_mask=loss_mask, - behave_imp_weight=ratios, + # behave_imp_weight=ratios, log_probs=log_probs, old_log_probs=old_log_probs, ) @@ -979,7 +979,7 @@ def test_stage2_mis_filters_high_token_within_accepted_seq(self): new_mask, _ = apply_rejection_sampling( config=config, loss_mask=loss_mask, - behave_imp_weight=ratios, + # behave_imp_weight=ratios, log_probs=log_probs, old_log_probs=old_log_probs, ) @@ -1010,7 +1010,7 @@ def test_stage2_tis_clamps_token_weights_not_mask(self): new_mask, new_weight = apply_rejection_sampling( config=config, loss_mask=loss_mask, - behave_imp_weight=ratios, + # behave_imp_weight=ratios, log_probs=log_probs, old_log_probs=old_log_probs, ) @@ -1041,7 +1041,7 @@ def test_stage1_dominates_even_if_stage2_would_pass(self): new_mask, _ = apply_rejection_sampling( config=config, loss_mask=loss_mask, - behave_imp_weight=ratios, + # behave_imp_weight=ratios, log_probs=log_probs, old_log_probs=old_log_probs, ) @@ -1071,11 +1071,11 @@ def test_none_token_action_identical_to_pure_sequence_geo_rs(self): ) mask_off, w_off = apply_rejection_sampling( - cfg_two_stage_off, loss_mask.clone(), ratios.clone(), + cfg_two_stage_off, loss_mask.clone(), log_probs, old_log_probs, ) mask_orig, w_orig = apply_rejection_sampling( - cfg_original, loss_mask.clone(), ratios.clone(), + cfg_original, loss_mask.clone(), log_probs, old_log_probs, ) @@ -1101,7 +1101,7 @@ def test_lower_bound_also_applied_in_token_mis(self): new_mask, _ = apply_rejection_sampling( config=config, loss_mask=loss_mask, - behave_imp_weight=ratios, + # behave_imp_weight=ratios, log_probs=log_probs, old_log_probs=old_log_probs, ) From 74b595277c16b2e4fca00ba231a40ef87cb37837 Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 12:41:40 -0700 Subject: [PATCH 09/25] Update test_rejection_sampling.py --- tests/test_rejection_sampling.py | 52 ++++++++++++++++---------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/tests/test_rejection_sampling.py b/tests/test_rejection_sampling.py index b957af6ed0..718fe800e1 100644 --- a/tests/test_rejection_sampling.py +++ b/tests/test_rejection_sampling.py @@ -932,9 +932,9 @@ def _batch_inputs(): [0.8, 0.8, 0.8, 0.8], ]) loss_mask = torch.ones(3, 4) - log_probs = torch.log(ratios) - old_log_probs = torch.zeros_like(log_probs) - return loss_mask, ratios, log_probs, old_log_probs + proximal_logprobs = torch.log(ratios) + old_logprobs = torch.zeros_like(proximal_logprobs) + return loss_mask, ratios, proximal_logprobs, old_logprobs def test_stage1_rejects_divergent_sequence(self): """Stage 1 (Geo-RS) must fully zero-out the rejected sequence.""" @@ -942,13 +942,13 @@ def test_stage1_rejects_divergent_sequence(self): level="sequence", action="mask", metric="ratio", agg="mean", upper=2.0, token_action="mask", ) - loss_mask, ratios, log_probs, old_log_probs = self._batch_inputs() + loss_mask, ratios, proximal_logprobs, old_logprobs = self._batch_inputs() new_mask, _ = apply_rejection_sampling( config=config, loss_mask=loss_mask, # behave_imp_weight=ratios, - log_probs=log_probs, - old_log_probs=old_log_probs, + proximal_logprobs=proximal_logprobs, + old_logprobs=old_logprobs, ) # Sequence 1 (geo-mean 3.0 > 2.0) must be fully masked. assert new_mask[1].sum() == 0, "Rejected sequence must be fully zeroed" @@ -973,15 +973,15 @@ def test_stage2_mis_filters_high_token_within_accepted_seq(self): [1.0, 1.0, 1.0, 1.0], ]) loss_mask = torch.ones(2, 4) - log_probs = torch.log(ratios) - old_log_probs = torch.zeros_like(log_probs) + proximal_logprobs = torch.log(ratios) + old_logprobs = torch.zeros_like(proximal_logprobs) new_mask, _ = apply_rejection_sampling( config=config, loss_mask=loss_mask, # behave_imp_weight=ratios, - log_probs=log_probs, - old_log_probs=old_log_probs, + proximal_logprobs=proximal_logprobs, + old_logprobs=old_logprobs, ) assert new_mask[0, 0] == 1 assert new_mask[0, 1] == 1 @@ -1004,15 +1004,15 @@ def test_stage2_tis_clamps_token_weights_not_mask(self): [0.8, 1.2, 1.5, 0.9], # all in range ]) loss_mask = torch.ones(2, 4) - log_probs = torch.log(ratios.clamp(min=1e-6)) - old_log_probs = torch.zeros_like(log_probs) + proximal_logprobs = torch.log(ratios.clamp(min=1e-6)) + old_logprobs = torch.zeros_like(proximal_logprobs) new_mask, new_weight = apply_rejection_sampling( config=config, loss_mask=loss_mask, # behave_imp_weight=ratios, - log_probs=log_probs, - old_log_probs=old_log_probs, + proximal_logprobs=proximal_logprobs, + old_logprobs=old_logprobs, ) # loss_mask must be entirely unchanged — TIS never zeros tokens. assert new_mask.sum() == 8, "Token-TIS must not zero any loss_mask tokens" @@ -1035,15 +1035,15 @@ def test_stage1_dominates_even_if_stage2_would_pass(self): loss_mask = torch.ones(1, 4) # geo-mean = 4.0 > 2.0 → Stage 1 rejects this sequence entirely. ratios = torch.full((1, 4), 4.0) - log_probs = torch.log(ratios) - old_log_probs = torch.zeros_like(log_probs) + proximal_logprobs = torch.log(ratios) + old_logprobs = torch.zeros_like(proximal_logprobs) new_mask, _ = apply_rejection_sampling( config=config, loss_mask=loss_mask, # behave_imp_weight=ratios, - log_probs=log_probs, - old_log_probs=old_log_probs, + proximal_logprobs=proximal_logprobs, + old_logprobs=old_logprobs, ) assert new_mask.sum() == 0, "Stage 1 rejection must dominate Stage 2" @@ -1058,8 +1058,8 @@ def test_none_token_action_identical_to_pure_sequence_geo_rs(self): [0.8, 0.8, 0.8, 0.8], ]) loss_mask = torch.ones(3, 4) - log_probs = torch.log(ratios) - old_log_probs = torch.zeros_like(log_probs) + proximal_logprobs = torch.log(ratios) + old_logprobs = torch.zeros_like(proximal_logprobs) cfg_two_stage_off = RejectionSamplingConfig( level="sequence", action="mask", metric="ratio", @@ -1072,11 +1072,11 @@ def test_none_token_action_identical_to_pure_sequence_geo_rs(self): mask_off, w_off = apply_rejection_sampling( cfg_two_stage_off, loss_mask.clone(), - log_probs, old_log_probs, + proximal_logprobs, old_logprobs, ) mask_orig, w_orig = apply_rejection_sampling( cfg_original, loss_mask.clone(), - log_probs, old_log_probs, + proximal_logprobs, old_logprobs, ) torch.testing.assert_close(mask_off, mask_orig) @@ -1095,15 +1095,15 @@ def test_lower_bound_also_applied_in_token_mis(self): # Seq geo-mean ≈ exp(mean(log([0.3, 1.0, 1.0, 1.0]))) ≈ 0.84 → accepted # but token[0] = 0.3 < lower=0.5 → masked by Token-MIS ratios = torch.tensor([[0.3, 1.0, 1.0, 1.0]]) - log_probs = torch.log(ratios) - old_log_probs = torch.zeros_like(log_probs) + proximal_logprobs = torch.log(ratios) + old_logprobs = torch.zeros_like(proximal_logprobs) new_mask, _ = apply_rejection_sampling( config=config, loss_mask=loss_mask, # behave_imp_weight=ratios, - log_probs=log_probs, - old_log_probs=old_log_probs, + proximal_logprobs=proximal_logprobs, + old_logprobs=old_logprobs, ) assert new_mask[0, 0] == 0, "Token below lower bound must be masked" assert new_mask[0, 1] == 1 From 5b2dbc79f332bdbae2edaae1b6077aa7287b77e0 Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 12:47:19 -0700 Subject: [PATCH 10/25] Update test_rejection_sampling.py --- tests/test_rejection_sampling.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_rejection_sampling.py b/tests/test_rejection_sampling.py index 718fe800e1..0c32e085aa 100644 --- a/tests/test_rejection_sampling.py +++ b/tests/test_rejection_sampling.py @@ -946,6 +946,7 @@ def test_stage1_rejects_divergent_sequence(self): new_mask, _ = apply_rejection_sampling( config=config, loss_mask=loss_mask, + cu_seqlens=None, # behave_imp_weight=ratios, proximal_logprobs=proximal_logprobs, old_logprobs=old_logprobs, @@ -979,6 +980,7 @@ def test_stage2_mis_filters_high_token_within_accepted_seq(self): new_mask, _ = apply_rejection_sampling( config=config, loss_mask=loss_mask, + cu_seqlens=None, # behave_imp_weight=ratios, proximal_logprobs=proximal_logprobs, old_logprobs=old_logprobs, @@ -1010,6 +1012,7 @@ def test_stage2_tis_clamps_token_weights_not_mask(self): new_mask, new_weight = apply_rejection_sampling( config=config, loss_mask=loss_mask, + cu_seqlens=None, # behave_imp_weight=ratios, proximal_logprobs=proximal_logprobs, old_logprobs=old_logprobs, @@ -1041,6 +1044,7 @@ def test_stage1_dominates_even_if_stage2_would_pass(self): new_mask, _ = apply_rejection_sampling( config=config, loss_mask=loss_mask, + cu_seqlens=None, # behave_imp_weight=ratios, proximal_logprobs=proximal_logprobs, old_logprobs=old_logprobs, @@ -1071,12 +1075,10 @@ def test_none_token_action_identical_to_pure_sequence_geo_rs(self): ) mask_off, w_off = apply_rejection_sampling( - cfg_two_stage_off, loss_mask.clone(), - proximal_logprobs, old_logprobs, + proximal_logprobs, old_logprobs, loss_mask.clone(), None, cfg_two_stage_off ) mask_orig, w_orig = apply_rejection_sampling( - cfg_original, loss_mask.clone(), - proximal_logprobs, old_logprobs, + proximal_logprobs, old_logprobs, loss_mask.clone(), None, cfg_original ) torch.testing.assert_close(mask_off, mask_orig) @@ -1101,6 +1103,7 @@ def test_lower_bound_also_applied_in_token_mis(self): new_mask, _ = apply_rejection_sampling( config=config, loss_mask=loss_mask, + cu_seqlens=None, # behave_imp_weight=ratios, proximal_logprobs=proximal_logprobs, old_logprobs=old_logprobs, From 646f47a9f4a190c295d186e980b33fb94db878f2 Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 12:50:40 -0700 Subject: [PATCH 11/25] Update functional.py --- areal/utils/functional/functional.py | 34 ++++++++++++++-------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/areal/utils/functional/functional.py b/areal/utils/functional/functional.py index e5d4b64fb2..1f124855e2 100644 --- a/areal/utils/functional/functional.py +++ b/areal/utils/functional/functional.py @@ -343,22 +343,22 @@ def apply_rejection_sampling( behave_imp_weight, ) # ── Stage 2: Token-MIS/TIS (1D packed) ────────────────────────────── - if cfg.token_action is not None: + if config.token_action is not None: # behave_imp_weight holds per-token ratios. # Shape: [total_tokens] in 1D packed format. token_ratio = behave_imp_weight - if cfg.token_action == "mask": - token_oor = token_ratio > cfg.upper - if cfg.lower is not None: - token_oor = token_oor | (token_ratio < cfg.lower) + if config.token_action == "mask": + token_oor = token_ratio > config.upper + if config.lower is not None: + token_oor = token_oor | (token_ratio < config.lower) loss_mask = loss_mask * (~token_oor).to(loss_mask.dtype) behave_imp_weight = behave_imp_weight * (~token_oor).to( behave_imp_weight.dtype ) - elif cfg.token_action == "clamp": - clamp_lower = cfg.lower if cfg.lower is not None else 0.0 + elif config.token_action == "clamp": + clamp_lower = config.lower if config.lower is not None else 0.0 behave_imp_weight = token_ratio.clamp( - min=clamp_lower, max=cfg.upper + min=clamp_lower, max=config.upper ) # ── End Stage 2 ────────────────────────────────────────────────────── else: @@ -407,34 +407,34 @@ def apply_rejection_sampling( ) # ── Stage 2: Token-MIS/TIS on Geo-RS-accepted sequences ───────────── - # Runs only in two-stage mode (cfg.token_action is not None). + # Runs only in two-stage mode (config.token_action is not None). # At this point, loss_mask already has Stage-1-rejected sequences # zeroed out. We now apply per-token filtering on surviving tokens. - if cfg.token_action is not None: + if config.token_action is not None: # behave_imp_weight holds per-token ratios π_prox / π_behave. # Shape: [batch, seq_len] in 2D padded format. token_ratio = behave_imp_weight - if cfg.token_action == "mask": + if config.token_action == "mask": # Token-MIS: zero out tokens where the per-token ratio exceeds # upper, and optionally where it falls below lower. # This suppresses tokens where the current policy has drifted # far from the behavior policy at the token level. - token_oor = token_ratio > cfg.upper - if cfg.lower is not None: - token_oor = token_oor | (token_ratio < cfg.lower) + token_oor = token_ratio > config.upper + if config.lower is not None: + token_oor = token_oor | (token_ratio < config.lower) loss_mask = loss_mask * (~token_oor).to(loss_mask.dtype) behave_imp_weight = behave_imp_weight * (~token_oor).to( behave_imp_weight.dtype ) - elif cfg.token_action == "clamp": + elif config.token_action == "clamp": # Token-TIS: clamp the per-token importance ratio to # [lower, upper]. All tokens remain in the gradient but # their contribution is bounded. - clamp_lower = cfg.lower if cfg.lower is not None else 0.0 + clamp_lower = config.lower if config.lower is not None else 0.0 behave_imp_weight = token_ratio.clamp( - min=clamp_lower, max=cfg.upper + min=clamp_lower, max=config.upper ) # ── End Stage 2 ────────────────────────────────────────────────────── else: From 63f3b7e22546fbc95c0580658c49c30238ac822b Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 12:56:30 -0700 Subject: [PATCH 12/25] Update test_rejection_sampling.py --- tests/test_rejection_sampling.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/test_rejection_sampling.py b/tests/test_rejection_sampling.py index 0c32e085aa..463b1b84c6 100644 --- a/tests/test_rejection_sampling.py +++ b/tests/test_rejection_sampling.py @@ -943,7 +943,7 @@ def test_stage1_rejects_divergent_sequence(self): agg="mean", upper=2.0, token_action="mask", ) loss_mask, ratios, proximal_logprobs, old_logprobs = self._batch_inputs() - new_mask, _ = apply_rejection_sampling( + result = apply_rejection_sampling( config=config, loss_mask=loss_mask, cu_seqlens=None, @@ -951,6 +951,7 @@ def test_stage1_rejects_divergent_sequence(self): proximal_logprobs=proximal_logprobs, old_logprobs=old_logprobs, ) + new_mask = result.loss_mask # Sequence 1 (geo-mean 3.0 > 2.0) must be fully masked. assert new_mask[1].sum() == 0, "Rejected sequence must be fully zeroed" # Sequences 0 and 2 are accepted and their token ratios ≤ upper → kept. @@ -977,7 +978,7 @@ def test_stage2_mis_filters_high_token_within_accepted_seq(self): proximal_logprobs = torch.log(ratios) old_logprobs = torch.zeros_like(proximal_logprobs) - new_mask, _ = apply_rejection_sampling( + result = apply_rejection_sampling( config=config, loss_mask=loss_mask, cu_seqlens=None, @@ -985,6 +986,7 @@ def test_stage2_mis_filters_high_token_within_accepted_seq(self): proximal_logprobs=proximal_logprobs, old_logprobs=old_logprobs, ) + new_mask = result.loss_mask assert new_mask[0, 0] == 1 assert new_mask[0, 1] == 1 assert new_mask[0, 2] == 0, "Token-MIS must mask the 2.5-ratio token" @@ -1009,7 +1011,7 @@ def test_stage2_tis_clamps_token_weights_not_mask(self): proximal_logprobs = torch.log(ratios.clamp(min=1e-6)) old_logprobs = torch.zeros_like(proximal_logprobs) - new_mask, new_weight = apply_rejection_sampling( + new_mask, new_weight, new_filtered_fraction = apply_rejection_sampling( config=config, loss_mask=loss_mask, cu_seqlens=None, @@ -1041,7 +1043,7 @@ def test_stage1_dominates_even_if_stage2_would_pass(self): proximal_logprobs = torch.log(ratios) old_logprobs = torch.zeros_like(proximal_logprobs) - new_mask, _ = apply_rejection_sampling( + result = apply_rejection_sampling( config=config, loss_mask=loss_mask, cu_seqlens=None, @@ -1049,6 +1051,7 @@ def test_stage1_dominates_even_if_stage2_would_pass(self): proximal_logprobs=proximal_logprobs, old_logprobs=old_logprobs, ) + new_mask=result.loss_mask assert new_mask.sum() == 0, "Stage 1 rejection must dominate Stage 2" def test_none_token_action_identical_to_pure_sequence_geo_rs(self): @@ -1074,10 +1077,10 @@ def test_none_token_action_identical_to_pure_sequence_geo_rs(self): agg="mean", upper=2.0, ) - mask_off, w_off = apply_rejection_sampling( + mask_off, w_off, filtered_fraction_off = apply_rejection_sampling( proximal_logprobs, old_logprobs, loss_mask.clone(), None, cfg_two_stage_off ) - mask_orig, w_orig = apply_rejection_sampling( + mask_orig, w_orig, filtered_fraction_orig = apply_rejection_sampling( proximal_logprobs, old_logprobs, loss_mask.clone(), None, cfg_original ) @@ -1100,7 +1103,7 @@ def test_lower_bound_also_applied_in_token_mis(self): proximal_logprobs = torch.log(ratios) old_logprobs = torch.zeros_like(proximal_logprobs) - new_mask, _ = apply_rejection_sampling( + result = apply_rejection_sampling( config=config, loss_mask=loss_mask, cu_seqlens=None, @@ -1108,6 +1111,7 @@ def test_lower_bound_also_applied_in_token_mis(self): proximal_logprobs=proximal_logprobs, old_logprobs=old_logprobs, ) + new_mask = result.loss_mask assert new_mask[0, 0] == 0, "Token below lower bound must be masked" assert new_mask[0, 1] == 1 assert new_mask[0, 2] == 1 From 24f16046cd29085554d86a4285a9f24e877fb2a2 Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 12:59:50 -0700 Subject: [PATCH 13/25] Update test_rejection_sampling.py --- tests/test_rejection_sampling.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/test_rejection_sampling.py b/tests/test_rejection_sampling.py index 463b1b84c6..969c6e2274 100644 --- a/tests/test_rejection_sampling.py +++ b/tests/test_rejection_sampling.py @@ -1011,7 +1011,7 @@ def test_stage2_tis_clamps_token_weights_not_mask(self): proximal_logprobs = torch.log(ratios.clamp(min=1e-6)) old_logprobs = torch.zeros_like(proximal_logprobs) - new_mask, new_weight, new_filtered_fraction = apply_rejection_sampling( + result = apply_rejection_sampling( config=config, loss_mask=loss_mask, cu_seqlens=None, @@ -1019,6 +1019,8 @@ def test_stage2_tis_clamps_token_weights_not_mask(self): proximal_logprobs=proximal_logprobs, old_logprobs=old_logprobs, ) + new_mask = result.loss_mask + new_weight = result.behave_imp_weight # loss_mask must be entirely unchanged — TIS never zeros tokens. assert new_mask.sum() == 8, "Token-TIS must not zero any loss_mask tokens" # Weights clamped to [0.5, 2.0]. @@ -1077,12 +1079,16 @@ def test_none_token_action_identical_to_pure_sequence_geo_rs(self): agg="mean", upper=2.0, ) - mask_off, w_off, filtered_fraction_off = apply_rejection_sampling( + result = apply_rejection_sampling( proximal_logprobs, old_logprobs, loss_mask.clone(), None, cfg_two_stage_off ) - mask_orig, w_orig, filtered_fraction_orig = apply_rejection_sampling( + mask_off = result.loss_mask + w_off = result.behave_imp_weight + result= apply_rejection_sampling( proximal_logprobs, old_logprobs, loss_mask.clone(), None, cfg_original ) + mask_orig = result.loss_mask + w_orig = result.behave_imp_weight torch.testing.assert_close(mask_off, mask_orig) torch.testing.assert_close(w_off, w_orig) From fe595910f480d24f2984c02848ca3b1878b52908 Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 13:59:01 -0700 Subject: [PATCH 14/25] Update test_rejection_sampling.py --- tests/test_rejection_sampling.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_rejection_sampling.py b/tests/test_rejection_sampling.py index 969c6e2274..9fc324364f 100644 --- a/tests/test_rejection_sampling.py +++ b/tests/test_rejection_sampling.py @@ -973,7 +973,7 @@ def test_stage2_mis_filters_high_token_within_accepted_seq(self): ratios = torch.tensor([ [1.0, 1.0, 2.5, 1.0], [1.0, 1.0, 1.0, 1.0], - ]) + ], dtype=torch.int32)) loss_mask = torch.ones(2, 4) proximal_logprobs = torch.log(ratios) old_logprobs = torch.zeros_like(proximal_logprobs) @@ -981,7 +981,7 @@ def test_stage2_mis_filters_high_token_within_accepted_seq(self): result = apply_rejection_sampling( config=config, loss_mask=loss_mask, - cu_seqlens=None, + cu_seqlens=ratios, # behave_imp_weight=ratios, proximal_logprobs=proximal_logprobs, old_logprobs=old_logprobs, @@ -1006,7 +1006,7 @@ def test_stage2_tis_clamps_token_weights_not_mask(self): ratios = torch.tensor([ [0.2, 1.0, 1.8, 3.5], # tokens 0 and 3 out of [0.5, 2.0] [0.8, 1.2, 1.5, 0.9], # all in range - ]) + ], dtype=torch.int32)) loss_mask = torch.ones(2, 4) proximal_logprobs = torch.log(ratios.clamp(min=1e-6)) old_logprobs = torch.zeros_like(proximal_logprobs) @@ -1014,7 +1014,7 @@ def test_stage2_tis_clamps_token_weights_not_mask(self): result = apply_rejection_sampling( config=config, loss_mask=loss_mask, - cu_seqlens=None, + cu_seqlens=ratios, # behave_imp_weight=ratios, proximal_logprobs=proximal_logprobs, old_logprobs=old_logprobs, @@ -1105,14 +1105,14 @@ def test_lower_bound_also_applied_in_token_mis(self): loss_mask = torch.ones(1, 4) # Seq geo-mean ≈ exp(mean(log([0.3, 1.0, 1.0, 1.0]))) ≈ 0.84 → accepted # but token[0] = 0.3 < lower=0.5 → masked by Token-MIS - ratios = torch.tensor([[0.3, 1.0, 1.0, 1.0]]) + cu_seqlens = torch.tensor([[0.3, 1.0, 1.0, 1.0]], dtype=torch.int32) proximal_logprobs = torch.log(ratios) old_logprobs = torch.zeros_like(proximal_logprobs) result = apply_rejection_sampling( config=config, loss_mask=loss_mask, - cu_seqlens=None, + cu_seqlens=cu_seqlens, # behave_imp_weight=ratios, proximal_logprobs=proximal_logprobs, old_logprobs=old_logprobs, From c44a2cfe218a7f05522d27af6acd4ac373a5c4b5 Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 14:00:03 -0700 Subject: [PATCH 15/25] Update test_rejection_sampling.py --- tests/test_rejection_sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_rejection_sampling.py b/tests/test_rejection_sampling.py index 9fc324364f..2caa74e2bb 100644 --- a/tests/test_rejection_sampling.py +++ b/tests/test_rejection_sampling.py @@ -973,7 +973,7 @@ def test_stage2_mis_filters_high_token_within_accepted_seq(self): ratios = torch.tensor([ [1.0, 1.0, 2.5, 1.0], [1.0, 1.0, 1.0, 1.0], - ], dtype=torch.int32)) + ], dtype=torch.int32) loss_mask = torch.ones(2, 4) proximal_logprobs = torch.log(ratios) old_logprobs = torch.zeros_like(proximal_logprobs) From c1e9e812be66de37995c59a42fcfaf49d3feda31 Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 14:00:47 -0700 Subject: [PATCH 16/25] Update test_rejection_sampling.py --- tests/test_rejection_sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_rejection_sampling.py b/tests/test_rejection_sampling.py index 2caa74e2bb..9c8eab2495 100644 --- a/tests/test_rejection_sampling.py +++ b/tests/test_rejection_sampling.py @@ -1006,7 +1006,7 @@ def test_stage2_tis_clamps_token_weights_not_mask(self): ratios = torch.tensor([ [0.2, 1.0, 1.8, 3.5], # tokens 0 and 3 out of [0.5, 2.0] [0.8, 1.2, 1.5, 0.9], # all in range - ], dtype=torch.int32)) + ], dtype=torch.int32) loss_mask = torch.ones(2, 4) proximal_logprobs = torch.log(ratios.clamp(min=1e-6)) old_logprobs = torch.zeros_like(proximal_logprobs) From 2e094bf4c1ae71fdd5bdf3fa8cde31b58635b20e Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 14:02:09 -0700 Subject: [PATCH 17/25] Update test_rejection_sampling.py --- tests/test_rejection_sampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_rejection_sampling.py b/tests/test_rejection_sampling.py index 9c8eab2495..ee862578e9 100644 --- a/tests/test_rejection_sampling.py +++ b/tests/test_rejection_sampling.py @@ -1105,14 +1105,14 @@ def test_lower_bound_also_applied_in_token_mis(self): loss_mask = torch.ones(1, 4) # Seq geo-mean ≈ exp(mean(log([0.3, 1.0, 1.0, 1.0]))) ≈ 0.84 → accepted # but token[0] = 0.3 < lower=0.5 → masked by Token-MIS - cu_seqlens = torch.tensor([[0.3, 1.0, 1.0, 1.0]], dtype=torch.int32) + ratios = torch.tensor([[0.3, 1.0, 1.0, 1.0]], dtype=torch.int32) proximal_logprobs = torch.log(ratios) old_logprobs = torch.zeros_like(proximal_logprobs) result = apply_rejection_sampling( config=config, loss_mask=loss_mask, - cu_seqlens=cu_seqlens, + cu_seqlens=ratios, # behave_imp_weight=ratios, proximal_logprobs=proximal_logprobs, old_logprobs=old_logprobs, From d87ae4fd3214363988d9d3984a4bc503afc299f2 Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 14:05:10 -0700 Subject: [PATCH 18/25] Update test_rejection_sampling.py --- tests/test_rejection_sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_rejection_sampling.py b/tests/test_rejection_sampling.py index ee862578e9..eebeccd609 100644 --- a/tests/test_rejection_sampling.py +++ b/tests/test_rejection_sampling.py @@ -1105,7 +1105,7 @@ def test_lower_bound_also_applied_in_token_mis(self): loss_mask = torch.ones(1, 4) # Seq geo-mean ≈ exp(mean(log([0.3, 1.0, 1.0, 1.0]))) ≈ 0.84 → accepted # but token[0] = 0.3 < lower=0.5 → masked by Token-MIS - ratios = torch.tensor([[0.3, 1.0, 1.0, 1.0]], dtype=torch.int32) + ratios = torch.tensor([0.3, 1.0, 1.0, 1.0], dtype=torch.int32) proximal_logprobs = torch.log(ratios) old_logprobs = torch.zeros_like(proximal_logprobs) From e91b6b3e02dd94913da27087a625878d3e35e4d5 Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 14:06:44 -0700 Subject: [PATCH 19/25] Update test_rejection_sampling.py --- tests/test_rejection_sampling.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_rejection_sampling.py b/tests/test_rejection_sampling.py index eebeccd609..255a1ef002 100644 --- a/tests/test_rejection_sampling.py +++ b/tests/test_rejection_sampling.py @@ -1102,7 +1102,7 @@ def test_lower_bound_also_applied_in_token_mis(self): level="sequence", action="mask", metric="ratio", agg="mean", upper=3.0, lower=0.5, token_action="mask", ) - loss_mask = torch.ones(1, 4) + loss_mask = torch.ones(4) # Seq geo-mean ≈ exp(mean(log([0.3, 1.0, 1.0, 1.0]))) ≈ 0.84 → accepted # but token[0] = 0.3 < lower=0.5 → masked by Token-MIS ratios = torch.tensor([0.3, 1.0, 1.0, 1.0], dtype=torch.int32) @@ -1118,10 +1118,10 @@ def test_lower_bound_also_applied_in_token_mis(self): old_logprobs=old_logprobs, ) new_mask = result.loss_mask - assert new_mask[0, 0] == 0, "Token below lower bound must be masked" - assert new_mask[0, 1] == 1 - assert new_mask[0, 2] == 1 - assert new_mask[0, 3] == 1 + assert new_mask[0] == 0, "Token below lower bound must be masked" + assert new_mask[1] == 1 + assert new_mask[2] == 1 + assert new_mask[3] == 1 def test_invalid_agg_raises(self): """Invalid agg should raise ValueError.""" with pytest.raises(ValueError, match="agg must be one of"): From 5874f3f720fda6cc3ce0952520b4c5a1dd6334fa Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 15:15:10 -0700 Subject: [PATCH 20/25] Update functional.py --- areal/utils/functional/functional.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/areal/utils/functional/functional.py b/areal/utils/functional/functional.py index 1f124855e2..19bef8df07 100644 --- a/areal/utils/functional/functional.py +++ b/areal/utils/functional/functional.py @@ -346,7 +346,8 @@ def apply_rejection_sampling( if config.token_action is not None: # behave_imp_weight holds per-token ratios. # Shape: [total_tokens] in 1D packed format. - token_ratio = behave_imp_weight + # token_ratio = behave_imp_weight + token_ratio = torch.exp(log_ratio) if config.token_action == "mask": token_oor = token_ratio > config.upper if config.lower is not None: @@ -413,7 +414,8 @@ def apply_rejection_sampling( if config.token_action is not None: # behave_imp_weight holds per-token ratios π_prox / π_behave. # Shape: [batch, seq_len] in 2D padded format. - token_ratio = behave_imp_weight + # token_ratio = behave_imp_weight + token_ratio = torch.exp(log_ratio) if config.token_action == "mask": # Token-MIS: zero out tokens where the per-token ratio exceeds @@ -428,14 +430,14 @@ def apply_rejection_sampling( behave_imp_weight.dtype ) - elif config.token_action == "clamp": - # Token-TIS: clamp the per-token importance ratio to - # [lower, upper]. All tokens remain in the gradient but - # their contribution is bounded. - clamp_lower = config.lower if config.lower is not None else 0.0 - behave_imp_weight = token_ratio.clamp( - min=clamp_lower, max=config.upper - ) + elif config.token_action == "clamp": + # Token-TIS: clamp the per-token importance ratio to + # [lower, upper]. All tokens remain in the gradient but + # their contribution is bounded. + clamp_lower = config.lower if config.lower is not None else 0.0 + behave_imp_weight = token_ratio.clamp( + min=clamp_lower, max=config.upper + ) # ── End Stage 2 ────────────────────────────────────────────────────── else: # Token level From 0b78a084f4c8ac7fea9a72d67603cd5cfc4e890f Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 15:16:08 -0700 Subject: [PATCH 21/25] Update test_rejection_sampling.py --- tests/test_rejection_sampling.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/test_rejection_sampling.py b/tests/test_rejection_sampling.py index 255a1ef002..dd33b44a28 100644 --- a/tests/test_rejection_sampling.py +++ b/tests/test_rejection_sampling.py @@ -973,7 +973,7 @@ def test_stage2_mis_filters_high_token_within_accepted_seq(self): ratios = torch.tensor([ [1.0, 1.0, 2.5, 1.0], [1.0, 1.0, 1.0, 1.0], - ], dtype=torch.int32) + ]) loss_mask = torch.ones(2, 4) proximal_logprobs = torch.log(ratios) old_logprobs = torch.zeros_like(proximal_logprobs) @@ -981,7 +981,7 @@ def test_stage2_mis_filters_high_token_within_accepted_seq(self): result = apply_rejection_sampling( config=config, loss_mask=loss_mask, - cu_seqlens=ratios, + cu_seqlens=None, # behave_imp_weight=ratios, proximal_logprobs=proximal_logprobs, old_logprobs=old_logprobs, @@ -1006,7 +1006,7 @@ def test_stage2_tis_clamps_token_weights_not_mask(self): ratios = torch.tensor([ [0.2, 1.0, 1.8, 3.5], # tokens 0 and 3 out of [0.5, 2.0] [0.8, 1.2, 1.5, 0.9], # all in range - ], dtype=torch.int32) + ]) loss_mask = torch.ones(2, 4) proximal_logprobs = torch.log(ratios.clamp(min=1e-6)) old_logprobs = torch.zeros_like(proximal_logprobs) @@ -1014,7 +1014,7 @@ def test_stage2_tis_clamps_token_weights_not_mask(self): result = apply_rejection_sampling( config=config, loss_mask=loss_mask, - cu_seqlens=ratios, + cu_seqlens=None, # behave_imp_weight=ratios, proximal_logprobs=proximal_logprobs, old_logprobs=old_logprobs, @@ -1102,26 +1102,26 @@ def test_lower_bound_also_applied_in_token_mis(self): level="sequence", action="mask", metric="ratio", agg="mean", upper=3.0, lower=0.5, token_action="mask", ) - loss_mask = torch.ones(4) + loss_mask = torch.ones(1,4) # Seq geo-mean ≈ exp(mean(log([0.3, 1.0, 1.0, 1.0]))) ≈ 0.84 → accepted # but token[0] = 0.3 < lower=0.5 → masked by Token-MIS - ratios = torch.tensor([0.3, 1.0, 1.0, 1.0], dtype=torch.int32) + ratios = torch.tensor([[0.3, 1.0, 1.0, 1.0]]) proximal_logprobs = torch.log(ratios) old_logprobs = torch.zeros_like(proximal_logprobs) result = apply_rejection_sampling( config=config, loss_mask=loss_mask, - cu_seqlens=ratios, + cu_seqlens=None, # behave_imp_weight=ratios, proximal_logprobs=proximal_logprobs, old_logprobs=old_logprobs, ) new_mask = result.loss_mask - assert new_mask[0] == 0, "Token below lower bound must be masked" - assert new_mask[1] == 1 - assert new_mask[2] == 1 - assert new_mask[3] == 1 + assert new_mask[0,0] == 0, "Token below lower bound must be masked" + assert new_mask[0,1] == 1 + assert new_mask[0,2] == 1 + assert new_mask[0,3] == 1 def test_invalid_agg_raises(self): """Invalid agg should raise ValueError.""" with pytest.raises(ValueError, match="agg must be one of"): From 9674738d3f1cae21ee9a2b99bf0ddad3aa9c6f06 Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 15:45:17 -0700 Subject: [PATCH 22/25] Update areal/utils/functional/functional.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- areal/utils/functional/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/areal/utils/functional/functional.py b/areal/utils/functional/functional.py index 19bef8df07..f6602208f0 100644 --- a/areal/utils/functional/functional.py +++ b/areal/utils/functional/functional.py @@ -353,7 +353,7 @@ def apply_rejection_sampling( if config.lower is not None: token_oor = token_oor | (token_ratio < config.lower) loss_mask = loss_mask * (~token_oor).to(loss_mask.dtype) - behave_imp_weight = behave_imp_weight * (~token_oor).to( + behave_imp_weight = token_ratio * (~token_oor).to( behave_imp_weight.dtype ) elif config.token_action == "clamp": From a050b21b4cad1409f11e6a800b893670425313e9 Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 15:47:16 -0700 Subject: [PATCH 23/25] Update functional.py update 2D padded format as well --- areal/utils/functional/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/areal/utils/functional/functional.py b/areal/utils/functional/functional.py index f6602208f0..28916bfacc 100644 --- a/areal/utils/functional/functional.py +++ b/areal/utils/functional/functional.py @@ -426,7 +426,7 @@ def apply_rejection_sampling( if config.lower is not None: token_oor = token_oor | (token_ratio < config.lower) loss_mask = loss_mask * (~token_oor).to(loss_mask.dtype) - behave_imp_weight = behave_imp_weight * (~token_oor).to( + behave_imp_weight = token_ratio * (~token_oor).to( behave_imp_weight.dtype ) From a9595b66d134a1464b1dcb95101562aacbc9becb Mon Sep 17 00:00:00 2001 From: Morgan Heisler <135909098+morgan-heisler@users.noreply.github.com> Date: Mon, 20 Apr 2026 16:10:03 -0700 Subject: [PATCH 24/25] Update functional.py add per_token_ratio --- areal/utils/functional/functional.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/areal/utils/functional/functional.py b/areal/utils/functional/functional.py index 28916bfacc..01c78143db 100644 --- a/areal/utils/functional/functional.py +++ b/areal/utils/functional/functional.py @@ -250,6 +250,7 @@ def apply_rejection_sampling( behave_imp_weight = torch.exp(log_ratio) # Save original weight before any clamping, to compute clamped fraction later. original_weight = behave_imp_weight + per_token_ratio = behave_imp_weight.clone() # Step 4: Aggregate and filter # @@ -346,8 +347,7 @@ def apply_rejection_sampling( if config.token_action is not None: # behave_imp_weight holds per-token ratios. # Shape: [total_tokens] in 1D packed format. - # token_ratio = behave_imp_weight - token_ratio = torch.exp(log_ratio) + token_ratio = per_token_ratio if config.token_action == "mask": token_oor = token_ratio > config.upper if config.lower is not None: @@ -414,8 +414,7 @@ def apply_rejection_sampling( if config.token_action is not None: # behave_imp_weight holds per-token ratios π_prox / π_behave. # Shape: [batch, seq_len] in 2D padded format. - # token_ratio = behave_imp_weight - token_ratio = torch.exp(log_ratio) + token_ratio = per_token_ratio if config.token_action == "mask": # Token-MIS: zero out tokens where the per-token ratio exceeds From b8545274179605193f489d47375754bb12ef7ccb Mon Sep 17 00:00:00 2001 From: m00565736 Date: Tue, 21 Apr 2026 18:05:09 +0000 Subject: [PATCH 25/25] files changes after running `pre-commit run --all-files` --- areal/utils/functional/functional.py | 8 +- docs/en/cli_reference.md | 18 ++-- docs/zh/cli_reference.md | 18 ++-- tests/test_rejection_sampling.py | 124 +++++++++++++++++---------- 4 files changed, 105 insertions(+), 63 deletions(-) diff --git a/areal/utils/functional/functional.py b/areal/utils/functional/functional.py index 01c78143db..a52f5fcc53 100644 --- a/areal/utils/functional/functional.py +++ b/areal/utils/functional/functional.py @@ -406,7 +406,7 @@ def apply_rejection_sampling( ), behave_imp_weight, ) - + # ── Stage 2: Token-MIS/TIS on Geo-RS-accepted sequences ───────────── # Runs only in two-stage mode (config.token_action is not None). # At this point, loss_mask already has Stage-1-rejected sequences @@ -414,8 +414,8 @@ def apply_rejection_sampling( if config.token_action is not None: # behave_imp_weight holds per-token ratios π_prox / π_behave. # Shape: [batch, seq_len] in 2D padded format. - token_ratio = per_token_ratio - + token_ratio = per_token_ratio + if config.token_action == "mask": # Token-MIS: zero out tokens where the per-token ratio exceeds # upper, and optionally where it falls below lower. @@ -428,7 +428,7 @@ def apply_rejection_sampling( behave_imp_weight = token_ratio * (~token_oor).to( behave_imp_weight.dtype ) - + elif config.token_action == "clamp": # Token-TIS: clamp the per-token importance ratio to # [lower, upper]. All tokens remain in the gradient but diff --git a/docs/en/cli_reference.md b/docs/en/cli_reference.md index 778b04972a..97ca19ef5e 100644 --- a/docs/en/cli_reference.md +++ b/docs/en/cli_reference.md @@ -1034,16 +1034,18 @@ Attributes: For KL metrics, aggregation is arithmetic. upper: Upper bound for filtering. lower: Lower bound for filtering (optional). + token_action: Enables two-stage rejection sampling and importance sampling mode ('mask' or 'clamp')(optional). ``` -| Parameter | Type | Default | Description | -| --------- | ------------- | --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `level` | string | `"token"` | Filtering granularity. 'token': per-token filtering (each token judged independently). 'sequence': per-sequence filtering (all tokens in a sequence share the same fate). When metric='ratio', both the filtering decision and the correction weight (behave_imp_weight) operate at sequence level using the geometric mean. **Choices:** `token`, `sequence` | -| `action` | string | `"mask"` | Action to take when metric exceeds threshold. 'mask': zero out loss_mask for filtered tokens/sequences (rejection, completely excludes from gradient computation). 'clamp': clamp importance weight to \[lower, upper\] bounds (truncation, tokens still participate in gradient but with bounded weight). **Choices:** `mask`, `clamp` | -| `metric` | string | `"ratio"` | Divergence metric for filtering. 'ratio': direct importance ratio π_proximal/π_behave. 'kl_k1': KL estimator k1 = log(r), forward KL unbiased estimator (can be negative). 'kl_k2': KL estimator k2 = 0.5 * (log r)^2, non-negative quadratic approximation. 'kl_k3': KL estimator k3 = r - log(r) - 1, non-negative exact forward KL estimator. **Choices:** `ratio`, `kl_k1`, `kl_k2`, `kl_k3` | -| `agg` | string | `"mean"` | Aggregation method for sequence-level filtering. Only used when level='sequence'. For 'ratio' metric, aggregation is in log space: 'sum' = exp(sum(log(r_i))), 'mean' = exp(mean(log(r_i))) = geometric mean (length-invariant, consistent with GSPO). For KL metrics, aggregation is arithmetic: 'sum' = sum(kl_i), 'mean' = mean(kl_i). 'max': max of per-token metric values (most conservative). **Choices:** `sum`, `mean`, `max` | -| `upper` | float | `5.0` | Upper bound for filtering. Tokens/sequences with metric > upper are filtered out (loss_mask zeroed). For 'ratio' metric: must be > 1.0, typical values are 2.0 or 5.0. For 'kl_k2'/'kl_k3' metrics: typical values are 0.5-2.0. | -| `lower` | float \| None | `None` | Lower bound for filtering (optional). None means no lower bound. For 'ratio' metric: typical value is 0.5 (filter out tokens where policy probability dropped significantly). Must be > 0. For 'kl_k1' metric: can be used to filter negative KL estimates. | +| Parameter | Type | Default | Description | +| -------------- | -------------- | --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `level` | string | `"token"` | Filtering granularity. 'token': per-token filtering (each token judged independently). 'sequence': per-sequence filtering (all tokens in a sequence share the same fate). When metric='ratio', both the filtering decision and the correction weight (behave_imp_weight) operate at sequence level using the geometric mean. **Choices:** `token`, `sequence` | +| `action` | string | `"mask"` | Action to take when metric exceeds threshold. 'mask': zero out loss_mask for filtered tokens/sequences (rejection, completely excludes from gradient computation). 'clamp': clamp importance weight to \[lower, upper\] bounds (truncation, tokens still participate in gradient but with bounded weight). **Choices:** `mask`, `clamp` | +| `metric` | string | `"ratio"` | Divergence metric for filtering. 'ratio': direct importance ratio π_proximal/π_behave. 'kl_k1': KL estimator k1 = log(r), forward KL unbiased estimator (can be negative). 'kl_k2': KL estimator k2 = 0.5 * (log r)^2, non-negative quadratic approximation. 'kl_k3': KL estimator k3 = r - log(r) - 1, non-negative exact forward KL estimator. **Choices:** `ratio`, `kl_k1`, `kl_k2`, `kl_k3` | +| `agg` | string | `"mean"` | Aggregation method for sequence-level filtering. Only used when level='sequence'. For 'ratio' metric, aggregation is in log space: 'sum' = exp(sum(log(r_i))), 'mean' = exp(mean(log(r_i))) = geometric mean (length-invariant, consistent with GSPO). For KL metrics, aggregation is arithmetic: 'sum' = sum(kl_i), 'mean' = mean(kl_i). 'max': max of per-token metric values (most conservative). **Choices:** `sum`, `mean`, `max` | +| `upper` | float | `5.0` | Upper bound for filtering. Tokens/sequences with metric > upper are filtered out (loss_mask zeroed). For 'ratio' metric: must be > 1.0, typical values are 2.0 or 5.0. For 'kl_k2'/'kl_k3' metrics: typical values are 0.5-2.0. | +| `lower` | float \| None | `None` | Lower bound for filtering (optional). None means no lower bound. For 'ratio' metric: typical value is 0.5 (filter out tokens where policy probability dropped significantly). Must be > 0. For 'kl_k1' metric: can be used to filter negative KL estimates. | +| `token_action` | string \| None | `None` | Enables two-stage Geo-RS + Token-MIS/TIS mode. Only valid when level='sequence' and metric='ratio'. Stage 1 (Geo-RS): sequences whose geometric-mean importance ratio exceeds `upper` are fully rejected (loss_mask zeroed for all tokens). Stage 2: on accepted sequences, apply per-token correction using this action — 'mask' (Token-MIS: zero tokens where per-token ratio > upper) or 'clamp' (Token-TIS: clamp per-token ratio to \[lower, upper\]). None disables Stage 2 (pure sequence-level Geo-RS only). Experimental results (PR #1084) show that neither Geo-RS alone nor Token-MIS/TIS alone is stable under severe off-policy drift; the two-stage combination is necessary for both grad_norm and approx_kl stability. **Choices:** `mask`, `clamp` | (section-scheduler)= diff --git a/docs/zh/cli_reference.md b/docs/zh/cli_reference.md index 0a71712686..17439e7a43 100644 --- a/docs/zh/cli_reference.md +++ b/docs/zh/cli_reference.md @@ -1032,16 +1032,18 @@ Attributes: For KL metrics, aggregation is arithmetic. upper: Upper bound for filtering. lower: Lower bound for filtering (optional). + token_action: Enables two-stage rejection sampling and importance sampling mode ('mask' or 'clamp')(optional). ``` -| Parameter | Type | Default | Description | -| --------- | ------------- | --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `level` | string | `"token"` | Filtering granularity. 'token': per-token filtering (each token judged independently). 'sequence': per-sequence filtering (all tokens in a sequence share the same fate). When metric='ratio', both the filtering decision and the correction weight (behave_imp_weight) operate at sequence level using the geometric mean. **Choices:** `token`, `sequence` | -| `action` | string | `"mask"` | Action to take when metric exceeds threshold. 'mask': zero out loss_mask for filtered tokens/sequences (rejection, completely excludes from gradient computation). 'clamp': clamp importance weight to \[lower, upper\] bounds (truncation, tokens still participate in gradient but with bounded weight). **Choices:** `mask`, `clamp` | -| `metric` | string | `"ratio"` | Divergence metric for filtering. 'ratio': direct importance ratio π_proximal/π_behave. 'kl_k1': KL estimator k1 = log(r), forward KL unbiased estimator (can be negative). 'kl_k2': KL estimator k2 = 0.5 * (log r)^2, non-negative quadratic approximation. 'kl_k3': KL estimator k3 = r - log(r) - 1, non-negative exact forward KL estimator. **Choices:** `ratio`, `kl_k1`, `kl_k2`, `kl_k3` | -| `agg` | string | `"mean"` | Aggregation method for sequence-level filtering. Only used when level='sequence'. For 'ratio' metric, aggregation is in log space: 'sum' = exp(sum(log(r_i))), 'mean' = exp(mean(log(r_i))) = geometric mean (length-invariant, consistent with GSPO). For KL metrics, aggregation is arithmetic: 'sum' = sum(kl_i), 'mean' = mean(kl_i). 'max': max of per-token metric values (most conservative). **Choices:** `sum`, `mean`, `max` | -| `upper` | float | `5.0` | Upper bound for filtering. Tokens/sequences with metric > upper are filtered out (loss_mask zeroed). For 'ratio' metric: must be > 1.0, typical values are 2.0 or 5.0. For 'kl_k2'/'kl_k3' metrics: typical values are 0.5-2.0. | -| `lower` | float \| None | `None` | Lower bound for filtering (optional). None means no lower bound. For 'ratio' metric: typical value is 0.5 (filter out tokens where policy probability dropped significantly). Must be > 0. For 'kl_k1' metric: can be used to filter negative KL estimates. | +| Parameter | Type | Default | Description | +| -------------- | -------------- | --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `level` | string | `"token"` | Filtering granularity. 'token': per-token filtering (each token judged independently). 'sequence': per-sequence filtering (all tokens in a sequence share the same fate). When metric='ratio', both the filtering decision and the correction weight (behave_imp_weight) operate at sequence level using the geometric mean. **Choices:** `token`, `sequence` | +| `action` | string | `"mask"` | Action to take when metric exceeds threshold. 'mask': zero out loss_mask for filtered tokens/sequences (rejection, completely excludes from gradient computation). 'clamp': clamp importance weight to \[lower, upper\] bounds (truncation, tokens still participate in gradient but with bounded weight). **Choices:** `mask`, `clamp` | +| `metric` | string | `"ratio"` | Divergence metric for filtering. 'ratio': direct importance ratio π_proximal/π_behave. 'kl_k1': KL estimator k1 = log(r), forward KL unbiased estimator (can be negative). 'kl_k2': KL estimator k2 = 0.5 * (log r)^2, non-negative quadratic approximation. 'kl_k3': KL estimator k3 = r - log(r) - 1, non-negative exact forward KL estimator. **Choices:** `ratio`, `kl_k1`, `kl_k2`, `kl_k3` | +| `agg` | string | `"mean"` | Aggregation method for sequence-level filtering. Only used when level='sequence'. For 'ratio' metric, aggregation is in log space: 'sum' = exp(sum(log(r_i))), 'mean' = exp(mean(log(r_i))) = geometric mean (length-invariant, consistent with GSPO). For KL metrics, aggregation is arithmetic: 'sum' = sum(kl_i), 'mean' = mean(kl_i). 'max': max of per-token metric values (most conservative). **Choices:** `sum`, `mean`, `max` | +| `upper` | float | `5.0` | Upper bound for filtering. Tokens/sequences with metric > upper are filtered out (loss_mask zeroed). For 'ratio' metric: must be > 1.0, typical values are 2.0 or 5.0. For 'kl_k2'/'kl_k3' metrics: typical values are 0.5-2.0. | +| `lower` | float \| None | `None` | Lower bound for filtering (optional). None means no lower bound. For 'ratio' metric: typical value is 0.5 (filter out tokens where policy probability dropped significantly). Must be > 0. For 'kl_k1' metric: can be used to filter negative KL estimates. | +| `token_action` | string \| None | `None` | Enables two-stage Geo-RS + Token-MIS/TIS mode. Only valid when level='sequence' and metric='ratio'. Stage 1 (Geo-RS): sequences whose geometric-mean importance ratio exceeds `upper` are fully rejected (loss_mask zeroed for all tokens). Stage 2: on accepted sequences, apply per-token correction using this action — 'mask' (Token-MIS: zero tokens where per-token ratio > upper) or 'clamp' (Token-TIS: clamp per-token ratio to \[lower, upper\]). None disables Stage 2 (pure sequence-level Geo-RS only). Experimental results (PR #1084) show that neither Geo-RS alone nor Token-MIS/TIS alone is stable under severe off-policy drift; the two-stage combination is necessary for both grad_norm and approx_kl stability. **Choices:** `mask`, `clamp` | (section-scheduler)= diff --git a/tests/test_rejection_sampling.py b/tests/test_rejection_sampling.py index dd33b44a28..9a081c6b07 100644 --- a/tests/test_rejection_sampling.py +++ b/tests/test_rejection_sampling.py @@ -926,11 +926,13 @@ def _batch_inputs(): Sequence 1: per-token ratio = 3.0 → geo-mean = 3.0 (rejected, > upper) Sequence 2: per-token ratio = 0.8 → geo-mean = 0.8 (accepted) """ - ratios = torch.tensor([ - [1.5, 1.5, 1.5, 1.5], - [3.0, 3.0, 3.0, 3.0], - [0.8, 0.8, 0.8, 0.8], - ]) + ratios = torch.tensor( + [ + [1.5, 1.5, 1.5, 1.5], + [3.0, 3.0, 3.0, 3.0], + [0.8, 0.8, 0.8, 0.8], + ] + ) loss_mask = torch.ones(3, 4) proximal_logprobs = torch.log(ratios) old_logprobs = torch.zeros_like(proximal_logprobs) @@ -939,8 +941,12 @@ def _batch_inputs(): def test_stage1_rejects_divergent_sequence(self): """Stage 1 (Geo-RS) must fully zero-out the rejected sequence.""" config = RejectionSamplingConfig( - level="sequence", action="mask", metric="ratio", - agg="mean", upper=2.0, token_action="mask", + level="sequence", + action="mask", + metric="ratio", + agg="mean", + upper=2.0, + token_action="mask", ) loss_mask, ratios, proximal_logprobs, old_logprobs = self._batch_inputs() result = apply_rejection_sampling( @@ -964,16 +970,22 @@ def test_stage2_mis_filters_high_token_within_accepted_seq(self): a sequence that was accepted by Geo-RS. """ config = RejectionSamplingConfig( - level="sequence", action="mask", metric="ratio", - agg="mean", upper=2.0, token_action="mask", + level="sequence", + action="mask", + metric="ratio", + agg="mean", + upper=2.0, + token_action="mask", ) # Seq 0: geo-mean ≈ exp(mean([0, 0, log(2.5), 0])) ≈ 1.26 → accepted by Geo-RS # but token[2] = 2.5 > upper → masked by Token-MIS # Seq 1: all ratios = 1.0 → accepted, all tokens kept - ratios = torch.tensor([ - [1.0, 1.0, 2.5, 1.0], - [1.0, 1.0, 1.0, 1.0], - ]) + ratios = torch.tensor( + [ + [1.0, 1.0, 2.5, 1.0], + [1.0, 1.0, 1.0, 1.0], + ] + ) loss_mask = torch.ones(2, 4) proximal_logprobs = torch.log(ratios) old_logprobs = torch.zeros_like(proximal_logprobs) @@ -999,14 +1011,21 @@ def test_stage2_tis_clamps_token_weights_not_mask(self): All tokens continue to contribute to the gradient. """ config = RejectionSamplingConfig( - level="sequence", action="mask", metric="ratio", - agg="mean", upper=2.0, lower=0.5, token_action="clamp", + level="sequence", + action="mask", + metric="ratio", + agg="mean", + upper=2.0, + lower=0.5, + token_action="clamp", ) # Both sequences accepted by Geo-RS (geo-means ≤ 2.0). - ratios = torch.tensor([ - [0.2, 1.0, 1.8, 3.5], # tokens 0 and 3 out of [0.5, 2.0] - [0.8, 1.2, 1.5, 0.9], # all in range - ]) + ratios = torch.tensor( + [ + [0.2, 1.0, 1.8, 3.5], # tokens 0 and 3 out of [0.5, 2.0] + [0.8, 1.2, 1.5, 0.9], # all in range + ] + ) loss_mask = torch.ones(2, 4) proximal_logprobs = torch.log(ratios.clamp(min=1e-6)) old_logprobs = torch.zeros_like(proximal_logprobs) @@ -1024,10 +1043,10 @@ def test_stage2_tis_clamps_token_weights_not_mask(self): # loss_mask must be entirely unchanged — TIS never zeros tokens. assert new_mask.sum() == 8, "Token-TIS must not zero any loss_mask tokens" # Weights clamped to [0.5, 2.0]. - assert new_weight[0, 0] == pytest.approx(0.5), "0.2 clamped to lower=0.5" - assert new_weight[0, 1] == pytest.approx(1.0), "1.0 unchanged" - assert new_weight[0, 2] == pytest.approx(1.8), "1.8 unchanged" - assert new_weight[0, 3] == pytest.approx(2.0), "3.5 clamped to upper=2.0" + assert new_weight[0, 0] == pytest.approx(0.5), "0.2 clamped to lower=0.5" + assert new_weight[0, 1] == pytest.approx(1.0), "1.0 unchanged" + assert new_weight[0, 2] == pytest.approx(1.8), "1.8 unchanged" + assert new_weight[0, 3] == pytest.approx(2.0), "3.5 clamped to upper=2.0" assert new_weight[1].allclose(ratios[1]), "Seq 1 weights unchanged" def test_stage1_dominates_even_if_stage2_would_pass(self): @@ -1036,8 +1055,12 @@ def test_stage1_dominates_even_if_stage2_would_pass(self): individual token ratio would have passed the Token-MIS threshold. """ config = RejectionSamplingConfig( - level="sequence", action="mask", metric="ratio", - agg="mean", upper=2.0, token_action="mask", + level="sequence", + action="mask", + metric="ratio", + agg="mean", + upper=2.0, + token_action="mask", ) loss_mask = torch.ones(1, 4) # geo-mean = 4.0 > 2.0 → Stage 1 rejects this sequence entirely. @@ -1053,7 +1076,7 @@ def test_stage1_dominates_even_if_stage2_would_pass(self): proximal_logprobs=proximal_logprobs, old_logprobs=old_logprobs, ) - new_mask=result.loss_mask + new_mask = result.loss_mask assert new_mask.sum() == 0, "Stage 1 rejection must dominate Stage 2" def test_none_token_action_identical_to_pure_sequence_geo_rs(self): @@ -1061,22 +1084,31 @@ def test_none_token_action_identical_to_pure_sequence_geo_rs(self): token_action=None must produce results identical to the existing level='sequence', action='mask' mode — no Stage 2 runs. """ - ratios = torch.tensor([ - [1.5, 1.5, 1.5, 1.5], - [3.0, 3.0, 3.0, 3.0], - [0.8, 0.8, 0.8, 0.8], - ]) + ratios = torch.tensor( + [ + [1.5, 1.5, 1.5, 1.5], + [3.0, 3.0, 3.0, 3.0], + [0.8, 0.8, 0.8, 0.8], + ] + ) loss_mask = torch.ones(3, 4) proximal_logprobs = torch.log(ratios) old_logprobs = torch.zeros_like(proximal_logprobs) cfg_two_stage_off = RejectionSamplingConfig( - level="sequence", action="mask", metric="ratio", - agg="mean", upper=2.0, token_action=None, + level="sequence", + action="mask", + metric="ratio", + agg="mean", + upper=2.0, + token_action=None, ) cfg_original = RejectionSamplingConfig( - level="sequence", action="mask", metric="ratio", - agg="mean", upper=2.0, + level="sequence", + action="mask", + metric="ratio", + agg="mean", + upper=2.0, ) result = apply_rejection_sampling( @@ -1084,7 +1116,7 @@ def test_none_token_action_identical_to_pure_sequence_geo_rs(self): ) mask_off = result.loss_mask w_off = result.behave_imp_weight - result= apply_rejection_sampling( + result = apply_rejection_sampling( proximal_logprobs, old_logprobs, loss_mask.clone(), None, cfg_original ) mask_orig = result.loss_mask @@ -1099,10 +1131,15 @@ def test_lower_bound_also_applied_in_token_mis(self): falls below `lower` (policy has dropped sharply at that token). """ config = RejectionSamplingConfig( - level="sequence", action="mask", metric="ratio", - agg="mean", upper=3.0, lower=0.5, token_action="mask", + level="sequence", + action="mask", + metric="ratio", + agg="mean", + upper=3.0, + lower=0.5, + token_action="mask", ) - loss_mask = torch.ones(1,4) + loss_mask = torch.ones(1, 4) # Seq geo-mean ≈ exp(mean(log([0.3, 1.0, 1.0, 1.0]))) ≈ 0.84 → accepted # but token[0] = 0.3 < lower=0.5 → masked by Token-MIS ratios = torch.tensor([[0.3, 1.0, 1.0, 1.0]]) @@ -1118,10 +1155,11 @@ def test_lower_bound_also_applied_in_token_mis(self): old_logprobs=old_logprobs, ) new_mask = result.loss_mask - assert new_mask[0,0] == 0, "Token below lower bound must be masked" - assert new_mask[0,1] == 1 - assert new_mask[0,2] == 1 - assert new_mask[0,3] == 1 + assert new_mask[0, 0] == 0, "Token below lower bound must be masked" + assert new_mask[0, 1] == 1 + assert new_mask[0, 2] == 1 + assert new_mask[0, 3] == 1 + def test_invalid_agg_raises(self): """Invalid agg should raise ValueError.""" with pytest.raises(ValueError, match="agg must be one of"):