diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 58b763bdb5..af238760d8 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -3,6 +3,7 @@ import argparse import json import os +import warnings from dataclasses import MISSING as dataclass_missing from dataclasses import asdict, dataclass, field, fields from enum import Enum @@ -1132,10 +1133,52 @@ class TrainEngineConfig: metadata={"help": "peft method type. Only LoRA is supported for now."}, ) - # Tree training + # Tree training (str, not Literal: OmegaConf.structured rejects Literal here) + tree_training_mode: str = field( + default="disabled", + metadata={ + "help": ( + "Tree training mode. " + "'sparse' enables tree training with Flex Attention module (flex attention), " + "'dta' enables Dynamic Tree Attention (dynamic tree training), " + "'disabled' disables tree training." + ), + "choices": ["disabled", "sparse", "dta"], + }, + ) enable_tree_training: bool = field( default=False, - metadata={"help": "Enable tree training with flex attention module."}, + metadata={ + "help": ( + "[DEPRECATED] Use tree_training_mode instead. " + "enable_tree_training=True maps to tree_training_mode='sparse'. " + "If both are set, tree_training_mode takes precedence." + ) + }, + ) + dta_block_size: int = field( + default=2048, + metadata={ + "help": ( + "Block size for Dynamic Tree Attention. " + "Set to -1 to disable block-size limit. " + "Only effective when tree_training_mode='dta'." + ) + }, + ) + packing_algorithm: str = field( + default="ffd", + metadata={ + "help": ( + "Trajectory packing across data-parallel ranks during distributed rollout " + "(``redistribute_trajectories``). " + "'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order " + "n_tree_tokens. " + "Not to be confused with ``mb_spec.packing_algorithm``, which only " + "controls micro-batch formation (ffd/kk) during training." + ), + "choices": ["ffd", "kk", "dta"], + }, ) # Scheduling @@ -1208,6 +1251,45 @@ def __post_init__(self): "memory_efficient_load is for loading pretrained weights on CPU, " "but init_from_scratch creates a model without loading any weights." ) + valid_tree_modes = {"disabled", "sparse", "dta"} + if self.tree_training_mode not in valid_tree_modes: + raise ValueError( + f"tree_training_mode must be one of {valid_tree_modes}, got '{self.tree_training_mode}'" + ) + valid_rollout_packing = {"ffd", "kk", "dta"} + if self.packing_algorithm not in valid_rollout_packing: + raise ValueError( + f"packing_algorithm (rollout) must be one of {valid_rollout_packing}, " + f"got '{self.packing_algorithm}'" + ) + if self.tree_training_mode == "dta": + if self.dta_block_size == 0 or self.dta_block_size < -1: + raise ValueError( + f"dta_block_size must be -1 or a positive integer when tree_training_mode='dta', got {self.dta_block_size}." + ) + + if self.enable_tree_training: + warnings.warn( + "`enable_tree_training` is deprecated and will be removed in a future version. " + "Use `tree_training_mode='sparse'` instead.", + FutureWarning, + stacklevel=2, + ) + if self.tree_training_mode != "disabled": + warnings.warn( + f"`tree_training_mode` is already set to '{self.tree_training_mode}', " + "`enable_tree_training=True` is ignored.", + FutureWarning, + stacklevel=2, + ) + else: + self.tree_training_mode = "sparse" + warnings.warn( + "`tree_training_mode` is overridden to 'sparse' from deprecated " + "`enable_tree_training=True`.", + FutureWarning, + stacklevel=2, + ) if self._version not in ("v1", "v2"): raise ValueError( f"_version must be either 'v1' or 'v2', got '{self._version}'" @@ -1576,6 +1658,22 @@ def __post_init__(self): "Please set `actor.use_decoupled_loss=false` in your configuration." ) + if self.packing_algorithm == "dta": + for norm_name in ["adv_norm", "reward_norm"]: + norm_config = getattr(self, norm_name) + if norm_config is not None: + if ( + norm_config.mean_level == "group" + or norm_config.std_level == "group" + ): + raise ValueError( + f"{norm_name} uses 'group' level normalization, which is incompatible " + "with packing_algorithm='dta'. DTA requires sequence-level independence, " + "but 'group' normalization relies on contiguous group slices. Please use " + "'batch' level normalization or set packing_algorithm='ffd'. " + "(Group-level support for DTA will be provided in a future release.)" + ) + super().__post_init__() diff --git a/areal/engine/fsdp_engine.py b/areal/engine/fsdp_engine.py index c22667d9d0..55a843dd3d 100644 --- a/areal/engine/fsdp_engine.py +++ b/areal/engine/fsdp_engine.py @@ -260,9 +260,14 @@ def __init__(self, config: TrainEngineConfig): self.dp_rank: int self.is_offload: bool = False + self.tree_training_mode: str = self.config.tree_training_mode + if self.tree_training_mode == "dta": + raise ValueError( + "tree_training_mode='dta' is only supported by ArchonEngine. " + "Please use Archon backend or set tree_training_mode to 'disabled'/'sparse'." + ) self._offload_depth: int = 0 self._per_layer_optim_wrapper: PerLayerOptimWrapper | None = None - self.enable_tree_training: bool = self.config.enable_tree_training @classmethod def from_pretrained( @@ -383,7 +388,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): # Create device model self._create_device_model() - if self.enable_tree_training and self.parallel_helper.sp_size > 1: + if self.tree_training_mode == "sparse" and self.parallel_helper.sp_size > 1: raise ValueError( "Tree training currently cannot be enabled with sp_size > 1." ) @@ -394,7 +399,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): shard_vision_across_sp=self.config.fsdp.shard_vision_across_sp, ) # Monkey patch: replace attention's forward() with tree attention. - patch_fsdp_for_tree_training(enable=self.enable_tree_training) + patch_fsdp_for_tree_training(enable=self.tree_training_mode == "sparse") if self.config.use_lora: self._apply_peft_wrapper() @@ -732,7 +737,7 @@ def forward_backward_batch( # module_fsdp.py reads these keys from the **kwargs that transformers # forwards through. tree_attn_keys: list[str] = [] - if self.enable_tree_training and ctx.trie_node is not None: + if self.tree_training_mode == "sparse" and ctx.trie_node is not None: padded_size = mb_item.padded_to_length assert padded_size is not None tree_kwargs = build_tree_attn_kwargs( @@ -880,8 +885,8 @@ def process_output(logits: torch.Tensor, ctx_dict: dict[str, Any]) -> None: self.forward_backward_batch(mb_list, process_output, forward_only=True) # Step 4: Aggregate and reorder outputs - if self.enable_tree_training: - result = merge_packed_tree_results(outputs, batch_size) + if self.tree_training_mode == "sparse": + return merge_packed_tree_results(outputs, batch_size) else: result = reorder_and_pad_outputs( outputs, output_seqlens, mb_list, aggregate_fn @@ -1589,7 +1594,7 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList: input_ = input_.copy() # Tree training path - if self.enable_tree_training: + if self.tree_training_mode == "sparse": mb_list = build_packed_tree_batch( input_, mb_spec=self.config.mb_spec, @@ -1855,12 +1860,12 @@ def _compute_logprobs_and_loss( if local_weight == 0: return logits.mean() * 0.0 - if self.config.is_critic and self.enable_tree_training: + if self.config.is_critic and self.tree_training_mode == "sparse": raise NotImplementedError( "Tree training with critic model is not supported yet." ) if not self.config.is_critic: - if self.enable_tree_training: + if self.tree_training_mode == "sparse": # Handle dummy trie (empty tree for DP synchronization) # When trie has no sequences, return zero loss with grad connection if ctx.trie_node is None or not ctx.trie_node.all_sequence_ids: @@ -1918,12 +1923,12 @@ def _compute_forward_result( ctx: FSDPTrainContext, ) -> torch.Tensor | dict[int, torch.Tensor]: """Compute forward output (logprobs or values).""" - if self.config.is_critic and self.enable_tree_training: + if self.config.is_critic and self.tree_training_mode == "sparse": raise NotImplementedError( "Tree training with critic model is not supported yet." ) if not self.config.is_critic: - if self.enable_tree_training: + if self.tree_training_mode == "sparse": result = _gather_packed_tree_logprobs( logits, ctx.trie_node, diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 6e22e2cb1c..abf814fa43 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -189,8 +189,13 @@ def __init__(self, config: TrainEngineConfig): self.seed: int = 0 self.own_global_group: bool = False self.is_offload: bool = False + self.tree_training_mode: str = self.config.tree_training_mode + if self.tree_training_mode == "dta": + raise ValueError( + "tree_training_mode='dta' is only supported by ArchonEngine. " + "Please use Archon backend or set tree_training_mode to 'disabled'/'sparse'." + ) self._offload_depth: int = 0 - self.enable_tree_training: bool = self.config.enable_tree_training # FP8 configuration self.fp8_config = self.mcore_config.fp8_config self.enable_fp8: bool = self.fp8_config is not None @@ -323,7 +328,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): self.tokenizer = load_hf_tokenizer(self.config.path) with patch_bridge_for_tree_training( - self.enable_tree_training and self.bridge_cls == "mbridge" + self.tree_training_mode == "sparse" and self.bridge_cls == "mbridge" ): self.bridge = self._build_hf_mcore_bridge() @@ -807,7 +812,7 @@ def forward_step(batch_iter, model): # save_for_backward() which can only save torch.Tensor objects; # BlockMask is recreated inside PytorchFlexAttention.forward(). tree_attn_keys: list[str] = [] - if self.enable_tree_training: + if self.tree_training_mode == "sparse": trie_node = mb_input.padded_mb.get("trie_node", None) # Ensure trie_node is also in orig_mb for _compute_logprobs_and_loss if trie_node is not None and "trie_node" not in mb_input.orig_mb: @@ -1029,7 +1034,7 @@ def process_output(output: torch.Tensor, inputs: dict[str, Any]) -> None: # Step 4: Aggregate, reorder, and broadcast outputs res = None if mpu.is_pipeline_last_stage(): - if self.enable_tree_training: + if self.tree_training_mode == "sparse": res = merge_packed_tree_results(outputs, batch_size) else: res = reorder_and_pad_outputs( @@ -1812,7 +1817,7 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList: pp_size = self.parallel_strategy.pipeline_parallel_size cp_size = self.parallel_strategy.context_parallel_size tp_size = self.parallel_strategy.tensor_parallel_size - if self.enable_tree_training: + if self.tree_training_mode == "sparse": assert cp_size == 1, ( "Context parallelism is not supported in tree training." ) @@ -1922,12 +1927,12 @@ def _compute_logprobs_and_loss( if local_weight == 0: return output.mean() * 0.0 - if self.config.is_critic and self.enable_tree_training: + if self.config.is_critic and self.tree_training_mode == "sparse": raise NotImplementedError( "Tree training with critic model is not supported yet." ) if not self.config.is_critic: - if self.enable_tree_training: + if self.tree_training_mode == "sparse": # Handle dummy trie (empty tree for DP synchronization) # When trie has no sequences, return zero loss with grad connection trie_node = inputs.get("trie_node") @@ -1987,12 +1992,12 @@ def _compute_forward_result( output: torch.Tensor, inputs: dict[str, Any], ) -> torch.Tensor | dict[int, torch.Tensor]: - if self.config.is_critic and self.enable_tree_training: + if self.config.is_critic and self.tree_training_mode == "sparse": raise NotImplementedError( "Tree training with critic model is not supported yet." ) if not self.config.is_critic: - if self.enable_tree_training: + if self.tree_training_mode == "sparse": logprobs = _gather_packed_tree_logprobs( output, inputs["trie_node"], diff --git a/areal/experimental/archon/__init__.py b/areal/experimental/archon/__init__.py new file mode 100644 index 0000000000..79d8c395eb --- /dev/null +++ b/areal/experimental/archon/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Archon experimental testing helpers exposed under `areal.experimental`.""" diff --git a/areal/experimental/archon/torchrun/__init__.py b/areal/experimental/archon/torchrun/__init__.py new file mode 100644 index 0000000000..f3fa8d13ab --- /dev/null +++ b/areal/experimental/archon/torchrun/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Torchrun helpers for Archon experimental tests.""" diff --git a/areal/experimental/archon/utils.py b/areal/experimental/archon/utils.py new file mode 100644 index 0000000000..1a1acc039a --- /dev/null +++ b/areal/experimental/archon/utils.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Small Archon utility helpers shared by experimental test runners.""" + + +def strip_wrapper_prefixes(name: str) -> str: + """Drop wrapper-generated path segments from parameter names.""" + return name.replace("._checkpoint_wrapped_module", "").replace("._orig_mod", "") diff --git a/areal/experimental/dta/dp.py b/areal/experimental/dta/dp.py new file mode 100644 index 0000000000..97f3ac6c54 --- /dev/null +++ b/areal/experimental/dta/dp.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 + +# This code is adapted with minor modifications from +# https://github.com/Whisper-6/DynamicTreeAttn/blob/main/data_parallel.py. +# Special thanks to Yuchen Yang for significant contributions to the load-balanced data parallel partitioning algorithm. +from types import SimpleNamespace + +from areal.experimental.dta.token_trie import TokenTrie +from areal.experimental.dta.trie import CompressedTrie, _get_stats, _get_subtrie + + +def LB_by_n_tokens(token_seqs, K): + bins = [[] for _ in range(K)] + bin_lens = [0] * K + seq_indices = sorted(range(len(token_seqs)), key=lambda i: -len(token_seqs[i])) + for i in seq_indices: + min_bin = min(range(K), key=lambda j: bin_lens[j]) + bins[min_bin].append(i) + bin_lens[min_bin] += len(token_seqs[i]) + return bins + + +def pred_time( + compressed_trie, time_model, mode: str, block_size: int | None = None +) -> float: + if mode == "forward": + _, lens, lcp_lens = compressed_trie.get_order_forward() + elif mode == "backward": + _, lens, lcp_lens = compressed_trie.get_order_backward() + else: + raise ValueError(f"Unsupported mode: {mode}") + + stats = _get_stats(lens, lcp_lens, mode, block_size) + return time_model.pred(stats) + + +def get_original_bins( + token_trie: TokenTrie, leaf_bins: list[list[int]] +) -> list[list[int]]: + bins = [[] for _ in range(len(leaf_bins))] + for bucket_idx, leaf_bucket in enumerate(leaf_bins): + for leaf_idx in leaf_bucket: + attach_lists = token_trie.attach_lists[leaf_idx] + for attach, _ in attach_lists: + original_seq_idx = attach["_sequence_batch_id"] + bins[bucket_idx].append(original_seq_idx) + return bins + + +def LB_by_TM(token_seqs, time_model, config: SimpleNamespace): + token_trie = TokenTrie(token_seqs) + n_leaf_seqs = len(token_trie.inputs) + compressed_trie = CompressedTrie(token_trie.lens, token_trie.lcp_lens) + + K = config.K + leaf_bins = [[] for _ in range(K)] + bin_times = [0.0] * K + + for i in range(n_leaf_seqs): + min_bin = min(range(K), key=lambda j: bin_times[j]) + leaf_bins[min_bin].append(i) + bin_compressed_trie = _get_subtrie(compressed_trie, leaf_bins[min_bin]) + bin_times[min_bin] = pred_time( + bin_compressed_trie, time_model, config.mode, config.block_size + ) + + bins = get_original_bins(token_trie, leaf_bins) + return bins + + +def try_divide( + compressed_trie, + n_seqs, + config: SimpleNamespace, + divL, + divR, + time_model, + cost_limit: float, +) -> list[list[int]] | None: + K = config.K + divs = [] + + start = 0 + while start < n_seqs: + divs.append(start) + if len(divs) > K: + break + L = max(divL[len(divs)] - 1, start) + R = divR[len(divs)] - 1 + while L < R: + mid = (L + R + 1) // 2 + cur_subtrie = _get_subtrie(compressed_trie, set(range(start, mid + 1))) + est_time = pred_time( + cur_subtrie, time_model, config.mode, config.block_size + ) + if est_time <= cost_limit: + L = mid + else: + R = mid - 1 + start = L + 1 + + return divs + + +def LB_by_DFS_and_TM(token_seqs, time_model, config: SimpleNamespace): + token_trie = TokenTrie(token_seqs) + n_leaf_seqs = len(token_trie.inputs) + K = config.K + if n_leaf_seqs == 0: + return [[] for _ in range(K)] + + compressed_trie = CompressedTrie(token_trie.lens, token_trie.lcp_lens) + + R = float(pred_time(compressed_trie, time_model, config.mode, config.block_size)) + L = R / K + eps = R * 1e-4 + + divL = [0] * (K + 1) + # Maintain a valid initial partition boundary so K==1 (L==R) does not + # skip the search and accidentally produce empty bins. + divR = [0] + [n_leaf_seqs] * K + + while R - L > eps: + mid = (L + R) / 2.0 + divs = try_divide( + compressed_trie, n_leaf_seqs, config, divL, divR, time_model, mid + ) + if len(divs) <= K: + R = mid + divR[: len(divs)] = divs + else: + L = mid + eps + divL = divs[: K + 1] + + leaf_bins = [list(range(divR[i], divR[i + 1])) for i in range(K)] + bins = get_original_bins(token_trie, leaf_bins) + return bins + + +# -------- Test -------- + + +def eval(token_seqs, bins, time_model, config: SimpleNamespace): + total_time = 0.0 + max_time = 0.0 + for bucket in bins: + token_trie = TokenTrie([token_seqs[i] for i in bucket]) + compressed_trie = CompressedTrie(token_trie.lens, token_trie.lcp_lens) + bucket_pred_time = pred_time( + compressed_trie, time_model, config.mode, config.block_size + ) + total_time += bucket_pred_time + max_time = max(max_time, bucket_pred_time) + return total_time, max_time diff --git a/areal/experimental/dta/dta_engine.py b/areal/experimental/dta/dta_engine.py new file mode 100644 index 0000000000..cfc646aba9 --- /dev/null +++ b/areal/experimental/dta/dta_engine.py @@ -0,0 +1,786 @@ +# SPDX-License-Identifier: Apache-2.0 + +# The following code is adapted with minor modifications from +# https://github.com/Whisper-6/DynamicTreeAttn/blob/main/tree_training_engine.py. +# Special thanks to Yuchen Yang for outstanding contributions to core DTA algorithms +# and optimizations, including chunked backpropagation and cut tail features. + +from bisect import bisect_left, bisect_right +from math import ceil +from typing import NoReturn + +import torch +import torch.nn.functional as F +from transformers.cache_utils import DynamicCache + +from areal.utils.functional import gather_logprobs, gather_logprobs_entropy +from areal.utils.logging import getLogger + +NO_BLOCK_SIZE_LIMIT = int(1e9) + + +def _get_forkpos(lens, lcp_lens, block_size: int | None) -> list: + """ + Compute all fork positions that DTAEngine's stack must track. + + Fork positions are token indices where: + 1) Sequences diverge (longest common prefix boundaries) + 2) Block boundaries for long sequences to reduce memory usage + + Returns a sorted list of unique fork positions. + """ + + forkpos_list = [] + + # 1. Fork positions induced by branching (LCP boundaries) + for lcp in lcp_lens: + if lcp > 0: + forkpos_list.append(lcp - 1) + + # 2. Fork positions induced by block segmentation + if block_size is not None: + for i in range(len(lens)): + start = 0 if i == len(lcp_lens) else lcp_lens[i] + end = lens[i] + + pop_len = end - start + n_blocks = ceil(pop_len / block_size) + block_size_actual = ceil(pop_len / n_blocks) + + for b in range(n_blocks): + pop_start = max(end - (b + 1) * block_size_actual, start) + if pop_start > 0: + forkpos_list.append(pop_start - 1) + + forkpos_list = list(set(forkpos_list)) + forkpos_list.sort() + + return forkpos_list + + +class DTAEngine: + """ + Engine for backward computation over sequences with shared prefixes. + + DTAEngine stores only necessary KV caches, logits at fork + positions, log-probs, and entropy to efficiently compute gradients + for multiple sequences while saving memory. + + Supports block-wise popping to reduce GPU memory peak. + """ + + def __init__( + self, + model_config, + device, + dtype: torch.dtype, + max_seq_len: int, + forward_only: bool = False, + is_critic: bool = False, + ): + """ + Initialize DTAEngine with model config, device and buffer sizes. + + Buffers for tokens, logprobs, entropy and KV caches are preallocated + to max_seq_len. + """ + self.model = None + self.device = device + self.dtype = dtype + self.max_seq_len = max_seq_len + self.is_critic = is_critic + + # ------------------------------------------------------------------------ + # Initialize static stack buffers + # ------------------------------------------------------------------------ + self.cur_len = 0 + + # Token buffer + self.tokens = torch.zeros((max_seq_len), device=self.device, dtype=torch.long) + + if self.is_critic: + # Value buffer for critic + self.values = torch.zeros( + (max_seq_len), device=self.device, dtype=torch.float32 + ) + if not forward_only: + self.grad_values = torch.zeros( + (max_seq_len), device=self.device, dtype=dtype + ) + else: + # Entropy buffer + if not forward_only: + self.entropy = torch.zeros( + (max_seq_len), device=self.device, dtype=torch.float32 + ) + self.grad_entropy = torch.zeros( + (max_seq_len), device=self.device, dtype=dtype + ) + + # Logprob buffer + self.logprobs = torch.zeros( + (max_seq_len), device=self.device, dtype=torch.float32 + ) + if not forward_only: + self.grad_logprobs = torch.zeros( + (max_seq_len), device=self.device, dtype=dtype + ) + + # Fork position logits buffer (store logits only at fork positions, others are None) + self.forkpos_list = [] # List of all fork positions + self.forkpos_logits: list[torch.Tensor | None] = [ + None + ] * max_seq_len # Logits at fork positions for computing logprobs + if not forward_only: + self.grad_forkpos_logits: list[torch.Tensor | None] = [ + None + ] * max_seq_len # Gradients of logits at fork positions + + # Attachments buffer + self.attachs = [] # List of sequences retained in the stack, including (attachments, length) + + # KV cache buffers + self.n_layers = model_config.num_hidden_layers + n_kv_heads = model_config.num_key_value_heads + # Compatible with Qwen2.5 and Qwen3 series + head_dim = ( + model_config.head_dim + if hasattr(model_config, "head_dim") + else model_config.hidden_size // model_config.num_attention_heads + ) + + kv_buffer_shape = (1, n_kv_heads, max_seq_len, head_dim) + + self.kv_cache = ( + [ + torch.zeros(kv_buffer_shape, device=self.device, dtype=dtype) + for _ in range(self.n_layers) + ], + [ + torch.zeros(kv_buffer_shape, device=self.device, dtype=dtype) + for _ in range(self.n_layers) + ], + ) + + if not forward_only: + self.grad_kv = ( + [ + torch.zeros(kv_buffer_shape, device=self.device, dtype=dtype) + for _ in range(self.n_layers) + ], + [ + torch.zeros(kv_buffer_shape, device=self.device, dtype=dtype) + for _ in range(self.n_layers) + ], + ) + + self.ret_logprobs = [] + + self._dta_log = getLogger("DTA") + + def _dta_fail(self, message: str) -> NoReturn: + text = f"[DTA] {message}" + self._dta_log.error("%s", text) + raise RuntimeError(text) + + def get_forkpos(self, start: int, end: int) -> list[int]: + """ + Yield fork positions within the interval [start, end). + + Uses binary search on precomputed forkpos_list. + """ + + left = bisect_left(self.forkpos_list, start) + right = bisect_right(self.forkpos_list, end - 1) + yield from self.forkpos_list[left:right] + + @torch.no_grad() + def push_forward_only( + self, + new_tokens: torch.LongTensor, + attach_list: list[tuple[dict, int]], + ): + """ + Push new tokens into the stack with their attachments. + + Builds cache (KV, logprobs) up to cache_len. + Updates logprobs for the previous token. + + Used in inference mode only. + """ + + B = new_tokens.numel() + if self.cur_len + B > self.max_seq_len: + self._dta_fail( + "Exceeds max_seq_len: " + f"cur_len={self.cur_len}, new_tokens={B}, max={self.max_seq_len}" + ) + if B == 0: + for attachment, length in attach_list: + seq_id = attachment["_sequence_batch_id"] + if length == 0: + self.returns[seq_id] = torch.empty( + 0, device=self.device, dtype=torch.float32 + ) + else: + logprobs = self.logprobs[: length - 1] + self.returns[seq_id] = logprobs.clone() + return + + start, end = self.cur_len, self.cur_len + B + + # ------------------------------------------------------------- + # 1. Build prefix cache from existing KV + # ------------------------------------------------------------- + prefix_cache = DynamicCache() + for layer_idx in range(self.n_layers): + prefix_cache.update( + self.kv_cache[0][layer_idx][:, :, :start, :], + self.kv_cache[1][layer_idx][:, :, :start, :], + layer_idx=layer_idx, + ) + + # ------------------------------------------------------------- + # 2. Forward + # ------------------------------------------------------------- + out = self.model( + new_tokens.unsqueeze(0), + past_key_values=prefix_cache, + use_cache=True, + ) + + # Compute logprobs and entropy for new tokens + logits = out.logits # [1, B, vocab] or [1, B, 1] + + # ------------------------------------------------------------- + # 3. Write tokens, computed logprobs/values, and KV cache into stack + # ------------------------------------------------------------- + + # Write tokens into stack + self.tokens[start:end] = new_tokens + + # Write KV cache into stack + new_cache = out.past_key_values + for layer_idx, layer in enumerate(new_cache.layers): + self.kv_cache[0][layer_idx][:, :, start:end, :] = layer.keys[ + :, :, start:end, : + ] + self.kv_cache[1][layer_idx][:, :, start:end, :] = layer.values[ + :, :, start:end, : + ] + + if self.is_critic: + values = logits.squeeze(0).squeeze(-1) + self.values[start:end] = values + + # ------------------------------------------------------------- + # 4. Store values for sequences ending in attach_list + # ------------------------------------------------------------- + for attachment, length in attach_list: + seq_id = attachment["_sequence_batch_id"] + if length == 0: + self.returns[seq_id] = torch.empty( + 0, device=self.device, dtype=torch.float32 + ) + continue + self.returns[seq_id] = self.values[:length].clone() + else: + logprobs = gather_logprobs( + logits=logits, + labels=new_tokens[1:].unsqueeze(0), + ) + + # Write logprobs into stack + self.logprobs[start : end - 1] = logprobs.squeeze(0) + # Fill the logprob of the first token using self.forkpos_logits[start] + if start > 0: + pre_logits = self.forkpos_logits[start - 1].float() + first_token = new_tokens[0].item() + pre_logprob = F.log_softmax(pre_logits, dim=-1)[first_token].item() + self.logprobs[start - 1] = pre_logprob + + # Write logits into stack (fork positions only) + forkpos_slice = self.get_forkpos(start, end) + for i in forkpos_slice: + self.forkpos_logits[i] = logits[0, i - start].detach().clone() + + # ------------------------------------------------------------- + # 4. Store logprobs for sequences ending in attach_list + # ------------------------------------------------------------- + for attachment, length in attach_list: + seq_id = attachment["_sequence_batch_id"] + if length == 0: + self.returns[seq_id] = torch.empty( + 0, device=self.device, dtype=torch.float32 + ) + continue + logprobs = self.logprobs[: length - 1] + self.returns[seq_id] = logprobs.clone() + + self.cur_len += B + + def build_cache(self, start: int, end: int): + """ + Build KV cache, logprobs and entropy for tokens in [start, end). + Uses the existing prefix cache [0, start). + """ + + # Build prefix cache from existing KV + prefix_cache = DynamicCache() + for layer_idx in range(self.n_layers): + prefix_cache.update( + self.kv_cache[0][layer_idx][:, :, :start, :], + self.kv_cache[1][layer_idx][:, :, :start, :], + layer_idx=layer_idx, + ) + + # Forward pass to compute new KV + out = self.model( + self.tokens[start:end].unsqueeze(0), + past_key_values=prefix_cache, + use_cache=True, + ) + + # Compute logprobs & entropy for new tokens + logits = out.logits # [1, B, vocab] or [1, B, 1] + + # Write new KV cache into stack + new_cache = out.past_key_values + for layer_idx, layer in enumerate(new_cache.layers): + self.kv_cache[0][layer_idx][:, :, start:end, :] = layer.keys[ + :, :, start:end, : + ] + self.kv_cache[1][layer_idx][:, :, start:end, :] = layer.values[ + :, :, start:end, : + ] + + if self.is_critic: + values = logits.squeeze(0).squeeze(-1) + self.values[start:end] = values + else: + logprobs, entropy = gather_logprobs_entropy( + logits=logits, + labels=self.tokens[start + 1 : end].unsqueeze(0), + ) + self.logprobs[start : end - 1] = logprobs.squeeze(0) + self.entropy[start:end] = entropy.squeeze(0) + + # Write logits into stack (fork positions only) + forkpos_slice = self.get_forkpos(start, end) + for i in forkpos_slice: + self.forkpos_logits[i] = logits[0, i - start].detach().clone() + + @torch.no_grad() + def push( + self, + new_tokens: torch.LongTensor, + attachs: list[tuple[dict, int]], + cache_len: int, + ): + """ + Push new tokens into the stack with their attachments. + + Builds cache (KV, logprobs, entropy) up to cache_len. + Updates logprobs for the previous token. + """ + + B = new_tokens.numel() + if self.cur_len + B > self.max_seq_len: + self._dta_fail( + "Exceeds max_seq_len: " + f"cur_len={self.cur_len}, new_tokens={B}, max={self.max_seq_len}" + ) + + start, end = self.cur_len, self.cur_len + B + + # Add attachments + for attachment, length in attachs: + self.attachs.append((attachment, length)) + + # Write tokens + self.tokens[start:end] = new_tokens + + # Build prefix cache (KV & logprobs/entropy) if needed + if start < cache_len: + self.build_cache(start, cache_len) + + # Update the previous token's logprob. + if not self.is_critic and start > 0: + pre_logits = self.forkpos_logits[start - 1].float() + first_token = new_tokens[0].item() + pre_logprob = F.log_softmax(pre_logits, dim=-1)[first_token].item() + self.logprobs[start - 1] = pre_logprob + + self.cur_len = end + + def pop(self, start: int, loss_fn) -> float: + """ + Pop tokens from position `start` to the current end. + + Computes gradients for the popped tokens and accumulates them + into the stack's KV, logprobs, entropy, and fork position logits buffers. + + Args: + start: The starting token index to pop from. + loss_fn: Callable that computes the loss for a sequence segment. + + Returns: + The total loss computed over sequences ending within the popped segment. + """ + if not (0 <= start < self.cur_len): + self._dta_fail(f"Invalid pop start: start={start}, cur_len={self.cur_len}") + + end = self.cur_len + _ = end - start + + tokens_to_pop = self.tokens[start:end] + + # --------------------------------------------------------------------------------- + # 1. Gather prefix KV (with requires_grad=True) + # --------------------------------------------------------------------------------- + prefix_cache = DynamicCache() + prefix_kv = [] + + for layer_idx in range(self.n_layers): + k = ( + self.kv_cache[0][layer_idx][:, :, :start, :] + .detach() + .requires_grad_(True) + ) + v = ( + self.kv_cache[1][layer_idx][:, :, :start, :] + .detach() + .requires_grad_(True) + ) + prefix_cache.update(k, v, layer_idx=layer_idx) + prefix_kv.append((k, v)) + + # --------------------------------------------------------------------------------- + # 2. Forward pass on tokens_to_pop (builds computation graph) + # --------------------------------------------------------------------------------- + out = self.model( + tokens_to_pop.unsqueeze(0), past_key_values=prefix_cache, use_cache=True + ) + + logits = out.logits + block_cache = out.past_key_values + + # --------------------------------------------------------------------------------- + # 3. Compute suffix logprobs & entropy or values + # --------------------------------------------------------------------------------- + if self.is_critic: + suf_values = logits.squeeze(0).squeeze(-1) + else: + suf_logprobs, suf_entropy = gather_logprobs_entropy( + logits=logits, labels=tokens_to_pop[1:].unsqueeze(0) + ) + suf_entropy = suf_entropy.squeeze(0) + suf_logprobs = suf_logprobs.squeeze(0) + + # Compute logprob for connection to previous token if exists + if start > 0: + mid_logits = ( + self.forkpos_logits[start - 1].float().detach().requires_grad_(True) + ) + mid_label = self.tokens[start].item() + mid_logprob = F.log_softmax(mid_logits, dim=-1)[mid_label].unsqueeze(0) + + # --------------------------------------------------------------------------------- + # 4. Compute loss for sequences ending in this block + # --------------------------------------------------------------------------------- + + # Gather attachs for sequences ending in this block + attachs_in_block = [ + (att, length) for att, length in self.attachs if start < length <= end + ] + + if attachs_in_block: + if self.is_critic: + if start > 0: + pre_values = self.values[:start].detach().requires_grad_(True) + values = torch.cat([pre_values, suf_values], dim=0) + else: + values = suf_values + + # Compute loss + loss = 0.0 + for attachment, length in attachs_in_block: + if length == 0: + continue + loss += loss_fn(values[:length], attachment) + else: + # Concatenate full logprobs and entropy, with requires_grad=True + if start > 0: + pre_entropy = self.entropy[:start].detach().requires_grad_(True) + entropys = torch.cat([pre_entropy, suf_entropy], dim=0) + if start > 1: + pre_logprobs = ( + self.logprobs[: start - 1].detach().requires_grad_(True) + ) + logprobs = torch.cat( + [pre_logprobs, mid_logprob, suf_logprobs], dim=0 + ) + else: + logprobs = torch.cat([mid_logprob, suf_logprobs], dim=0) + else: + entropys = suf_entropy + logprobs = suf_logprobs + + # Compute loss + loss = 0.0 + for attachment, length in attachs_in_block: + if length == 0: + continue + loss += loss_fn( + logprobs[: length - 1], entropys[:length], attachment + ) + + # --------------------------------------------------------------------------------- + # 5. Backward with gradient injection from popped tokens + # (to KV, logprobs, entropy, forkpos-logits) + # --------------------------------------------------------------------------------- + roots, grads = [], [] + + # Loss gradient + if attachs_in_block: + roots.append(loss) + grads.append(torch.tensor(1.0, device=self.device, dtype=loss.dtype)) + + # KV gradients from popped tokens + for layer_idx, layer in enumerate(block_cache.layers): + k = layer.keys[:, :, start:end, :] + v = layer.values[:, :, start:end, :] + roots.extend([k, v]) + grads.extend( + [ + self.grad_kv[0][layer_idx][:, :, start:end, :], + self.grad_kv[1][layer_idx][:, :, start:end, :], + ] + ) + + if self.is_critic: + roots.append(suf_values) + grads.append(self.grad_values[start:end]) + else: + # Logprobs & entropy gradients from popped tokens + roots.extend([suf_logprobs, suf_entropy]) + grads.extend( + [self.grad_logprobs[start : end - 1], self.grad_entropy[start:end]] + ) + if start > 0: + roots.append(mid_logprob) + grad_mid_logprob = self.grad_logprobs[start - 1].unsqueeze(0) + grads.append(grad_mid_logprob) + + # Fork position logits gradients + forkpos_slice = self.get_forkpos(start, end) + for i in forkpos_slice: + if self.grad_forkpos_logits[i] is not None: + fork_logits = logits[0, i - start] + roots.append(fork_logits) + grads.append(self.grad_forkpos_logits[i]) + + # roots: loss, (KV, logprobs, entropy, forkpos logits) in tokens_to_pop + torch.autograd.backward(roots, grads) + + # --------------------------------------------------------------------------------- + # 6. Accumulate gradients to prefix cache (KV, logprobs, entropy, forkpos-logits) + # --------------------------------------------------------------------------------- + + # gradients to prefix KV + for layer_idx, (k, v) in enumerate(prefix_kv): + if k.grad is not None: + self.grad_kv[0][layer_idx][:, :, :start, :] += k.grad + if v.grad is not None: + self.grad_kv[1][layer_idx][:, :, :start, :] += v.grad + + if start > 0: + if self.is_critic: + if attachs_in_block and pre_values.grad is not None: + self.grad_values[:start] += pre_values.grad + else: + # gradients to forkpos logits + if mid_logits.grad is not None: + if self.grad_forkpos_logits[start - 1] is None: + self.grad_forkpos_logits[start - 1] = mid_logits.grad.clone() + else: + self.grad_forkpos_logits[start - 1] += mid_logits.grad + if attachs_in_block: + # gradients to prefix logprobs & entropy + if pre_entropy.grad is not None: + self.grad_entropy[:start] += pre_entropy.grad + if start > 1 and pre_logprobs.grad is not None: + self.grad_logprobs[: start - 1] += pre_logprobs.grad + + # --------------------------------------------------------------------------------- + # 7. Cleanup: truncate and clear buffers + # --------------------------------------------------------------------------------- + + self.attachs = [ + (att, length) for att, length in self.attachs if length <= start + ] + + for layer_idx in range(self.n_layers): + self.grad_kv[0][layer_idx][:, :, start:end, :].zero_() + self.grad_kv[1][layer_idx][:, :, start:end, :].zero_() + + if self.is_critic: + self.grad_values[start:end].zero_() + else: + self.grad_logprobs[0 if start == 0 else start - 1 : end - 1].zero_() + self.grad_entropy[start:end].zero_() + + forkpos_slice = self.get_forkpos(start, end) + for i in forkpos_slice: + self.forkpos_logits[i] = None + self.grad_forkpos_logits[i] = None + + self.cur_len = start + + return loss.item() if attachs_in_block else 0.0 + + def pop_byblock(self, start: int, block_size: int, loss_fn) -> float: + """ + Pop tokens from [start, cur_len) in blocks to reduce peak GPU memory usage. + + Tokens are popped in reverse block order, calling `pop()` on each block. + + Args: + start: The starting token index to pop from. + block_size: Maximum block size for each pop to control memory usage. + loss_fn: Callable to compute loss for a sequence segment. + + Returns: + Total loss over all popped blocks. + """ + end = self.cur_len + length = end - start + n_blocks = ceil(length / block_size) + block_size_actual = ceil(length / n_blocks) + + loss = 0.0 + for b in range(n_blocks): + pop_start = max(end - (b + 1) * block_size_actual, start) + loss += self.pop(pop_start, loss_fn) + + return loss + + @torch.no_grad() + def forward(self, model, token_trie): + """ + Perform backward pass over all sequences in a TokenTrie. + Compute logprobs for each sequence. + The sequence ID is identified by attachment['_sequence_batch_id'], which TokenTrie automatically adds. + + Args: + token_trie: TokenTrie containing input sequences and attachs. + + Returns: + List of logprob tensors for each sequence in the TokenTrie. + """ + + self.model = model + self.returns = [None] * token_trie.n_sequences + + inputs, attach_lists, lcp_lens = ( + token_trie.inputs, + token_trie.attach_lists, + token_trie.lcp_lens, + ) + + if not self.is_critic: + self.forkpos_list = _get_forkpos(None, lcp_lens, None) + + for i in range(len(inputs)): + input_ids = inputs[i].to(self.device) + attach_list = attach_lists[i] + _ = input_ids.size(0) + + # Pop diverged branch from previous sequence + if i > 0: + self.cur_len = lcp_lens[i - 1] + + # Push new tokens + new_tokens = input_ids[self.cur_len :] + + self.push_forward_only(new_tokens, attach_list) + + self.cur_len = 0 + if not self.is_critic: + self.forkpos_logits = [None] * self.max_seq_len # Clear forkpos_logits + + return self.returns + + def backward( + self, model, token_trie, loss_fn, block_size: int, cut_f1_tail: bool = True + ) -> float: + """ + Perform backward pass over all sequences in a TokenTrie. + + Processes sequences in lexicographic order, popping diverged + branches (block-wise) and pushing new tokens. + + Args: + token_trie: TokenTrie containing input sequences and attachs. + block_size: Maximum block size for popping to control GPU memory. + Use -1 for no block-size limit. + loss_fn: Callable to compute per-sequence loss. + cut_f1_tail: Whether to cut the tail of the first forward. + Returns: + Total loss accumulated over all sequences. + """ + + self.model = model + if block_size == -1: + block_size = NO_BLOCK_SIZE_LIMIT + + total_loss = 0.0 + + inputs, attach_lists, lcp_lens = ( + token_trie.inputs, + token_trie.attach_lists, + token_trie.lcp_lens, + ) + + # Precompute fork positions and block boundaries + lens = [ids.size(0) for ids in inputs] + if not self.is_critic: + self.forkpos_list = _get_forkpos(lens, lcp_lens, block_size) + + # Process each sequence + for i in range(len(inputs)): + input_ids = inputs[i].to(self.device) + attach_list = attach_lists[i] + _ = input_ids.size(0) + + # Pop diverged branch from previous sequence + if i > 0: + lcp = lcp_lens[i - 1] + if lcp < self.cur_len: + total_loss += self.pop_byblock(lcp, block_size, loss_fn) + + # Push new tokens + new_tokens = input_ids[self.cur_len :] + + # Determine cache length to build (optimize for next pop) + lcp_next = lcp_lens[i] if i < len(inputs) - 1 else 0 + B = new_tokens.numel() + next_pop_len = self.cur_len + B - lcp_next + + if next_pop_len > block_size: + n_blocks = ceil(next_pop_len / block_size) + block_size_actual = ceil(next_pop_len / n_blocks) + cache_len = max(self.cur_len + B - block_size_actual, lcp_next) + else: + cache_len = lcp_next + + if not cut_f1_tail: + cache_len = self.cur_len + B + + self.push(new_tokens, attach_list, cache_len) + + # Final pop for remaining tokens + if self.cur_len > 0: + total_loss += self.pop_byblock(0, block_size, loss_fn) + + return total_loss diff --git a/areal/experimental/dta/token_trie.py b/areal/experimental/dta/token_trie.py new file mode 100644 index 0000000000..3d143e8c08 --- /dev/null +++ b/areal/experimental/dta/token_trie.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 + +# The following code is adapted with minor modifications from +# https://github.com/Whisper-6/DynamicTreeAttn/blob/main/token_trie.py. +# Special thanks to Yuchen Yang for outstanding contributions to the optimized DFS order. + +import torch + +from areal.experimental.dta.trie import CompressedTrie, _get_stats + + +def _lcp_torch(a: torch.Tensor, b: torch.Tensor) -> int: + """Compute the length of the longest common prefix of two 1D tensors.""" + L = min(a.numel(), b.numel()) + eq = a[:L] == b[:L] + return L if eq.all() else int((~eq).to(torch.int32).argmax().item()) + + +def _leafization(input_ids: list[torch.LongTensor], attachs: list[dict]): + """ + Args: + input_ids: List of token tensors, sorted in lexicographic order. + attachs: List of dicts, each storing loss-related config for one token tensor. + + Merge fully overlapping prefixes and compute the `lcp_lens` list. + """ + + # Compute adjacent LCP lengths and validate lexicographic ordering. + lcp_lens = [] + for i in range(len(input_ids) - 1): + seq_L, seq_R = input_ids[i], input_ids[i + 1] + lcp = _lcp_torch(seq_L, seq_R) + L = min(seq_L.numel(), seq_R.numel()) + if lcp < L and seq_L[lcp] > seq_R[lcp]: + raise ValueError("input_ids not sorted in lexicographic order.") + lcp_lens.append(lcp) + + # Merge fully overlapping prefixes by keeping only the longest sequence. + input_ids_leafed = [] + attach_lists = [] + lcp_lens_leafed = [] + + fork = -1 + for i in range(len(input_ids)): + if i == len(input_ids) - 1 or lcp_lens[i] < min( + input_ids[i].numel(), input_ids[i + 1].numel() + ): + input_ids_leafed.append(input_ids[i]) + if i < len(input_ids) - 1: + lcp_lens_leafed.append(lcp_lens[i]) + attach_list = [] + for k in range(fork + 1, i + 1): + attach_list.append((attachs[k], input_ids[k].numel())) + attach_lists.append(attach_list) + fork = i + + return input_ids_leafed, attach_lists, lcp_lens_leafed + + +class TokenTrie: + def __init__( + self, + inputs: list[torch.LongTensor], + attachs: list[dict] | None = None, + sorted: bool = False, + ): + if attachs is not None: + if len(inputs) != len(attachs): + raise ValueError("Length of inputs and attachs must match.") + else: + attachs = [{} for _ in range(len(inputs))] + + # Attach the original sequence index to each attachment dict. + for seq_id in range(len(inputs)): + attachs[seq_id]["_sequence_batch_id"] = seq_id + + # -------- sort by lexicographical order of input_ids -------- + if not sorted: + pairs = list(zip(inputs, attachs)) + pairs.sort(key=lambda x: x[0].tolist()) + inputs_sorted, attachs_sorted = [p[0] for p in pairs], [p[1] for p in pairs] + else: + inputs_sorted, attachs_sorted = inputs, attachs + + # -------- leafization -------- + self.inputs, self.attach_lists, self.lcp_lens = _leafization( + inputs_sorted, attachs_sorted + ) + self.lens = [len(ids) for ids in self.inputs] + + # -------- stats -------- + self.n_sequences = len(inputs) + self.n_tokens = sum(len(ids) for ids in inputs) + + def get_stats(self, mode: str, block_size: int | None = None): + stats = _get_stats(self.lens, self.lcp_lens, mode, block_size) + stats["n_sequences"] = self.n_sequences + stats["n_tokens"] = self.n_tokens + return stats + + def permute(self, order): + self.inputs = [self.inputs[i] for i in order] + self.attach_lists = [self.attach_lists[i] for i in order] + self.lens = [self.lens[i] for i in order] + self.lcp_lens = [ + _lcp_torch(self.inputs[i], self.inputs[i + 1]) + for i in range(len(self.inputs) - 1) + ] + + def forward_permute(self): + compressed_trie = CompressedTrie(self.lens, self.lcp_lens) + order, _, _ = compressed_trie.get_order_forward() + self.permute(order) + + def backward_permute(self): + compressed_trie = CompressedTrie(self.lens, self.lcp_lens) + order, _, _ = compressed_trie.get_order_backward() + self.permute(order) + + def random_permute(self): + compressed_trie = CompressedTrie(self.lens, self.lcp_lens) + order = compressed_trie.get_order_random() + self.permute(order) diff --git a/areal/experimental/dta/tree_time_model.py b/areal/experimental/dta/tree_time_model.py new file mode 100644 index 0000000000..d39c856a5f --- /dev/null +++ b/areal/experimental/dta/tree_time_model.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 + +# The following code is adapted with minor modifications from +# https://github.com/Whisper-6/DynamicTreeAttn/blob/main/tree_time_model.py. +import numpy as np +from scipy.optimize import nnls + + +class TreeTimeModel: + MIN_N_DATA_POINTS = 16 + MAX_N_DATA_POINTS = 1024 + + def __init__(self): + # T = c_0 * n_leaf_sequences + c_1 * n_tree_tokens + c_2 * n_f1_tokens + c_3 * sum_prefix_len + c_4 * sum_depth + self.coeffs = None + self.data = [] + + def fit(self): + X, Y = [], [] + for stats in self.data: + # X.append([0, stats["n_tree_tokens"], 0, 0, 0]) + X.append( + [ + stats["n_leaf_sequences"], + stats["n_tree_tokens"], + stats.get("n_f1_tokens", 0), + stats["sum_prefix_len"], + stats["sum_depth"], + ] + ) + Y.append(stats["time"]) + + X, Y = np.array(X), np.array(Y) + self.coeffs, _ = nnls(X, Y) + + T_pred = X @ self.coeffs + mse = np.mean((T_pred - Y) ** 2) + return mse + + def add_data(self, data): + self.data.extend(data) + if len(self.data) > self.MAX_N_DATA_POINTS: + self.data = self.data[-self.MAX_N_DATA_POINTS :] + if len(self.data) >= self.MIN_N_DATA_POINTS: + self.fit() + + def pred(self, stats): + if self.coeffs is None: + return stats["n_tree_tokens"] + return ( + self.coeffs[0] * stats["n_leaf_sequences"] + + self.coeffs[1] * stats["n_tree_tokens"] + + self.coeffs[2] * stats.get("n_f1_tokens", 0) + + self.coeffs[3] * stats["sum_prefix_len"] + + self.coeffs[4] * stats["sum_depth"] + ) diff --git a/areal/experimental/dta/trie.py b/areal/experimental/dta/trie.py new file mode 100644 index 0000000000..68968f8494 --- /dev/null +++ b/areal/experimental/dta/trie.py @@ -0,0 +1,287 @@ +# SPDX-License-Identifier: Apache-2.0 + +# The following code is adapted with minor modifications from +# https://github.com/Whisper-6/DynamicTreeAttn/blob/main/trie.py. +import random +from dataclasses import dataclass, field +from math import ceil + + +def _get_stats( + lens: list[int], lcp_lens: list[int], mode: str, block_size: int | None = None +) -> dict: + n_tree_tokens = sum(lens) - sum(lcp_lens) + sum_depth = 0 + for i in range(len(lens)): + start = lcp_lens[i - 1] if i > 0 else 0 + end = lens[i] + sum_depth += (start + end - 1) * (end - start) // 2 + + if mode == "forward": + sum_prefix_len = sum(lcp_lens) + + return { + "n_leaf_sequences": len(lens), + "n_tree_tokens": n_tree_tokens, + "sum_prefix_len": sum_prefix_len, + "sum_depth": sum_depth, + } + + elif mode == "backward": + sum_prefix_len = 0 + n_f1_tokens = 0 + for i in range(len(lens)): + start = lcp_lens[i] if i < len(lcp_lens) else 0 + end = lens[i] + pop_len = end - start + f1_start = lcp_lens[i - 1] if i > 0 else 0 + + if block_size is None or pop_len <= block_size: + f1_end = start + sum_prefix_len += start + else: + n_blocks = ceil(pop_len / block_size) + block_size_actual = ceil(pop_len / n_blocks) + f1_end = end - block_size_actual + for b in range(n_blocks): + pop_start = max(end - (b + 1) * block_size_actual, start) + sum_prefix_len += pop_start + + n_f1_tokens += max(f1_end - f1_start, 0) + + return { + "n_leaf_sequences": len(lens), + "n_tree_tokens": n_tree_tokens, + "sum_prefix_len": sum_prefix_len, + "sum_depth": sum_depth, + "n_f1_tokens": n_f1_tokens, + } + + else: + raise ValueError(f"Unsupported mode: {mode}") + + +@dataclass(slots=True) +class CTNode: + """Node in the compressed trie.""" + + depth: int = 0 # Depth of this node. + seq_id: int = -1 # Sequence index; -1 indicates an internal node. + chain_tail_depth: int = 0 # Tail depth of the prioritized chain. + child_ids: list[int] = field(default_factory=list) # IDs of child nodes. + + +class CompressedTrie: + """Compressed trie used to plan traversal order.""" + + def __init__(self, lens: list[int], lcp_lens: list[int]): + """ + Initialize the compressed trie. + + Args: + lens: Length of each sequence, sorted in lexicographic order. + lcp_lens: LCP length between adjacent sequences, where + len(lcp_lens) == max(len(lens) - 1, 0). An empty `lens` + produces a degenerate trie that contains only the root node. + """ + expected_lcp = max(len(lens) - 1, 0) + if len(lcp_lens) != expected_lcp: + raise ValueError( + f"len(lcp_lens) must be {expected_lcp}, got {len(lcp_lens)}" + ) + + self.nodes: list[CTNode] = [] # Stores all trie nodes. + self._build(lens, lcp_lens) + + self.lca_depth = None + self.order = None + self.lens = None + self.lcp_lens = None + + def _new_node(self, depth: int, seq_id: int = -1) -> int: + """Create a new node and return its ID.""" + self.nodes.append(CTNode(depth=depth, seq_id=seq_id)) + return len(self.nodes) - 1 + + def _build(self, lens: list[int], lcp_lens: list[int]): + """Build the compressed trie.""" + + n_seqs = len(lens) + # Create the root node. + root_id = self._new_node(depth=0, seq_id=-1) + stack = [(root_id, 0)] # Stack entries are (node_id, depth). + nodes = self.nodes + + for seq_id in range(n_seqs): + len_i = lens[seq_id] + lcp = lcp_lens[seq_id - 1] if seq_id > 0 else 0 + + if len(stack) >= 2: + while stack[-2][1] > lcp: + # Pop a child node and connect it to its parent. + child_id = stack.pop()[0] + parent_id = stack[-1][0] + nodes[parent_id].child_ids.append(child_id) + + child_id = stack.pop()[0] + if stack[-1][1] < lcp: + lcp_node_id = self._new_node(depth=lcp, seq_id=-1) + stack.append((lcp_node_id, lcp)) + parent_id = stack[-1][0] + nodes[parent_id].child_ids.append(child_id) + else: + if stack[-1][1] < lcp: + lcp_node_id = self._new_node(depth=lcp, seq_id=-1) + stack.append((lcp_node_id, lcp)) + + # Create a new leaf node. + parent_id = stack[-1][0] + cur_node_id = self._new_node(depth=len_i, seq_id=seq_id) + stack.append((cur_node_id, len_i)) + + while len(stack) >= 2: + child_id = stack.pop()[0] + parent_id = stack[-1][0] + nodes[parent_id].child_ids.append(child_id) + + def dfs_chain(self, node_id: int, child_order_func) -> int: + """Compute `chain_tail_depth` for each node.""" + node = self.nodes[node_id] + + # Leaf node. + if node.seq_id != -1: + node.chain_tail_depth = node.depth + return + + for child_id in node.child_ids: + self.dfs_chain(child_id, child_order_func) + + child_ids = child_order_func(node_id) + if not child_ids: + # Only reachable for the root of an empty trie. The value never + # propagates anywhere since the subtree carries no leaves. + node.chain_tail_depth = node.depth + return + node.chain_tail_depth = self.nodes[child_ids[0]].chain_tail_depth + + def dfs_get_lens(self, node_id: int, seq_set: set[int]): + node = self.nodes[node_id] + + if node.seq_id != -1: + if node.seq_id in seq_set: + self.lens.append(node.depth) + self.lcp_lens.append(self.lca_depth) + self.lca_depth = node.depth + return + + for child_id in node.child_ids: + self.lca_depth = min(self.lca_depth, node.depth) + self.dfs_get_lens(child_id, seq_set) + + def get_lens(self, seq_set: set[int]): + self.lens = [] + self.lcp_lens = [] + self.lca_depth = 0 + self.dfs_get_lens(0, seq_set) + return self.lens, self.lcp_lens[1:] + + def dfs_get_order(self, node_id: int, child_order_func): + node = self.nodes[node_id] + + # Leaf node: record the sequence index. + if node.seq_id != -1: + self.order.append(node.seq_id) + self.lens.append(node.depth) + self.lcp_lens.append(self.lca_depth) + self.lca_depth = node.depth + return + + # Get child traversal order from the given strategy. + child_ids = child_order_func(node_id) + + # Recursively traverse children. + for child_id in child_ids: + self.lca_depth = min(self.lca_depth, node.depth) + self.dfs_get_order(child_id, child_order_func) + + def _get_child_order_forward(self, node_id: int) -> list[int]: + node = self.nodes[node_id] + return sorted( + node.child_ids, key=lambda child_id: self.nodes[child_id].chain_tail_depth + ) + + def _get_child_order_backward(self, node_id: int) -> list[int]: + node = self.nodes[node_id] + return sorted( + node.child_ids, + key=lambda child_id: ( + 1 if self.nodes[child_id].child_ids else 0, + self.nodes[child_id].chain_tail_depth, + ), + ) + + def _get_child_order_random( + self, node_id: int, seed: int | None = None + ) -> list[int]: + node = self.nodes[node_id] + child_ids = node.child_ids.copy() + + if seed is not None: + local_random = random.Random(seed) + local_random.shuffle(child_ids) + else: + random.shuffle(child_ids) + + return child_ids + + def get_order(self, child_order_func): + """Get sequence order from DFS with a custom child-order strategy.""" + self.dfs_chain(0, child_order_func) + self.order = [] + self.lens = [] + self.lcp_lens = [] + self.lca_depth = 0 + self.dfs_get_order(0, child_order_func) + + def get_order_forward(self): + """Get sequence order from DFS using main-Ld-priority traversal.""" + self.get_order(self._get_child_order_forward) + return self.order, self.lens, self.lcp_lens[1:] + + def get_order_backward(self): + """Get sequence order from DFS for backward-style pop traversal.""" + self.get_order(self._get_child_order_backward) + return self.order[::-1], self.lens[::-1], self.lcp_lens[1:][::-1] + + def get_order_random(self, seed: int | None = None): + """Get sequence order from DFS after randomizing child edges.""" + self.get_order(lambda node_id: self._get_child_order_random(node_id, seed)) + return self.order + + +def _get_subtrie(trie, seq_set: set[int]) -> CompressedTrie: + lens, lcp_lens = trie.get_lens(seq_set) + return CompressedTrie(lens, lcp_lens) + + +# -------- Test -------- + + +def test_compressed_trie(): + lens1 = [5, 4, 3, 2] + lcp_lens1 = [3, 2, 1] + + trie1 = CompressedTrie(lens1, lcp_lens1) + + order, lens, lcp_lens = trie1.get_order_forward() + print(order, lens, lcp_lens) + + order, lens, lcp_lens = trie1.get_order_backward() + print(order, lens, lcp_lens) + + order, lens, lcp_lens = trie1.get_order_random() + print(order, lens, lcp_lens) + + +if __name__ == "__main__": + test_compressed_trie() diff --git a/areal/experimental/dta/wrapper.py b/areal/experimental/dta/wrapper.py new file mode 100644 index 0000000000..470d25ec66 --- /dev/null +++ b/areal/experimental/dta/wrapper.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, Protocol + +import torch +from transformers import PretrainedConfig +from transformers.cache_utils import DynamicCache + +from areal.experimental.dta.dta_engine import DTAEngine +from areal.experimental.dta.token_trie import TokenTrie + + +class KVCacheModel(Protocol): + """Structural contract for DTA-compatible models.""" + + def forward( + self, + tokens: torch.LongTensor, + past_key_values: DynamicCache | None = None, + use_cache: bool = True, + ) -> SimpleNamespace: ... + + +class DTAWrapper: + """Engine-agnostic facade for DTA forward/backward paths.""" + + def __init__( + self, + model: KVCacheModel, + model_config: PretrainedConfig, + device: torch.device, + dtype: torch.dtype, + max_seq_len: int, + block_size: int, + is_critic: bool = False, + ) -> None: + self.model = model + self.device = device + self.block_size = block_size + self.is_critic = is_critic + self._engine = DTAEngine( + model_config=model_config, + device=device, + dtype=dtype, + max_seq_len=max_seq_len, + is_critic=is_critic, + ) + + @torch.no_grad() + def run_forward(self, mb_list: Any) -> torch.Tensor: + input_ids_list = self._extract_input_ids_list_from_mb_list(mb_list) + max_seq_len = max((ids.numel() for ids in input_ids_list), default=0) + input_data = [{} for _ in input_ids_list] + trie = TokenTrie(input_ids_list, input_data, sorted=False) + trie.forward_permute() + + output = self._engine.forward(model=self.model, token_trie=trie) + batch_size = len(output) + if batch_size == 0: + return torch.zeros((0, 0), dtype=torch.float32, device=self.device) + output_padded = torch.zeros( + (batch_size, max_seq_len), + dtype=output[0].dtype, + device=output[0].device, + ) + for i, seq in enumerate(output): + seq_len = seq.shape[0] + output_padded[i, :seq_len] = seq + return output_padded + + @staticmethod + def _extract_input_ids(mb_input: dict[str, Any]) -> torch.Tensor: + if "input_ids" not in mb_input: + raise ValueError("DTA expects `input_ids` in micro-batch input.") + input_ids = mb_input["input_ids"] + if not torch.is_tensor(input_ids) or input_ids.ndim != 1: + raise ValueError( + "DTA expects packed 1D `input_ids` in micro-batch input, " + f"got {type(input_ids)} with ndim=" + f"{getattr(input_ids, 'ndim', 'N/A')}." + ) + return input_ids + + def _extract_input_ids_list_from_mb_list(self, mb_list: Any) -> list[torch.Tensor]: + input_ids_list: list[torch.Tensor] = [] + for mb_item in mb_list: + input_ids_list.append(self._extract_input_ids(mb_item.orig_mb)) + return input_ids_list + + def run_backward_with_scaled_loss( + self, + mb_list: Any, + prepare_mb_inputs_fn: Any, + loss_fn: Any, + loss_weight_fn: Any, + total_loss_weight: torch.Tensor, + block_size: int | None = None, + ) -> dict[str, float]: + input_ids_list = self._extract_input_ids_list_from_mb_list(mb_list) + per_seq_input_data: list[dict[str, Any]] = [] + for idx, mb_item in enumerate(mb_list): + _, ctx = prepare_mb_inputs_fn(mb_item) + mb_input = ctx.mb_input + # Keep backward input source aligned with forward input source. + self._extract_input_ids(mb_input) + if mb_input["input_ids"].shape != input_ids_list[idx].shape: + raise ValueError( + "DTA expects `ctx.mb_input['input_ids']` to align with " + "`mb_item.orig_mb['input_ids']` for each micro-batch." + ) + loss_scale = loss_weight_fn(ctx.mb_input) / total_loss_weight + if isinstance(loss_scale, torch.Tensor): + loss_scale = loss_scale.item() + per_seq_input_data.append({"original": mb_input, "scale": loss_scale}) + + if self.is_critic: + + def scaled_loss_fn( + values: torch.Tensor, + seq_input_data: dict[str, Any], + **extra_kwargs: Any, + ) -> torch.Tensor: + loss_val = loss_fn( + values, + seq_input_data["original"], + **extra_kwargs, + ) + return loss_val * seq_input_data["scale"] + else: + + def scaled_loss_fn( + logprobs: torch.Tensor, + entropy: torch.Tensor, + seq_input_data: dict[str, Any], + **extra_kwargs: Any, + ) -> torch.Tensor: + # Keep current behavior: DTA engine expects one extra position. + logprobs = torch.cat([logprobs, logprobs.new_zeros(1)], dim=0) + loss_val = loss_fn( + logprobs, + entropy, + seq_input_data["original"], + **extra_kwargs, + ) + return loss_val * seq_input_data["scale"] + + trie = TokenTrie(input_ids_list, per_seq_input_data, sorted=False) + trie.backward_permute() + + total_loss = self._engine.backward( + model=self.model, + token_trie=trie, + block_size=block_size or self.block_size, + loss_fn=scaled_loss_fn, + ) + return {"dta_loss": float(total_loss)} diff --git a/areal/experimental/engine/archon_engine.py b/areal/experimental/engine/archon_engine.py index 98e2645a2e..21098d09ec 100644 --- a/areal/experimental/engine/archon_engine.py +++ b/areal/experimental/engine/archon_engine.py @@ -45,6 +45,7 @@ reorder_and_pad_outputs, ) from areal.engine.fsdp_utils.grad import fsdp2_clip_grad_norm +from areal.experimental.dta.wrapper import DTAWrapper from areal.experimental.engine.archon_checkpoint import ( load_from_dcp, load_model_from_hf, @@ -65,6 +66,12 @@ update_weights_from_disk, update_weights_from_distributed, ) +from areal.experimental.engine.archon_zero1 import ( + all_reduce_zero1_gradients, + create_zero1_optimizer, + parallelize_fn_zero1, + zero1_clip_grad_norm, +) from areal.experimental.models.archon import ( ArchonParallelDims, BaseStateDictAdapter, @@ -151,8 +158,14 @@ def __init__(self, config: TrainEngineConfig): # Configuration (immutable after init) self.config = config self.optimizer_config = config.optimizer - self.enable_tree_training = config.enable_tree_training - + self.tree_training_mode = config.tree_training_mode + if self.tree_training_mode == "dta" and config.gradient_checkpointing: + raise ValueError( + "ArchonEngine: gradient_checkpointing=True is incompatible with " + "tree_training_mode='dta'. Disable gradient_checkpointing for DTA." + ) + if self.tree_training_mode == "dta": + self.dta_wrapper: DTAWrapper # Model Configuration (loaded during __init__) self.model_config: PretrainedConfig = AutoConfig.from_pretrained( pretrained_model_name_or_path=self.config.path, @@ -343,7 +356,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): config=self.config, parallel_dims=self.parallel_dims, model_config=self.model_config, - enable_tree_training=self.enable_tree_training, + tree_training_mode=self.tree_training_mode, logger=self.logger, ) @@ -383,6 +396,19 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): has_last_stage=self.pp_has_last_stage, ) + if self.tree_training_mode == "dta": + dta_dtype = getattr(torch, self.config.dtype) + self.dta_wrapper = DTAWrapper( + model=self.model, + model_config=self.model_config, + device=self.device, + dtype=dta_dtype, + max_seq_len=self.config.mb_spec.max_tokens_per_mb, + block_size=self.config.dta_block_size, + is_critic=self.config.is_critic, + ) + self.logger.info(f"DTA Wrapper created on device {self.device}") + self._initialized = True @property @@ -482,16 +508,23 @@ def optimizer_step(self): assert self.optimizer_config is not None assert self.lr_scheduler is not None - grad_norm = fsdp2_clip_grad_norm( - self._get_all_parameters(), - max_norm=self.optimizer_config.gradient_clipping, - fsdp_group=self.data_parallel_group, - tp_group=self._tp_group, - pp_group=self.parallel_dims.get_group("pp") - if self.parallel_dims.pp_enabled - else None, - offload_params=self.config.archon.offload_params, - ) + if self.tree_training_mode == "dta": + grad_norm = zero1_clip_grad_norm( + self._get_all_parameters(), + max_norm=self.optimizer_config.gradient_clipping, + dp_group=self.data_parallel_group, + ) + else: + grad_norm = fsdp2_clip_grad_norm( + self._get_all_parameters(), + max_norm=self.optimizer_config.gradient_clipping, + fsdp_group=self.data_parallel_group, + tp_group=self._tp_group, + pp_group=self.parallel_dims.get_group("pp") + if self.parallel_dims.pp_enabled + else None, + offload_params=self.config.archon.offload_params, + ) if not math.isfinite(grad_norm): self.optimizer_zero_grad() @@ -527,6 +560,7 @@ def train_batch( input_: list[dict[str, Any]] | dict[str, Any], loss_fn: Callable[..., torch.Tensor], loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], + return_loss: bool = False, ) -> dict[str, float]: """Train on a batch of data.""" assert self._initialized @@ -540,11 +574,36 @@ def train_batch( mb_list, loss_weight_fn, self.data_parallel_group ) + if self.tree_training_mode == "dta": + # ========== DTA Path ========== + self.logger.info("tree_training_mode='dta' in train_batch") + self.logger.info(f"total_loss_weight: {total_loss_weight}") + dta_stats = self.dta_wrapper.run_backward_with_scaled_loss( + mb_list=mb_list, + prepare_mb_inputs_fn=self._prepare_mb_inputs, + loss_fn=loss_fn, + loss_weight_fn=loss_weight_fn, + total_loss_weight=total_loss_weight, + block_size=self.config.dta_block_size, + ) + self.logger.info(f"DTA backward stats: {dta_stats}") + all_reduce_zero1_gradients( + self._get_all_parameters(), + dp_group=self.data_parallel_group, + ) + result = self.optimizer_step() + if return_loss: + dta_loss = float(dta_stats.get("dta_loss", float("nan"))) + result["loss"] = dta_loss + return result + + losses: list[torch.Tensor] = [] + def process_output( logits: torch.Tensor, ctx_dict: dict[str, Any] ) -> torch.Tensor: ctx = ArchonTrainContext(**ctx_dict) - return self._compute_logprobs_and_loss( + loss = self._compute_logprobs_and_loss( logits, ctx, loss_fn, @@ -552,12 +611,28 @@ def process_output( total_loss_weight, loss_multiplier=self.data_parallel_world_size, ) + if return_loss: + losses.append(loss.detach()) + return loss self.forward_backward_batch(mb_list, process_output, forward_only=False) - stats = self.optimizer_step() - stats["num_micro_batches"] = len(mb_list.mbs) - return stats + result = self.optimizer_step() + result["num_micro_batches"] = len(mb_list.mbs) + if return_loss: + if losses: + # Non-DTA path stores per-microbatch scaled loss: + # loss_i * (w_i / W_total) * dp_world_size + # Summing over microbatches then dividing by dp_world_size aligns + # with DTA's returned objective: + # sum_i loss_i * (w_i / W_total) + local_loss = float(torch.stack(losses).sum().item()) / float( + self.data_parallel_world_size + ) + else: + local_loss = float("nan") + result["loss"] = local_loss + return result @torch.no_grad() def eval_batch( @@ -630,20 +705,31 @@ def forward_batch( batch_size = len(output_seqlens) mb_list = self._prepare_mb_list(input_batched).to(self.device) + if self.tree_training_mode == "dta": + self.logger.info("tree_training_mode='dta' in forward_batch") + dta_out = self.dta_wrapper.run_forward(mb_list=mb_list) + # DTA outputs already follow forward micro-batch order. + # Align seqlens to the same order and let reorder_and_pad_outputs + # restore original batch order via backward_indices. + seqlens_fwd = [output_seqlens[i] for i in mb_list.forward_indices] + outputs = [ + dta_out[i, : int(seqlens_fwd[i])] for i in range(len(seqlens_fwd)) + ] + else: - def process_output( - logits: torch.Tensor, ctx_dict: dict[str, Any] - ) -> torch.Tensor: - ctx = ArchonTrainContext(**ctx_dict) - return self._compute_forward_result(logits, ctx) + def process_output( + logits: torch.Tensor, ctx_dict: dict[str, Any] + ) -> torch.Tensor: + ctx = ArchonTrainContext(**ctx_dict) + return self._compute_forward_result(logits, ctx) - outputs = self.forward_backward_batch( - mb_list, process_output, forward_only=True - ) + outputs = self.forward_backward_batch( + mb_list, process_output, forward_only=True + ) if self.pp_has_last_stage: assert outputs is not None - if self.enable_tree_training: + if self.tree_training_mode == "sparse": res = merge_packed_tree_results(outputs, batch_size) else: res = reorder_and_pad_outputs( @@ -948,6 +1034,11 @@ def _apply_parallelism( enable_compile: bool, ) -> None: """Apply parallelism using parallelize_fn.""" + if self.tree_training_mode == "dta": + self.model = parallelize_fn_zero1(self.model) + self.model_parts = [self.model] + return + self.spec.parallelize_fn( model=self.model, parallel_dims=self.parallel_dims, @@ -971,7 +1062,7 @@ def _prepare_mb_inputs( # Tree training: labels are derived from trie structure, not torch.roll. # (Tree input_ids is 1D packed format, so roll would be wrong anyway.) - if self.enable_tree_training: + if self.tree_training_mode == "sparse": assert trie_node is not None ctx = ArchonTrainContext( mb_input=mb_item.orig_mb, @@ -1095,7 +1186,7 @@ def _create_model_structure(self) -> nn.Module: """Create model structure on meta device without loading weights.""" # Use tree attention type when tree training is enabled attn_type = self.config.archon.attn_type - if self.enable_tree_training: + if self.tree_training_mode == "sparse": if attn_type != "tree": self.logger.warning( f"Tree training enabled, overriding attn_type '{self.config.archon.attn_type}' -> 'tree'" @@ -1140,6 +1231,22 @@ def _materialize_and_load_weights(self): for model in self.model_parts: model.init_weights() + # DTA-only: parallelize_fn is identity, so no tie in parallelize_*; checkpoint + # load materializes separate tensors per key. Re-bind after load so embed and + # lm_head share storage (disabled path unchanged: tie remains in parallelize_*). + if self.tree_training_mode == "dta": + for model in self.model_parts: + if ( + model.model_args.enable_weight_tying + and model.output is not None + and model.tok_embeddings is not None + ): + model.output.weight = model.tok_embeddings.weight + self.logger.info( + "DTA: applied weight tying (output.weight = tok_embeddings.weight) " + "after loading weights" + ) + for model in self.model_parts: model.init_buffers(buffer_device=buffer_device) @@ -1156,9 +1263,16 @@ def _create_optimizer(self, ft_spec: FinetuneSpec): tik = time.perf_counter() - self.optimizer = create_optimizer( - self._get_all_parameters(), self.optimizer_config - ) + if self.tree_training_mode == "dta": + self.optimizer = create_zero1_optimizer( + self._get_all_parameters(), + self.optimizer_config, + self.data_parallel_group, + ) + else: + self.optimizer = create_optimizer( + self._get_all_parameters(), self.optimizer_config + ) self.lr_scheduler = create_lr_scheduler( self.optimizer, self.optimizer_config, ft_spec.total_train_steps ) @@ -1179,7 +1293,7 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList: # Tree training path # Note: CP/PP incompatibility is validated in initialize(). - if self.enable_tree_training: + if self.tree_training_mode == "sparse": mb_list = build_packed_tree_batch( input_, mb_spec=self.config.mb_spec, @@ -1195,29 +1309,50 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList: input_ = amend_position_ids(input_) - # Pipeline parallelism requires n_microbatches >= num_total_stages - if self.parallel_dims.pp_enabled: - pp_size = self.parallel_dims.pp - stages_per_rank = len(self.pp_stages) - num_total_stages = pp_size * stages_per_rank - n_seqs = input_["attention_mask"].shape[0] - if n_seqs < num_total_stages: - raise RuntimeError( - f"Pipeline parallelism requires at least {num_total_stages} " - f"sequences (pp_size={pp_size} * stages_per_rank=" - f"{stages_per_rank}), but got {n_seqs}. " - f"Increase batch size or reduce PP degree/stages." - ) - min_n_mbs = num_total_stages + if self.tree_training_mode == "dta": + # DTA uses one sequence per microbatch for sequence-level loss via DTAEngine. + # PP/CP incompatibility is validated in initialize(). + n_seqs = input_["input_ids"].shape[0] mb_spec = MicroBatchSpec.new( self.config.mb_spec, - n_mbs=max(min_n_mbs, self.config.mb_spec.n_mbs or 1), - n_mbs_divisor=pp_size, + n_mbs=n_seqs, + granularity=1, + max_tokens_per_mb=self.config.mb_spec.max_tokens_per_mb, + ) + # Keep DTA per-rank independent: one sequence per microbatch, no + # cross-rank synced microbatch-count alignment. + mb_list = split_padded_tensor_dict_into_mb_list( + input_, mb_spec, one_seq_per_mb=True + ) + assert len(mb_list.mbs) == n_seqs, ( + f"DTA requires one microbatch per sequence, " + f"expected {n_seqs} microbatches but got {len(mb_list.mbs)}." ) else: - mb_spec = self.config.mb_spec + # Pipeline parallelism requires n_microbatches >= num_total_stages. + # DTA path above is PP-incompatible and therefore bypasses this branch. + if self.parallel_dims.pp_enabled: + pp_size = self.parallel_dims.pp + stages_per_rank = len(self.pp_stages) + num_total_stages = pp_size * stages_per_rank + n_seqs = input_["attention_mask"].shape[0] + if n_seqs < num_total_stages: + raise RuntimeError( + f"Pipeline parallelism requires at least {num_total_stages} " + f"sequences (pp_size={pp_size} * stages_per_rank=" + f"{stages_per_rank}), but got {n_seqs}. " + f"Increase batch size or reduce PP degree/stages." + ) + min_n_mbs = num_total_stages + mb_spec = MicroBatchSpec.new( + self.config.mb_spec, + n_mbs=max(min_n_mbs, self.config.mb_spec.n_mbs or 1), + n_mbs_divisor=pp_size, + ) + else: + mb_spec = self.config.mb_spec + mb_list = split_padded_tensor_dict_into_mb_list(input_, mb_spec) - mb_list = split_padded_tensor_dict_into_mb_list(input_, mb_spec) mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs] # LCM ensures page-aligned memory and exact CP slicing without extra padding. @@ -1304,7 +1439,7 @@ def _gather_actor_train_outputs( ctx: ArchonTrainContext, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | None: """Compute (logprobs, entropy, vocab_min, vocab_max) for actor training.""" - if self.enable_tree_training: + if self.tree_training_mode == "sparse": # Handle dummy trie (empty tree for DP synchronization) if ctx.trie_node is None or not ctx.trie_node.all_sequence_ids: return None @@ -1347,7 +1482,7 @@ def _gather_actor_forward_output( ctx: ArchonTrainContext, ) -> torch.Tensor | dict[int, torch.Tensor]: """Compute actor logprobs for forward-only path.""" - if self.enable_tree_training: + if self.tree_training_mode == "sparse": assert ctx.trie_node is not None return _gather_packed_tree_logprobs( logits, diff --git a/areal/experimental/engine/archon_utils.py b/areal/experimental/engine/archon_utils.py index cf9734587f..03e4a838da 100644 --- a/areal/experimental/engine/archon_utils.py +++ b/areal/experimental/engine/archon_utils.py @@ -258,11 +258,11 @@ def force_pad_to_maximum( config: TrainEngineConfig, parallel_dims: ArchonParallelDims, enable_compile: bool, - enable_tree_training: bool, + tree_training_mode: str, logger: logging.Logger, ) -> None: """Force ``config.pad_to_maximum = True`` when compile, PP, or tree training - requires it. Also validates tree training constraints. + requires it. Also validates tree training / DTA constraints. """ # Force pad_to_maximum when compile is enabled to avoid dynamic shape issues if enable_compile and not config.pad_to_maximum: @@ -280,8 +280,8 @@ def force_pad_to_maximum( ) config.pad_to_maximum = True - # Tree training constraints - if enable_tree_training: + # Sparse tree training constraints + if tree_training_mode == "sparse": if config.is_critic: raise NotImplementedError( "Tree training with critic model is not supported yet." @@ -299,6 +299,23 @@ def force_pad_to_maximum( ) config.pad_to_maximum = True + # DTA constraints + if tree_training_mode == "dta": + if ( + parallel_dims.pp_enabled + or parallel_dims.cp_enabled + or parallel_dims.tp_enabled + or parallel_dims.ep_enabled + or parallel_dims.etp_enabled + ): + raise ValueError( + "DTA currently supports only data parallelism. " + "Found unsupported parallel dimensions enabled among " + "{pp, cp, tp, ep, etp}. " + f"Current sizes: pp={parallel_dims.pp}, cp={parallel_dims.cp}, " + f"tp={parallel_dims.tp}, ep={parallel_dims.ep}, etp={parallel_dims.etp}." + ) + # ========================================================================= # Combined Config Preparation @@ -309,7 +326,7 @@ def prepare_training_config( config: TrainEngineConfig, parallel_dims: ArchonParallelDims, model_config: PretrainedConfig, - enable_tree_training: bool, + tree_training_mode: str, logger: logging.Logger, ) -> tuple[ActivationCheckpointConfig | None, bool]: """Build and validate all training configs before parallelism setup. @@ -345,7 +362,7 @@ def prepare_training_config( config=config, parallel_dims=parallel_dims, enable_compile=enable_compile, - enable_tree_training=enable_tree_training, + tree_training_mode=tree_training_mode, logger=logger, ) diff --git a/areal/experimental/engine/archon_zero1.py b/areal/experimental/engine/archon_zero1.py new file mode 100644 index 0000000000..40016932d1 --- /dev/null +++ b/areal/experimental/engine/archon_zero1.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import math + +import torch +import torch.distributed as dist +from torch import nn +from torch.distributed.optim import ZeroRedundancyOptimizer + +from areal.api.cli_args import OptimizerConfig + + +def parallelize_fn_zero1( + model: nn.Module, +) -> nn.Module: + """Zero-1 path keeps full parameter replicas without model wrapper.""" + return model + + +def create_zero1_optimizer( + params: list[nn.Parameter], + optimizer_config: OptimizerConfig, + dp_group: dist.ProcessGroup, +) -> ZeroRedundancyOptimizer: + """Create ZeroRedundancyOptimizer from optimizer config.""" + common_kwargs: dict[str, object] = { + "lr": optimizer_config.lr, + "weight_decay": optimizer_config.weight_decay, + } + if optimizer_config.type == "adam": + return ZeroRedundancyOptimizer( + params, + optimizer_class=torch.optim.AdamW, + process_group=dp_group, + betas=(optimizer_config.beta1, optimizer_config.beta2), + eps=optimizer_config.eps, + fused=True, + **common_kwargs, + ) + if optimizer_config.type == "sgd": + return ZeroRedundancyOptimizer( + params, + optimizer_class=torch.optim.SGD, + process_group=dp_group, + **common_kwargs, + ) + raise ValueError(f"Unsupported optimizer type for Zero1: {optimizer_config.type}") + + +def zero1_clip_grad_norm( + parameters: list[nn.Parameter], + max_norm: float, + dp_group: dist.ProcessGroup, + eps: float = 1e-6, +) -> float: + """Clip gradients by global norm across DP ranks.""" + grads = [p.grad for p in parameters if p.grad is not None] + if not grads: + return 0.0 + + device = grads[0].device + total_sq = torch.zeros((), device=device, dtype=torch.float32) + for grad in grads: + total_sq += grad.detach().float().pow(2).sum() + + total_norm = total_sq.sqrt() + total_norm_value = float(total_norm) + if not math.isfinite(total_norm_value): + return total_norm_value + + clip_coef = (max_norm / (total_norm + eps)).clamp(max=1.0) + for grad in grads: + grad.mul_(clip_coef.to(device=grad.device, dtype=grad.dtype)) + return total_norm_value + + +def all_reduce_zero1_gradients( + parameters: list[nn.Parameter], + dp_group: dist.ProcessGroup, +) -> None: + """Synchronize gradients across DP ranks for Zero-1 training.""" + for parameter in parameters: + if parameter.grad is None: + continue + dist.all_reduce(parameter.grad, group=dp_group) diff --git a/areal/experimental/models/archon/attention/sdpa.py b/areal/experimental/models/archon/attention/sdpa.py index 96588f5bae..c1b5d8cc36 100644 --- a/areal/experimental/models/archon/attention/sdpa.py +++ b/areal/experimental/models/archon/attention/sdpa.py @@ -16,8 +16,10 @@ def create_block_causal_mask_2d( - cu_seqlens: torch.Tensor, - seq_len: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + q_len: int, + k_len: int, device: torch.device, dtype: torch.dtype, ) -> torch.Tensor: @@ -28,35 +30,71 @@ def create_block_causal_mask_2d( - Across sequences: no attention allowed Args: - cu_seqlens: Cumulative sequence lengths, shape [num_seqs + 1]. - For example, [0, 3, 5, 7] means 3 sequences with lengths 3, 2, 2. - seq_len: Total sequence length (should equal cu_seqlens[-1]). + cu_seqlens_q: Query cumulative sequence lengths, shape [num_seqs + 1]. + cu_seqlens_k: Key cumulative sequence lengths, shape [num_seqs + 1]. + For self-attention, ``cu_seqlens_q == cu_seqlens_k``. + For KV-cache attention, K can be longer than Q for each sequence. + q_len: Total number of query tokens. + k_len: Total number of key/value tokens. device: Target device for the mask tensor. dtype: Target dtype (float mask with 0.0 and -inf). Returns: - Attention mask of shape [seq_len, seq_len]. + Attention mask of shape [q_len, k_len]. 0.0 = attend, -inf = mask out. - Example for cu_seqlens=[0, 3, 5, 7]:: - - [ 0, -inf, -inf, | -inf, -inf, | -inf, -inf] - [ 0, 0 , -inf, | -inf, -inf, | -inf, -inf] - [ 0, 0 , 0 , | -inf, -inf, | -inf, -inf] - [-inf, -inf, -inf, | 0, -inf, | -inf, -inf] - [-inf, -inf, -inf, | 0, 0 , | -inf, -inf] - [-inf, -inf, -inf, | -inf, -inf, | 0, -inf] - [-inf, -inf, -inf, | -inf, -inf, | 0, 0 ] + Examples: + 1) Standard packed self-attention (q_len == k_len):: + + cu_q = cu_k = [0, 3] + # sequence length = 3 + # allowed key positions per query row: + # q0 -> [0] + # q1 -> [0, 1] + # q2 -> [0, 1, 2] + + 2) Right-aligned KV-cache attention (q_len < k_len):: + + cu_q = [0, 4] + cu_k = [0, 6] + # q tokens correspond to the last 4 positions in key timeline. + # Local alignment offset = k_seq_len - q_seq_len = 2 + # allowed key positions per query row: + # q0 -> [0, 1, 2] + # q1 -> [0, 1, 2, 3] + # q2 -> [0, 1, 2, 3, 4] + # q3 -> [0, 1, 2, 3, 4, 5] """ - positions = torch.arange(seq_len, device=device) - seq_ids = torch.searchsorted(cu_seqlens, positions, side="right") - 1 + if cu_seqlens_q.numel() != cu_seqlens_k.numel(): + raise ValueError( + "cu_seqlens_q and cu_seqlens_k must have same number of sequences, " + f"got {cu_seqlens_q.numel()} vs {cu_seqlens_k.numel()}." + ) + + q_positions = torch.arange(q_len, device=device) + k_positions = torch.arange(k_len, device=device) + cu_q = cu_seqlens_q.to(device) + cu_k = cu_seqlens_k.to(device) + q_seq_ids = torch.searchsorted(cu_q, q_positions, side="right") - 1 + k_seq_ids = torch.searchsorted(cu_k, k_positions, side="right") - 1 + + # Query/key must belong to the same packed sequence. + same_seq = q_seq_ids.unsqueeze(1) == k_seq_ids.unsqueeze(0) - # same_seq: query and key must be in the same sequence - # causal: key position <= query position - same_seq = seq_ids.unsqueeze(1) == seq_ids.unsqueeze(0) - causal = positions.unsqueeze(0) <= positions.unsqueeze(1) + # Sequence-local token indices. + q_local = q_positions - cu_q[q_seq_ids] + k_local = k_positions - cu_k[k_seq_ids] - mask = torch.full((seq_len, seq_len), float("-inf"), device=device, dtype=dtype) + # Right-align Q inside K for KV-cache style attention: + # q_abs = (k_seq_len - q_seq_len) + q_local. + q_seq_lens = cu_q[q_seq_ids + 1] - cu_q[q_seq_ids] + k_seq_lens = cu_k[q_seq_ids + 1] - cu_k[q_seq_ids] + right_offset = (k_seq_lens - q_seq_lens).clamp(min=0) + + # Causal condition: key local index <= aligned query absolute index. + causal = k_local.unsqueeze(0) <= (q_local + right_offset).unsqueeze(1) + + mask = torch.full((q_len, k_len), float("-inf"), device=device, dtype=dtype) mask = mask.masked_fill(same_seq & causal, 0.0) return mask @@ -93,6 +131,7 @@ def forward( cu_seqlens: torch.Tensor, max_seqlen: int, tree_attn_meta: TreeAttentionMeta | None = None, + cu_seqlens_k: torch.Tensor | None = None, ) -> torch.Tensor: """Compute attention with block-diagonal causal mask. @@ -101,17 +140,29 @@ def forward( k: Key tensor, shape [batch, heads, seq_len, head_dim] v: Value tensor, shape [batch, heads, seq_len, head_dim] scale: Optional scale factor for attention scores. - cu_seqlens: Cumulative sequence lengths, shape [num_seqs + 1]. + cu_seqlens: Query cumulative sequence lengths, shape [num_seqs + 1]. max_seqlen: Maximum sequence length (unused, for API compatibility). tree_attn_meta: Unused. Accepted for interface compatibility with TreeAttentionWrapper. + cu_seqlens_k: Optional key cumulative sequence lengths. If not set, + defaults to ``cu_seqlens``. Returns: Attention output, shape [batch, heads, seq_len, head_dim] """ - seq_len = q.shape[2] + q_len = q.shape[2] + k_len = k.shape[2] + if cu_seqlens_k is None: + cu_seqlens_k = cu_seqlens # TODO: Mask should be precomputed and passed in, not computed here. - attn_mask = create_block_causal_mask_2d(cu_seqlens, seq_len, q.device, q.dtype) + attn_mask = create_block_causal_mask_2d( + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens_k, + q_len=q_len, + k_len=k_len, + device=q.device, + dtype=q.dtype, + ) with sdpa_kernel(self.sdpa_backends, set_priority=True): return F.scaled_dot_product_attention( diff --git a/areal/experimental/models/archon/attention/varlen.py b/areal/experimental/models/archon/attention/varlen.py index 75fad6ac6b..61c5cb9c40 100644 --- a/areal/experimental/models/archon/attention/varlen.py +++ b/areal/experimental/models/archon/attention/varlen.py @@ -276,6 +276,7 @@ def forward( cu_seqlens: torch.Tensor, max_seqlen: int, tree_attn_meta: TreeAttentionMeta | None = None, + cu_seqlens_k: torch.Tensor | None = None, ) -> torch.Tensor: """Compute attention with varlen_attn. @@ -308,15 +309,20 @@ def forward( v_3d = v.squeeze(0).transpose(0, 1).contiguous() # Ensure cu_seqlens is int32 (required by flash_attn) - cu_seqlens_i32 = cu_seqlens.to(torch.int32) + cu_seqlens_q_i32 = cu_seqlens.to(torch.int32) + + if cu_seqlens_k is None: + cu_seqlens_k_i32 = cu_seqlens_q_i32 + else: + cu_seqlens_k_i32 = cu_seqlens_k.to(torch.int32) # Call varlen_attn (self-attention: q and k have same cu_seqlens) out = varlen_attn( q_3d, k_3d, v_3d, - cu_seqlens_i32, - cu_seqlens_i32, + cu_seqlens_q_i32, + cu_seqlens_k_i32, max_seqlen, max_seqlen, is_causal=True, diff --git a/areal/experimental/models/archon/qwen2/model/model.py b/areal/experimental/models/archon/qwen2/model/model.py index c8a36f884f..8ce890c44d 100644 --- a/areal/experimental/models/archon/qwen2/model/model.py +++ b/areal/experimental/models/archon/qwen2/model/model.py @@ -4,11 +4,14 @@ from __future__ import annotations +from types import SimpleNamespace + import torch import torch.distributed as dist import torch.nn.functional as F from torch import nn from torch.distributed import ProcessGroup +from transformers.cache_utils import DynamicCache from areal.experimental.models.archon.attention import ( SDPAWrapper, @@ -131,7 +134,9 @@ def forward( cu_seqlens: torch.Tensor, max_seqlen: int, tree_attn_meta: TreeAttentionMeta | None = None, - ) -> torch.Tensor: + past_key_values: tuple[torch.Tensor, torch.Tensor] | None = None, + use_cache: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: bs, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) @@ -177,6 +182,21 @@ def forward( xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) + cu_seqlens_k = None + # Preserve per-step KV for cache update. Attention may still consume + # concatenated (past + current) KV when past_key_values is provided. + kv_step = (xk, xv) + # KV cache for attention compute path: concat past K/V with newly computed K/V + if past_key_values is not None: + past_k, past_v = past_key_values + xk = torch.cat([past_k, xk], dim=2) + xv = torch.cat([past_v, xv], dim=2) + cu_seqlens_k = cu_seqlens.clone() + cu_seqlens_k += past_k.shape[2] + cu_seqlens_k[0] = 0 + + new_kv = kv_step if use_cache else None + output = self.packed_attn( xq, xk, @@ -185,6 +205,7 @@ def forward( cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, tree_attn_meta=tree_attn_meta, + cu_seqlens_k=cu_seqlens_k, ) output = output.transpose(1, 2).contiguous() @@ -197,7 +218,11 @@ def forward( seqlen = output.shape[1] output = output.view(bs, seqlen, -1) - return self.wo(output) + output = self.wo(output) + + if use_cache: + return output, new_kv + return output class FeedForward(nn.Module): @@ -246,16 +271,28 @@ def forward( cu_seqlens: torch.Tensor, max_seqlen: int, tree_attn_meta: TreeAttentionMeta | None = None, - ) -> torch.Tensor: - x = x + self.attention( + past_key_values: tuple[torch.Tensor, torch.Tensor] | None = None, + use_cache: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + attn_out = self.attention( self.attention_norm(x), rope_cache, positions, cu_seqlens, max_seqlen, tree_attn_meta=tree_attn_meta, + past_key_values=past_key_values, + use_cache=use_cache, ) + new_kv: tuple[torch.Tensor, torch.Tensor] | None = None + if use_cache: + assert isinstance(attn_out, tuple) + attn_out, new_kv = attn_out + x = x + attn_out x = x + self.feed_forward(self.ffn_norm(x)) + if use_cache: + assert new_kv is not None + return x, new_kv return x def init_weights(self): @@ -342,11 +379,38 @@ def init_buffers(self, buffer_device: torch.device | str): def forward( self, tokens: torch.Tensor, - positions: torch.Tensor, - cu_seqlens: torch.Tensor, - max_seqlen: int | torch.Tensor, + positions: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | torch.Tensor | None = None, tree_attn_meta: TreeAttentionMeta | None = None, - ) -> torch.Tensor: + past_key_values: DynamicCache | None = None, + use_cache: bool = False, + ) -> torch.Tensor | SimpleNamespace: + if past_key_values is not None: + if ( + positions is not None + or cu_seqlens is not None + or max_seqlen is not None + ): + raise ValueError( + "When past_key_values is provided, positions/cu_seqlens/max_seqlen " + "must be None and are inferred internally." + ) + past_len = 0 + if len(past_key_values.layers) > 0: + past_len = int(past_key_values.layers[0].keys.shape[2]) + seq_len = tokens.shape[1] + positions = torch.arange( + past_len, + past_len + seq_len, + dtype=torch.long, + device=tokens.device, + ).unsqueeze(0) + cu_seqlens = torch.tensor( + [0, tokens.shape[1]], dtype=torch.int32, device=tokens.device + ) + max_seqlen = int(tokens.shape[1]) + past_len + # When pipeline parallelism enabled, cu_seqlens is [1, B+1] if cu_seqlens.ndim == 2: cu_seqlens = cu_seqlens.squeeze(0) @@ -357,15 +421,34 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens - for layer in self.layers.values(): - h = layer( + if use_cache: + if past_key_values is not None: + next_cache = past_key_values + else: + next_cache = DynamicCache() + for layer_idx, layer in enumerate(self.layers.values()): + layer_past = None + if past_key_values is not None and layer_idx < len(past_key_values.layers): + layer_entry = past_key_values.layers[layer_idx] + layer_past = (layer_entry.keys, layer_entry.values) + + layer_out = layer( h, self.rope_cache, positions, cu_seqlens, max_seqlen, tree_attn_meta=tree_attn_meta, + past_key_values=layer_past, + use_cache=use_cache, ) + if use_cache: + assert isinstance(layer_out, tuple) + h, layer_kv = layer_out + assert next_cache is not None + next_cache.update(layer_kv[0], layer_kv[1], layer_idx=layer_idx) + else: + h = layer_out h = self.norm(h) if self.norm else h @@ -373,6 +456,8 @@ def forward( output = self.score(h) if self.score else h else: output = self.output(h) if self.output else h + if use_cache: + return SimpleNamespace(logits=output, past_key_values=next_cache) return output diff --git a/areal/experimental/models/archon/qwen3/model/model.py b/areal/experimental/models/archon/qwen3/model/model.py index 062326be17..2f04935ab1 100644 --- a/areal/experimental/models/archon/qwen3/model/model.py +++ b/areal/experimental/models/archon/qwen3/model/model.py @@ -4,12 +4,15 @@ from __future__ import annotations +from types import SimpleNamespace + import torch import torch.distributed as dist import torch.nn.functional as F from torch import nn from torch.distributed import ProcessGroup from torch.distributed.tensor import DTensor +from transformers.cache_utils import DynamicCache from areal.experimental.models.archon.attention import ( SDPAWrapper, @@ -195,6 +198,8 @@ def forward( cu_seqlens: torch.Tensor, max_seqlen: int, tree_attn_meta: TreeAttentionMeta | None = None, + past_key_values: tuple[torch.Tensor, torch.Tensor] | None = None, + use_cache: bool = False, ) -> torch.Tensor: bs, seqlen, _ = x.shape @@ -251,6 +256,21 @@ def forward( xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) + cu_seqlens_k = None + # Preserve per-step KV for cache update. Attention may still consume + # concatenated (past + current) KV when past_key_values is provided. + kv_step = (xk, xv) + # KV cache for attention compute path: concat past K/V with newly computed K/V + if past_key_values is not None: + past_k, past_v = past_key_values + xk = torch.cat([past_k, xk], dim=2) + xv = torch.cat([past_v, xv], dim=2) + cu_seqlens_k = cu_seqlens.clone() + cu_seqlens_k += past_k.shape[2] + cu_seqlens_k[0] = 0 + + new_kv = kv_step if use_cache else None + output = self.packed_attn( xq, xk, @@ -259,6 +279,7 @@ def forward( cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, tree_attn_meta=tree_attn_meta, + cu_seqlens_k=cu_seqlens_k, ) output = output.transpose(1, 2).contiguous() @@ -271,7 +292,11 @@ def forward( seqlen = output.shape[1] output = output.view(bs, seqlen, -1) - return self.wo(output) + output = self.wo(output) + + if use_cache: + return output, new_kv + return output class FeedForward(nn.Module): @@ -334,19 +359,31 @@ def forward( cu_seqlens: torch.Tensor, max_seqlen: int, tree_attn_meta: TreeAttentionMeta | None = None, + past_key_values: tuple[torch.Tensor, torch.Tensor] | None = None, + use_cache: bool = False, ) -> torch.Tensor: - x = x + self.attention( + attn_out = self.attention( self.attention_norm(x), rope_cache, positions, cu_seqlens, max_seqlen, tree_attn_meta=tree_attn_meta, + past_key_values=past_key_values, + use_cache=use_cache, ) + new_kv: tuple[torch.Tensor, torch.Tensor] | None = None + if use_cache: + assert isinstance(attn_out, tuple) + attn_out, new_kv = attn_out + x = x + attn_out if self.moe_enabled: x = x + self.moe(self.ffn_norm(x)) else: x = x + self.feed_forward(self.ffn_norm(x)) + if use_cache: + assert new_kv is not None + return x, new_kv return x def init_weights(self): @@ -456,11 +493,42 @@ def init_buffers(self, buffer_device: torch.device | str): def forward( self, tokens: torch.Tensor, - positions: torch.Tensor, - cu_seqlens: torch.Tensor, - max_seqlen: int | torch.Tensor, + positions: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | torch.Tensor | None = None, tree_attn_meta: TreeAttentionMeta | None = None, - ) -> torch.Tensor: + past_key_values: DynamicCache | None = None, + use_cache: bool = False, + ) -> torch.Tensor | SimpleNamespace: + if past_key_values is not None: + if ( + positions is not None + or cu_seqlens is not None + or max_seqlen is not None + ): + raise ValueError( + "When past_key_values is provided, positions/cu_seqlens/max_seqlen " + "must be None and are inferred internally." + ) + past_len = 0 + if len(past_key_values.layers) > 0: + past_len = int(past_key_values.layers[0].keys.shape[2]) + seq_len = tokens.shape[1] + positions = torch.arange( + past_len, + past_len + seq_len, + dtype=torch.long, + device=tokens.device, + ).unsqueeze(0) + cu_seqlens = torch.tensor( + [0, tokens.shape[1]], dtype=torch.int32, device=tokens.device + ) + max_seqlen = int(tokens.shape[1]) + past_len + + assert positions is not None + assert cu_seqlens is not None + assert max_seqlen is not None + # When pipeline parallelism enabled, cu_seqlens is [1, B+1] if cu_seqlens.ndim == 2: cu_seqlens = cu_seqlens.squeeze(0) @@ -471,15 +539,34 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens - for layer in self.layers.values(): - h = layer( + if use_cache: + if past_key_values is not None: + next_cache = past_key_values + else: + next_cache = DynamicCache() + for layer_idx, layer in enumerate(self.layers.values()): + layer_past = None + if past_key_values is not None and layer_idx < len(past_key_values.layers): + layer_entry = past_key_values.layers[layer_idx] + layer_past = (layer_entry.keys, layer_entry.values) + + layer_out = layer( h, self.rope_cache, positions, cu_seqlens, max_seqlen, tree_attn_meta=tree_attn_meta, + past_key_values=layer_past, + use_cache=use_cache, ) + if use_cache: + assert isinstance(layer_out, tuple) + h, layer_kv = layer_out + assert next_cache is not None + next_cache.update(layer_kv[0], layer_kv[1], layer_idx=layer_idx) + else: + h = layer_out h = self.norm(h) if self.norm else h @@ -487,6 +574,8 @@ def forward( output = self.score(h) if self.score else h else: output = self.output(h) if self.output else h + if use_cache: + return SimpleNamespace(logits=output, past_key_values=next_cache) return output diff --git a/areal/infra/controller/train_controller.py b/areal/infra/controller/train_controller.py index 97a2f7a684..ad1ff83b9a 100644 --- a/areal/infra/controller/train_controller.py +++ b/areal/infra/controller/train_controller.py @@ -77,6 +77,7 @@ def _dispatch_tensors( item_list: list[dict[str, Any]], dp_size: int, group_size: int = 1, + packing_algorithm: str = "ffd", ) -> tuple[list[list[dict[str, Any]]], list[list[int]]]: """Partition trajectories across DP groups by balanced token count. @@ -87,6 +88,34 @@ def _dispatch_tensors( partitioning. """ n = len(item_list) + + if packing_algorithm == "dta": + has_rtensor = any( + _find_in_structure(item, RTensor) is not None for item in item_list + ) + + if has_rtensor: + # DTA requires sequence-level data. If we have grouped RTensors, we MUST + # localize everything on the Controller to properly slice the tensors. + item_list = RTensor.localize(item_list) + has_rtensor = False + + if not has_rtensor: + from areal.infra.dist_rollout import _dta_allocate + from areal.utils.data import unpack_groups_to_sequences + + # Flatten grouped trajectories into sequence level for DTA + flat_item_list = unpack_groups_to_sequences(item_list) + + dta_result = _dta_allocate(flat_item_list, dp_size) + stats_tracker.scalar(**dta_result.metrics.to_stats()) + group_indices = dta_result.group_indices + splits: list[list[dict[str, Any]]] = [] + for gidxs in group_indices: + splits.append([flat_item_list[idx] for idx in gidxs]) + + return splits, group_indices + if n % group_size != 0: raise ValueError( f"item count ({n}) must be divisible by group_size ({group_size})" @@ -535,7 +564,10 @@ def _split(item: Any) -> list[Any]: if _is_tensor_like(item): if group_indices is None: splits, group_indices = _dispatch_tensors( - item, dp_size, group_size=group_size + item, + dp_size, + group_size=group_size, + packing_algorithm=self.config.packing_algorithm, ) return splits return [[item[i] for i in idxs] for idxs in group_indices] diff --git a/areal/infra/dist_rollout.py b/areal/infra/dist_rollout.py index 28e363066f..c86c65b74a 100644 --- a/areal/infra/dist_rollout.py +++ b/areal/infra/dist_rollout.py @@ -2,28 +2,134 @@ from collections.abc import Callable from dataclasses import dataclass +from types import SimpleNamespace from typing import Any +import torch import torch.distributed as dist from torchdata.stateful_dataloader import StatefulDataLoader from areal.api import InferenceEngine, TrainEngine, WorkflowLike from areal.infra.platforms import current_platform +from areal.utils import stats_tracker from areal.utils.data import ( all_gather_tensor_container, broadcast_tensor_container, + extract_single_valid_token_sequence, + get_total_valid_tokens, split_and_unpad_tensor, tensor_container_to, ) from areal.utils.seqpack import get_allocate_fn +class _TreeTokenOnlyTimeModel: + def pred(self, stats: dict[str, Any]) -> float: + return float(stats["n_tree_tokens"]) + + +def _validate_group_indices( + group_indices: list[list[int]], n_groups: int, n_items: int +) -> None: + if len(group_indices) != n_groups: + raise ValueError( + f"group_indices must contain exactly {n_groups} groups, got {len(group_indices)}." + ) + flat_indices = [idx for group in group_indices for idx in group] + if len(flat_indices) != n_items: + raise ValueError( + f"group_indices must assign exactly {n_items} items, got {len(flat_indices)}." + ) + if sorted(flat_indices) != list(range(n_items)): + raise ValueError( + "group_indices must be a permutation of [0, ..., n_items-1] " + "(no duplicates, no missing/out-of-range indices)." + ) + + @dataclass class RedistributedData: all_data: list[dict[str, Any]] data: list[dict[str, Any]] rank: int group_indices: list[list[int]] + dta_metrics: "DTAMetrics | None" = None + + +@dataclass(slots=True) +class DTAMetrics: + n_tokens: float + n_tree_tokens_before_allocation: float + n_tree_tokens_after_allocation: float + compression_ratio_before_allocation: float + compression_ratio_after_allocation: float + + def to_stats(self) -> dict[str, float]: + return { + "dta/n_tokens": self.n_tokens, + "dta/n_tree_tokens_before_allocation": self.n_tree_tokens_before_allocation, + "dta/n_tree_tokens_after_allocation": self.n_tree_tokens_after_allocation, + "dta/compression_ratio_before_allocation": self.compression_ratio_before_allocation, + "dta/compression_ratio_after_allocation": self.compression_ratio_after_allocation, + } + + +@dataclass(slots=True) +class DTAAllocationResult: + group_indices: list[list[int]] + metrics: DTAMetrics + + +def _dta_allocate( + trajectories: list[dict[str, Any]], + n_groups: int, +) -> DTAAllocationResult: + from areal.experimental.dta.dp import LB_by_DFS_and_TM + from areal.experimental.dta.token_trie import TokenTrie + + token_seqs: list[torch.Tensor] = [] + for idx, trajectory in enumerate(trajectories): + try: + seq = extract_single_valid_token_sequence(trajectory) + except (TypeError, ValueError) as err: + raise ValueError( + f"Invalid trajectory format at index {idx} for DTA partitioning." + ) from err + token_seqs.append(seq) + + all_stats = TokenTrie(token_seqs).get_stats(mode="backward") + n_total_tokens = float(all_stats["n_tokens"]) + n_tree_tokens_before = float(all_stats["n_tree_tokens"]) + + config = SimpleNamespace(K=n_groups, mode="backward", block_size=None) + group_indices = LB_by_DFS_and_TM(token_seqs, _TreeTokenOnlyTimeModel(), config) + + n_tree_tokens_after = 0.0 + for group in group_indices: + if not group: + continue + group_token_seqs = [token_seqs[idx] for idx in group] + group_stats = TokenTrie(group_token_seqs).get_stats(mode="backward") + n_tree_tokens_after += float(group_stats["n_tree_tokens"]) + + compression_ratio_before = ( + n_total_tokens / n_tree_tokens_before + if n_tree_tokens_before > 0 + else float("nan") + ) + compression_ratio_after = ( + n_total_tokens / n_tree_tokens_after + if n_tree_tokens_after > 0 + else float("nan") + ) + metrics = DTAMetrics( + n_tokens=n_total_tokens, + n_tree_tokens_before_allocation=n_tree_tokens_before, + n_tree_tokens_after_allocation=n_tree_tokens_after, + compression_ratio_before_allocation=compression_ratio_before, + compression_ratio_after_allocation=compression_ratio_after, + ) + return DTAAllocationResult(group_indices=group_indices, metrics=metrics) def redistribute_trajectories( @@ -45,7 +151,9 @@ def redistribute_trajectories( group : dist.ProcessGroup, optional The process group for communication. If None, uses the default group. packing_algorithm : str, optional - Packing algorithm to use ("ffd" or "kk"). Default is "ffd". + How to pack trajectories across data-parallel ranks: ``"ffd"`` or ``"kk"`` + balance by total sequence length; ``"dta"`` uses DTA DFS-order partitioning + with ``n_tree_tokens`` as cost. Default ``"ffd"``. Returns ------- @@ -65,7 +173,7 @@ def redistribute_trajectories( all_data.extend(traj_list) # Compute sequence lengths for load balancing - seqlens = [d["attention_mask"].sum().item() for d in all_data] + seqlens = [get_total_valid_tokens(d) for d in all_data] # Remove pad positions from each trajectory (split_and_unpad_tensor # auto-derives trim lengths from attention_mask when traj_seqlens=None) @@ -76,21 +184,40 @@ def redistribute_trajectories( for d in all_data ] - allocate_fn = get_allocate_fn(packing_algorithm) - # Allocate trajectories to ranks using the configured packing algorithm - # No capacity limit leads to balanced partition across this group - group_indices = allocate_fn( - seqlens, capacity=int(1e12), min_groups=dist.get_world_size(group) - ) - local_indices = group_indices[dist.get_rank(group=group)] + n_groups = dist.get_world_size(group) + if packing_algorithm == "dta": + # Unpack group-level trajectories into sequence-level for DTA + from areal.utils.data import unpack_groups_to_sequences + + all_data = unpack_groups_to_sequences(all_data) + + dta_result = _dta_allocate(all_data, n_groups) + group_indices = dta_result.group_indices + dta_metrics = dta_result.metrics + elif packing_algorithm in ("ffd", "kk"): + allocate_fn = get_allocate_fn(packing_algorithm) + # Allocate trajectories to ranks using the configured packing algorithm + # No capacity limit leads to balanced partition across this group + group_indices = allocate_fn( + seqlens, capacity=int(1e12), min_groups=dist.get_world_size(group) + ) + dta_metrics = None + else: + raise ValueError( + f"Unsupported packing_algorithm: {packing_algorithm!r}. " + "Expected one of {'ffd', 'kk', 'dta'}." + ) + _validate_group_indices(group_indices, n_groups=n_groups, n_items=len(all_data)) # Select assigned trajectories for this rank (no concatenation — deferred to train side) + local_indices = group_indices[dist.get_rank(group=group)] data = [all_data[i] for i in local_indices] return RedistributedData( all_data=all_data, data=data, rank=dist.get_rank(group=group), group_indices=group_indices, + dta_metrics=dta_metrics, ) @@ -122,22 +249,32 @@ def _broadcast_and_redistribute_trajectories( list[dict[str, Any]] Redistributed and broadcast batch available on all ranks (list of trajs) """ + rollout_packing = self.train_engine.config.packing_algorithm + if trajectories is not None: - config = getattr(self.train_engine, "config", None) - mb_spec = getattr(config, "mb_spec", None) - packing_algorithm = getattr(mb_spec, "packing_algorithm", "ffd") redist = redistribute_trajectories( trajectories, group=self.train_engine.data_parallel_group, - packing_algorithm=packing_algorithm, + packing_algorithm=rollout_packing, ) batch = redist.data + dta_metrics_payload = [redist.dta_metrics] else: batch = None + dta_metrics_payload = [None] current_platform.synchronize() dist.barrier(group=self.train_engine.cpu_group) + dist.broadcast_object_list( + dta_metrics_payload, + src=self.train_engine.current_data_parallel_head(), + group=self.train_engine.context_and_model_parallel_group, + ) + dta_metrics = dta_metrics_payload[0] + if dta_metrics is not None: + stats_tracker.scalar(**dta_metrics.to_stats()) + batch = broadcast_tensor_container( batch, src_rank=self.train_engine.current_data_parallel_head(), diff --git a/areal/models/tree_attn/module_archon.py b/areal/models/tree_attn/module_archon.py index d039228e69..49140bf109 100644 --- a/areal/models/tree_attn/module_archon.py +++ b/areal/models/tree_attn/module_archon.py @@ -93,6 +93,7 @@ def forward( cu_seqlens: torch.Tensor, max_seqlen: int, tree_attn_meta: TreeAttentionMeta | None = None, + cu_seqlens_k: torch.Tensor | None = None, ) -> torch.Tensor: """Compute tree attention. @@ -107,6 +108,8 @@ def forward( kept for API compatibility with VarlenAttentionWrapper). tree_attn_meta: Tree attention metadata containing either a BlockMask (flex attention) or TreeAttentionData (Triton). + cu_seqlens_k: Unused. Accepted for interface compatibility with + VarlenAttentionWrapper. Returns: Attention output, shape [batch, heads, seq_len, head_dim] diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 7c38d8090e..de46077389 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -48,8 +48,10 @@ from areal.infra.data_service import DataController from areal.infra.data_service.controller.config import DataServiceConfig from areal.infra.data_service.rdataset import RDataset +from areal.infra.rpc.rtensor import RTensor from areal.infra.utils.concurrent import call_maybe_async from areal.utils import logging, perf_tracer, seeding, stats_tracker +from areal.utils.data import unpack_groups_to_sequences from areal.utils.dataloader import create_dataloader from areal.utils.environ import is_single_controller from areal.utils.evaluator import Evaluator @@ -369,6 +371,9 @@ def __init__( ) self._config_perf_tracer() + self._cumulative_training_tokens = 0.0 + self._cumulative_train_step_time = 0.0 + self._cumulative_step_time = 0.0 self._apply_initial_offload_policy() @staticmethod @@ -567,6 +572,9 @@ def train( group_size=config.gconfig.n_samples, dynamic_bs=self.config.dynamic_bs, ) + if config.actor.packing_algorithm == "dta": + rollout_batch = RTensor.localize(rollout_batch) + rollout_batch = unpack_groups_to_sequences(rollout_batch) if self._should_offload_rollout: self._offload_rollout() @@ -582,6 +590,15 @@ def train( ), ): values = self.critic.compute_values(rollout_batch) + if config.actor.packing_algorithm == "dta": + assert isinstance(values, list), ( + f"values must return list under DTA, got {type(values)}" + ) + assert len(values) == len(rollout_batch), ( + "values length mismatch under DTA: " + f"len(rollout_batch)={len(rollout_batch)}, " + f"len(values)={len(values)}" + ) for traj, v in zip(rollout_batch, values): traj["values"] = v self.critic.get_device_stats().log("critic values") @@ -599,6 +616,16 @@ def train( ), ): ref_logps = self.ref.compute_logp(rollout_batch) + if config.actor.packing_algorithm == "dta": + assert isinstance(ref_logps, list), ( + "ref_logps must return list under DTA, " + f"got {type(ref_logps)}" + ) + assert len(ref_logps) == len(rollout_batch), ( + "ref_logps length mismatch under DTA: " + f"len(rollout_batch)={len(rollout_batch)}, " + f"len(ref_logps)={len(ref_logps)}" + ) for traj, logp in zip(rollout_batch, ref_logps): traj["ref_logp"] = logp self.ref.get_device_stats().log("ref logp") @@ -617,6 +644,16 @@ def train( ), ): teacher_logps = self.teacher.compute_logp(rollout_batch) + if config.actor.packing_algorithm == "dta": + assert isinstance(teacher_logps, list), ( + "teacher_logps must return list under DTA, " + f"got {type(teacher_logps)}" + ) + assert len(teacher_logps) == len(rollout_batch), ( + "teacher_logps length mismatch under DTA: " + f"len(rollout_batch)={len(rollout_batch)}, " + f"len(teacher_logps)={len(teacher_logps)}" + ) for traj, logp in zip(rollout_batch, teacher_logps): traj["teacher_logp"] = logp traj["rl_loss_weight"] = self.config.teacher.rl_loss_weight @@ -639,6 +676,16 @@ def train( ), ): prox_logps = self.actor.compute_logp(rollout_batch) + if config.actor.packing_algorithm == "dta": + assert isinstance(prox_logps, list), ( + "prox_logps must return list under DTA, " + f"got {type(prox_logps)}" + ) + assert len(prox_logps) == len(rollout_batch), ( + "prox_logps length mismatch under DTA: " + f"len(rollout_batch)={len(rollout_batch)}, " + f"len(prox_logps)={len(prox_logps)}" + ) for traj, logp in zip(rollout_batch, prox_logps): traj["prox_logp"] = logp self.actor.get_device_stats().log("recompute logp") @@ -1174,11 +1221,55 @@ def _export_and_commit_stats(self, epoch: int, epoch_step: int, global_step: int stats.update(self.rollout.export_stats()) if self.eval_rollout is not None: stats.update(self.eval_rollout.export_stats()) + self._add_throughput_metrics(stats) self.stats_logger.commit(epoch, epoch_step, global_step, stats) dist.barrier(group=self.actor.cpu_group) current_platform.synchronize() + def _add_throughput_metrics(self, stats: dict[str, float]) -> None: + # TODO(agent): Not enabled yet, will be implemented in the future. + return + if "ppo_actor/update/n_tokens" not in stats: + raise ValueError( + "Missing required metric `ppo_actor/update/n_tokens` for throughput computation." + ) + if "timeperf/train_step" not in stats: + raise ValueError( + "Missing required metric `timeperf/train_step` for throughput computation." + ) + + n_tokens = float(stats["ppo_actor/update/n_tokens"]) + train_step_time = float(stats["timeperf/train_step"]) + step_total_time = sum( + float(value) + for key, value in stats.items() + if key.startswith("timeperf/") and not key.endswith("__count") + ) + stats["timeperf/step_total"] = step_total_time + + self._cumulative_training_tokens += n_tokens + self._cumulative_train_step_time += max(train_step_time, 0.0) + self._cumulative_step_time += max(step_total_time, 0.0) + stats["timeperf/cumulative_step_total"] = self._cumulative_step_time + + if n_tokens > 0.0 and train_step_time > 0.0: + stats["training_throughput"] = n_tokens / train_step_time + if ( + self._cumulative_training_tokens > 0.0 + and self._cumulative_train_step_time > 0.0 + ): + stats["cumulative_training_throughput"] = ( + self._cumulative_training_tokens / self._cumulative_train_step_time + ) + + if n_tokens > 0.0 and step_total_time > 0.0: + stats["throughput"] = n_tokens / step_total_time + if self._cumulative_training_tokens > 0.0 and self._cumulative_step_time > 0.0: + stats["cumulative_throughput"] = ( + self._cumulative_training_tokens / self._cumulative_step_time + ) + def _validate_cfg(self): """validate config for incompatible settings before weight initialization, to avoid wasted resources on spawning workers and loading models.""" rollout_backend = self.rollout_alloc.backend diff --git a/areal/utils/data.py b/areal/utils/data.py index 09368e4ef2..81dc5060e8 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -47,6 +47,62 @@ def get_batch_size(data: dict[str, Any]) -> int: return 0 +def extract_valid_token_sequences( + input_ids_batch: torch.Tensor, + attention_mask: torch.Tensor, +) -> tuple[list[torch.Tensor], int]: + """Extract unpadded token sequences from a [B, S] batch.""" + if not (torch.is_tensor(input_ids_batch) and torch.is_tensor(attention_mask)): + raise TypeError("input_ids_batch and attention_mask must be torch.Tensor.") + if input_ids_batch.ndim != 2 or attention_mask.ndim != 2: + raise ValueError("input_ids_batch and attention_mask must be rank-2 tensors.") + if input_ids_batch.shape != attention_mask.shape: + raise ValueError( + "input_ids_batch and attention_mask must have identical shapes." + ) + + max_seq_len = 0 + input_ids_list: list[torch.Tensor] = [] + for i in range(input_ids_batch.shape[0]): + valid_length = int(attention_mask[i].sum().item()) + max_seq_len = max(max_seq_len, valid_length) + input_ids_list.append(input_ids_batch[i, :valid_length]) + return input_ids_list, max_seq_len + + +def extract_single_valid_token_sequence( + trajectory: dict[str, Any], +) -> torch.Tensor: + """Extract one unpadded token sequence from a trajectory dict. + + Raises + ------ + ValueError + If required fields are missing, malformed, or trajectory batch size is not 1. + """ + if "input_ids" not in trajectory or "attention_mask" not in trajectory: + raise ValueError( + "trajectory must contain both 'input_ids' and 'attention_mask'." + ) + + input_ids = trajectory["input_ids"] + attention_mask = trajectory["attention_mask"] + seqs, _ = extract_valid_token_sequences(input_ids, attention_mask) + if len(seqs) != 1: + raise ValueError( + f"trajectory must contain exactly one sequence, got {len(seqs)}." + ) + return seqs[0] + + +def get_total_valid_tokens(trajectory: dict[str, Any]) -> int: + """Return total valid token count inferred from attention_mask when available.""" + attention_mask = trajectory.get("attention_mask") + if torch.is_tensor(attention_mask): + return int(attention_mask.sum().item()) + return 0 + + def reorder_list(xs: Sequence, indices: list[int]) -> list: assert len(set(indices)) == len(xs) return [xs[i] for i in indices] @@ -357,6 +413,46 @@ def split_and_unpad_tensor( return result +def unpack_groups_to_sequences( + item_list: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Flatten grouped trajectories into fully independent sequence-level dicts. + + For example, if an item in item_list has shape [8, seq_len, ...] for 8 samples + (group_size=8), it will be split into 8 separate dictionaries, each with + shape [1, seq_len, ...]. This is required for algorithms like DTA that operate + on individual sequences rather than groups. + + Args: + item_list: List of trajectory dictionaries. + + Returns: + A new list where every dictionary represents a single sequence. + """ + flat_item_list = [] + for item in item_list: + attn_mask = item.get("attention_mask") + if ( + attn_mask is not None + and isinstance(attn_mask, torch.Tensor) + and attn_mask.ndim >= 2 + ): + n_seqs = attn_mask.shape[0] + if n_seqs > 1: + splits = split_and_unpad_tensor( + item, n_trajs=n_seqs, traj_group_sizes=1 + ) + if isinstance(splits, list): + flat_item_list.extend(splits) + else: + flat_item_list.append(splits) + else: + flat_item_list.append(item) + else: + flat_item_list.append(item) + return flat_item_list + + @dataclass class TrajBatchMeta: """Metadata for reversing concat_batch: traj counts, group sizes, seqlens.""" @@ -697,6 +793,7 @@ def split_padded_tensor_dict_into_mb_list( data: dict[str, Any], mb_spec: MicroBatchSpec, group: dist.ProcessGroup | None = None, + one_seq_per_mb: bool = False, ) -> MicroBatchList: """Split a padded dict of tensors into micro-batches based on the attention mask. @@ -704,6 +801,10 @@ def split_padded_tensor_dict_into_mb_list( data (Dict): Dictionary containing padded tensors. mb_spec (MicroBatchSpec): Specification for micro-batch splitting. group (Optional[dist.ProcessGroup]): Process group for distributed synchronization. + one_seq_per_mb (bool): If True, each micro-batch contains exactly one sequence + and skips cross-rank synchronized micro-batch count alignment. + Requires every row's valid token count ``<= mb_spec.max_tokens_per_mb`` + and ``granularity=1`` (errors in this path mention ``one_seq_per_mb``). Returns: MicroBatchList: A structure containing the split micro-batches and metadata. @@ -721,6 +822,20 @@ def split_padded_tensor_dict_into_mb_list( raise RuntimeError(f"Batch size {bs} cannot divide granularity {granularity}.") max_seqlen = data["attention_mask"].shape[1] seq_lens = data["attention_mask"].sum(1).long().cpu().numpy().tolist() + if one_seq_per_mb: + if granularity != 1: + raise RuntimeError( + f"split_padded_tensor_dict_into_mb_list: one_seq_per_mb=True requires " + f"granularity=1, but got granularity={granularity}." + ) + cap = mb_spec.max_tokens_per_mb + for i, L in enumerate(seq_lens): + if L > cap: + raise RuntimeError( + f"split_padded_tensor_dict_into_mb_list: one_seq_per_mb=True, " + f"but sequence at index {i} has {L} valid tokens, which exceeds " + f"max_tokens_per_mb={cap}." + ) input_lens = ( data["attention_mask"] .view(bs // granularity, granularity, -1) @@ -748,13 +863,19 @@ def split_padded_tensor_dict_into_mb_list( not_to_split[key] = value # split - group_indices = allocate_balanced_mbs_synced(mb_spec, input_lens, group=group) - group_indices = [ - seqpack.flat2d( - [list(range(i * granularity, (i + 1) * granularity)) for i in group_index] - ) - for group_index in group_indices - ] + if one_seq_per_mb: + group_indices = [[i] for i in range(bs)] + else: + group_indices = allocate_balanced_mbs_synced(mb_spec, input_lens, group=group) + group_indices = [ + seqpack.flat2d( + [ + list(range(i * granularity, (i + 1) * granularity)) + for i in group_index + ] + ) + for group_index in group_indices + ] splitted_lens = [ [seq_lens[i] for i in group_index] for group_index in group_indices ] diff --git a/areal/utils/logging.py b/areal/utils/logging.py index 6752134a76..3266c7018c 100644 --- a/areal/utils/logging.py +++ b/areal/utils/logging.py @@ -87,6 +87,7 @@ "TreeAttentionCore": "light_cyan", "TreeAttentionConstants": "light_cyan", "TreeAttentionViz": "light_cyan", + "DTA": "light_cyan", # Checkpoint - blue (infrastructure) "Saver": "blue", "AsyncCheckpoint": "blue", diff --git a/docs/en/algorithms/grpo_series.md b/docs/en/algorithms/grpo_series.md index 0eb3dc1f8d..b9a519a021 100644 --- a/docs/en/algorithms/grpo_series.md +++ b/docs/en/algorithms/grpo_series.md @@ -124,7 +124,7 @@ When `eps_clip_higher` is `None`, symmetric clipping is used: $\text{clip}(r, 1-\epsilon, 1+\epsilon)$. When `eps_clip_higher` is set (DAPO-style), asymmetric clipping is used: -$\text{clip}(r, 1-\epsilon_{\text{low}}, 1+\epsilon_{\text{high}})$. +$\\text{clip}(r, 1-\\epsilon\_{\\text{low}}, 1+\\epsilon\_{\\text{high}})$. ### Importance Sampling Level (`actor.importance_sampling_level`) @@ -258,7 +258,7 @@ r_{i,t}(\theta) \hat{A}_{i,t}, \text{clip}\left( r_{i,t}(\theta), 1-\epsilon_{\text{low}}, 1+\epsilon_{\text{high}} \right) \hat{A}_{i,t} \right) \right] $$ -where $\hat{A}_{i,t}$ is the group-normalized advantage and $r_{i,t}(\theta)$ is the +where $\\hat{A}_{i,t}$ is the group-normalized advantage and $r_{i,t}(\\theta)$ is the token-level policy ratio. **Asymmetric clipping parameters:** diff --git a/docs/en/cli_reference.md b/docs/en/cli_reference.md index b4ec9149ac..a42ee32b7e 100644 --- a/docs/en/cli_reference.md +++ b/docs/en/cli_reference.md @@ -369,7 +369,10 @@ Configuration for PPO actor model, a subclass of a TrainEngine. | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `tree_training_mode` | string | `"disabled"` | Tree training mode. 'sparse' enables tree training with Flex Attention module (flex attention), 'dta' enables Dynamic Tree Attention (dynamic tree training), 'disabled' disables tree training. **Choices:** `disabled`, `sparse`, `dta` | +| `enable_tree_training` | boolean | `False` | \[DEPRECATED\] Use tree_training_mode instead. enable_tree_training=True maps to tree_training_mode='sparse'. If both are set, tree_training_mode takes precedence. | +| `dta_block_size` | integer | `2048` | Block size for Dynamic Tree Attention. Set to -1 to disable block-size limit. Only effective when tree_training_mode='dta'. | +| `packing_algorithm` | string | `"ffd"` | Trajectory packing across data-parallel ranks during distributed rollout (`redistribute_trajectories`). 'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order n_tree_tokens. Not to be confused with `mb_spec.packing_algorithm`, which only controls micro-batch formation (ffd/kk) during training. **Choices:** `ffd`, `kk`, `dta` | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -414,45 +417,48 @@ Configuration for PPO actor model, a subclass of a TrainEngine. Configuration for PPO critic model, a subclass of a TrainEngine. -| Parameter | Type | Default | Description | -| ------------------------ | --------------------------------------------------- | --------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string | **Required** | - | -| `trial_name` | string | **Required** | - | -| `path` | string | `""` | Path to HuggingFace checkpoint | -| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | -| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | -| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | -| `is_critic` | boolean | `False` | Whether to use a critic/reward model | -| `temperature` | float | `1.0` | Temperature during generation. | -| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | -| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | -| `disable_dropout` | boolean | `False` | Disable dropout layers during training | -| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | -| `dtype` | string | `"bfloat16"` | Parameter data type. | -| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | -| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | -| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | -| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | -| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | -| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | -| `lora_rank` | integer | `32` | lora rank | -| `lora_alpha` | integer | `16` | lora alpha | -| `target_modules` | list of string | **Required** | lora target_modules. | -| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | -| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | -| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | -| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | -| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | -| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | -| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | -| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | -| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | -| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | -| `eps_clip` | float | `0.5` | Clipping factor for value loss | -| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | +| Parameter | Type | Default | Description | +| ------------------------ | --------------------------------------------------- | --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `experiment_name` | string | **Required** | - | +| `trial_name` | string | **Required** | - | +| `path` | string | `""` | Path to HuggingFace checkpoint | +| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | +| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | +| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | +| `is_critic` | boolean | `False` | Whether to use a critic/reward model | +| `temperature` | float | `1.0` | Temperature during generation. | +| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | +| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | +| `disable_dropout` | boolean | `False` | Disable dropout layers during training | +| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | +| `dtype` | string | `"bfloat16"` | Parameter data type. | +| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | +| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | +| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | +| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | +| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | +| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `lora_rank` | integer | `32` | lora rank | +| `lora_alpha` | integer | `16` | lora alpha | +| `target_modules` | list of string | **Required** | lora target_modules. | +| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | +| `tree_training_mode` | string | `"disabled"` | Tree training mode. 'sparse' enables tree training with Flex Attention module (flex attention), 'dta' enables Dynamic Tree Attention (dynamic tree training), 'disabled' disables tree training. **Choices:** `disabled`, `sparse`, `dta` | +| `enable_tree_training` | boolean | `False` | \[DEPRECATED\] Use tree_training_mode instead. enable_tree_training=True maps to tree_training_mode='sparse'. If both are set, tree_training_mode takes precedence. | +| `dta_block_size` | integer | `2048` | Block size for Dynamic Tree Attention. Set to -1 to disable block-size limit. Only effective when tree_training_mode='dta'. | +| `packing_algorithm` | string | `"ffd"` | Trajectory packing across data-parallel ranks during distributed rollout (`redistribute_trajectories`). 'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order n_tree_tokens. Not to be confused with `mb_spec.packing_algorithm`, which only controls micro-batch formation (ffd/kk) during training. **Choices:** `ffd`, `kk`, `dta` | +| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | +| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | +| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | +| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | +| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | +| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | +| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | +| `eps_clip` | float | `0.5` | Clipping factor for value loss | +| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | (section-train-engine)= @@ -460,42 +466,45 @@ Configuration for PPO critic model, a subclass of a TrainEngine. Core configuration for model training, including optimization and backend settings. -| Parameter | Type | Default | Description | -| ------------------------ | --------------------------------------------------- | --------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string | **Required** | - | -| `trial_name` | string | **Required** | - | -| `path` | string | `""` | Path to HuggingFace checkpoint | -| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | -| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | -| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | -| `is_critic` | boolean | `False` | Whether to use a critic/reward model | -| `temperature` | float | `1.0` | Temperature during generation. | -| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | -| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | -| `disable_dropout` | boolean | `False` | Disable dropout layers during training | -| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | -| `dtype` | string | `"bfloat16"` | Parameter data type. | -| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | -| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | -| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | -| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | -| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | -| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | -| `lora_rank` | integer | `32` | lora rank | -| `lora_alpha` | integer | `16` | lora alpha | -| `target_modules` | list of string | **Required** | lora target_modules. | -| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | -| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | -| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | -| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | -| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | -| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | -| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | -| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | -| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | +| Parameter | Type | Default | Description | +| ------------------------ | --------------------------------------------------- | --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `experiment_name` | string | **Required** | - | +| `trial_name` | string | **Required** | - | +| `path` | string | `""` | Path to HuggingFace checkpoint | +| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | +| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | +| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | +| `is_critic` | boolean | `False` | Whether to use a critic/reward model | +| `temperature` | float | `1.0` | Temperature during generation. | +| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | +| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | +| `disable_dropout` | boolean | `False` | Disable dropout layers during training | +| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | +| `dtype` | string | `"bfloat16"` | Parameter data type. | +| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | +| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | +| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | +| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | +| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | +| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `lora_rank` | integer | `32` | lora rank | +| `lora_alpha` | integer | `16` | lora alpha | +| `target_modules` | list of string | **Required** | lora target_modules. | +| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | +| `tree_training_mode` | string | `"disabled"` | Tree training mode. 'sparse' enables tree training with Flex Attention module (flex attention), 'dta' enables Dynamic Tree Attention (dynamic tree training), 'disabled' disables tree training. **Choices:** `disabled`, `sparse`, `dta` | +| `enable_tree_training` | boolean | `False` | \[DEPRECATED\] Use tree_training_mode instead. enable_tree_training=True maps to tree_training_mode='sparse'. If both are set, tree_training_mode takes precedence. | +| `dta_block_size` | integer | `2048` | Block size for Dynamic Tree Attention. Set to -1 to disable block-size limit. Only effective when tree_training_mode='dta'. | +| `packing_algorithm` | string | `"ffd"` | Trajectory packing across data-parallel ranks during distributed rollout (`redistribute_trajectories`). 'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order n_tree_tokens. Not to be confused with `mb_spec.packing_algorithm`, which only controls micro-batch formation (ffd/kk) during training. **Choices:** `ffd`, `kk`, `dta` | +| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | +| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | +| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | +| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | +| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | +| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | (section-generation-hyperparameters)= @@ -966,44 +975,47 @@ Configuration for Direct Preference Optimization (DPO) experiments. Engine configuration for DPO training, extending TrainEngineConfig with DPO-specific fields. -| Parameter | Type | Default | Description | -| ------------------------ | --------------------------------------------------- | --------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string | **Required** | - | -| `trial_name` | string | **Required** | - | -| `path` | string | `""` | Path to HuggingFace checkpoint | -| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | -| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | -| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | -| `is_critic` | boolean | `False` | Whether to use a critic/reward model | -| `temperature` | float | `1.0` | Temperature during generation. | -| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | -| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | -| `disable_dropout` | boolean | `False` | Disable dropout layers during training | -| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | -| `dtype` | string | `"bfloat16"` | Parameter data type. | -| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | -| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | -| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | -| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | -| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | -| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | -| `lora_rank` | integer | `32` | lora rank | -| `lora_alpha` | integer | `16` | lora alpha | -| `target_modules` | list of string | **Required** | lora target_modules. | -| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | -| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | -| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | -| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | -| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | -| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | -| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | -| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | -| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | -| `beta` | float | `0.1` | KL penalty coefficient for DPO loss. | -| `loss_type` | string | `"sigmoid"` | DPO loss variant. 'sigmoid': original DPO loss (Rafailov et al. 2023). 'ipo': Identity Preference Optimization with per-token length normalization (Azar et al. 2023). **Choices:** `sigmoid`, `ipo` | +| Parameter | Type | Default | Description | +| ------------------------ | --------------------------------------------------- | --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `experiment_name` | string | **Required** | - | +| `trial_name` | string | **Required** | - | +| `path` | string | `""` | Path to HuggingFace checkpoint | +| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | +| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | +| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | +| `is_critic` | boolean | `False` | Whether to use a critic/reward model | +| `temperature` | float | `1.0` | Temperature during generation. | +| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | +| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | +| `disable_dropout` | boolean | `False` | Disable dropout layers during training | +| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | +| `dtype` | string | `"bfloat16"` | Parameter data type. | +| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | +| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | +| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | +| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | +| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | +| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `lora_rank` | integer | `32` | lora rank | +| `lora_alpha` | integer | `16` | lora alpha | +| `target_modules` | list of string | **Required** | lora target_modules. | +| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | +| `tree_training_mode` | string | `"disabled"` | Tree training mode. 'sparse' enables tree training with Flex Attention module (flex attention), 'dta' enables Dynamic Tree Attention (dynamic tree training), 'disabled' disables tree training. **Choices:** `disabled`, `sparse`, `dta` | +| `enable_tree_training` | boolean | `False` | \[DEPRECATED\] Use tree_training_mode instead. enable_tree_training=True maps to tree_training_mode='sparse'. If both are set, tree_training_mode takes precedence. | +| `dta_block_size` | integer | `2048` | Block size for Dynamic Tree Attention. Set to -1 to disable block-size limit. Only effective when tree_training_mode='dta'. | +| `packing_algorithm` | string | `"ffd"` | Trajectory packing across data-parallel ranks during distributed rollout (`redistribute_trajectories`). 'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order n_tree_tokens. Not to be confused with `mb_spec.packing_algorithm`, which only controls micro-batch formation (ffd/kk) during training. **Choices:** `ffd`, `kk`, `dta` | +| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | +| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | +| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | +| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | +| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | +| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | +| `beta` | float | `0.1` | KL penalty coefficient for DPO loss. | +| `loss_type` | string | `"sigmoid"` | DPO loss variant. 'sigmoid': original DPO loss (Rafailov et al. 2023). 'ipo': Identity Preference Optimization with per-token length normalization (Azar et al. 2023). **Choices:** `sigmoid`, `ipo` | (section-distributed-data-parallel)= @@ -1256,7 +1268,10 @@ Configuration class: TeacherConfig | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `tree_training_mode` | string | `"disabled"` | Tree training mode. 'sparse' enables tree training with Flex Attention module (flex attention), 'dta' enables Dynamic Tree Attention (dynamic tree training), 'disabled' disables tree training. **Choices:** `disabled`, `sparse`, `dta` | +| `enable_tree_training` | boolean | `False` | \[DEPRECATED\] Use tree_training_mode instead. enable_tree_training=True maps to tree_training_mode='sparse'. If both are set, tree_training_mode takes precedence. | +| `dta_block_size` | integer | `2048` | Block size for Dynamic Tree Attention. Set to -1 to disable block-size limit. Only effective when tree_training_mode='dta'. | +| `packing_algorithm` | string | `"ffd"` | Trajectory packing across data-parallel ranks during distributed rollout (`redistribute_trajectories`). 'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order n_tree_tokens. Not to be confused with `mb_spec.packing_algorithm`, which only controls micro-batch formation (ffd/kk) during training. **Choices:** `ffd`, `kk`, `dta` | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | diff --git a/docs/figures/experimental/gsm8k_ppo_dta_reward_compare.png b/docs/figures/experimental/gsm8k_ppo_dta_reward_compare.png new file mode 100644 index 0000000000..68d74aa939 Binary files /dev/null and b/docs/figures/experimental/gsm8k_ppo_dta_reward_compare.png differ diff --git a/docs/testing/archon_training_test.md b/docs/testing/archon_training_test.md new file mode 100644 index 0000000000..75519f5c78 --- /dev/null +++ b/docs/testing/archon_training_test.md @@ -0,0 +1,318 @@ +# ArchonEngine 训练侧测试工具 + +本文档介绍 `tests/experimental/archon/torchrun/` 下新增的 ArchonEngine 训练侧 对拍工具,包括三个文件: + +| 文件 | 作用 | +| ----------------------------- | -------------------------------------------------------------------- | +| `training_test_config.py` | YAML + CLI 覆盖配置加载器,定义 `ArchonTrainingTestConfig` | +| `run_archon_training_test.py` | `torchrun` 入口脚本,运行 N 步训练并落盘全局 loss / 显存 / `diff.pt` | +| `compare_training_dumps.py` | 离线对拍脚本,比较两个 dump 目录的 per-step loss 与 `diff.pt` 差异 | +| `archon_training_test.yaml` | 示例配置 | + +目标场景:在同一份配置上跑两次(比如 DTA 开启 vs 关闭、或两种 rollout `packing_algorithm`:`ffd`/`kk`/`dta`), 然后用 +`compare_training_dumps.py` 做 "对拍",验证训练逻辑的等价性 / 回归性。 + +______________________________________________________________________ + +## 1. 总体流程 + +``` + +-------------------+ torchrun +-------------------+ + | archon_*.yaml | ────────────────────▶| run_archon_training_test.py | + | + CLI overrides | +-----┬-------------+ + +-------------------+ │ + ▼ + / + ├── stats.jsonl (每 step 一行全局 JSON) + ├── diff.pt (rank 0, 参数更新统计) + + 两次跑完后 → compare_training_dumps.py --dump-a --dump-b +``` + +______________________________________________________________________ + +## 2. 配置:`training_test_config.py` + +该工具现在只支持:**普通 AReaL YAML + `test_config`**。 + +```yaml +experiment_name: xxx +trial_name: xxx +cluster: + fileroot: xxx +actor: # AReaL 标准 actor 段(含 backend) +test_config: # 本工具专属的测试参数 +``` + +- 不再支持测试专用顶层 `engine`/`parallel`。 +- 并行策略统一从 `actor.backend` 解析(例如 `archon:d8`)。 + +### 2.1 `test_config` 字段 + +| 字段 | 类型 | 默认值 | 说明 | +| --------------------- | ------ | ---------- | ------------------------------------------------------------------------------------------- | +| `step` | `int` | 必填(>0) | 训练迭代步数;不复用 AReaL 原生 epoch 相关配置 | +| `data_dir` | `str` | 必填 | 放一组 `.pt` 文件的目录,每个文件是 `list[1-D Tensor]` 形式的 `input_ids`;按字典序循环使用 | +| `disable_optimizer` | `bool` | `False` | 开启后 engine 不创建优化器、不更新参数,也不分配优化器状态显存 | +| `save_diff` | `bool` | `True` | 训练结束后在 rank 0 保存 `diff.pt`(参数更新统计) | +| `save_params` | `bool` | `False` | 兼容旧字段;低显存模式下忽略,不再导出 `params.pt` | +| `save_initial_params` | `bool` | `False` | 兼容旧字段;低显存模式下忽略,不再导出 `params_initial.pt` | +| `seed` | `int` | `42` | 构造 `advantages`/`logprobs` 等合成字段的随机种子基值 | + +### 2.2 配置加载器做了什么 + +- 使用 OmegaConf 做 YAML 读取 + dotlist 覆盖(`test_config.step=5`、 + `actor.mb_spec.max_tokens_per_mb=8192` 等语法)。 +- 并行策略固定从 `actor.backend` 读取并解析(复用 `AllocationMode.from_str`)。 +- 训练 backend 当前只允许 `archon`;如果字符串解析成 `fsdp`/`megatron` 会直接报错。 +- 输出目录自动对齐普通训练日志根目录: + `/logs////__`。 +- `fileroot` 优先级:`stats_logger.fileroot > cluster.fileroot`。 +- 手写了一个 dataclass 构造器 `_build_dataclass` / `_coerce_value`,递归把 `DictConfig` 子节点塞进 + `TrainEngineConfig` 等结构。原因是 `OmegaConf.structured(TrainEngineConfig)` 会在 + `Literal[...]` 字段 (例如 `tree_training_mode: Literal["disabled", "sparse", "dta"]`)上报 + `ValidationError`。 +- `actor` 中 `TrainEngineConfig` 未使用的字段(例如 PPO/GRPO 专属字段)会自动过滤,便于直接复用普通 YAML。 + +### 2.3 示例 YAML + +`tests/experimental/archon/torchrun/archon_training_test.yaml` 是一份可直接运行 的模板(以 +Qwen2.5-0.5B-Instruct + DTA + dp=2 为例): + +```yaml +experiment_name: archon_train_test +trial_name: trial0 +cluster: + fileroot: /storage/openpsi/experiments +actor: + backend: archon:d2 + path: /storage/openpsi/models/Qwen__Qwen2.5-0.5B-Instruct/ + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 5596 + optimizer: + type: adam + lr: 1.0e-5 + weight_decay: 0.01 + lr_scheduler_type: constant + gradient_clipping: 1.0 + tree_training_mode: dta # {disabled, sparse, dta} + dta_block_size: 2048 + packing_algorithm: ffd # rollout {ffd, kk, dta};与 mb_spec.packing_algorithm 不同 + +test_config: + step: 4 + data_dir: "" # 必须通过 CLI 覆盖 + disable_optimizer: false + save_diff: true + seed: 42 +``` + +______________________________________________________________________ + +## 3. 数据格式 + +`data_dir` 下每个 `.pt` 文件对应 **一个训练步** 的候选输入。格式约定: + +```python +torch.save([ + torch.tensor([id0, id1, ...], dtype=torch.long), # 第 0 条 1-D input_ids + torch.tensor([id0, id1, ...], dtype=torch.long), # 第 1 条 + ... +], "step_000.pt") +``` + +- 每个张量必须是 1-D `input_ids`(不包含 attention_mask、label 等)。 +- 文件按 **字典序** 排序,训练时 `step_idx` 取模轮转使用。 +- 序列数量必须 `>= dp_world_size`,否则抛错;多出来的尾巴会按 `len // dp_world_size * dp_world_size` 截断以保证每个 + rank 供给同样多条 `redistribute_trajectories` 预期下的输入。 + +每条 `input_ids` 会被自动补齐成一个 GRPO 用的 trajectory dict: `attention_mask` 全 1、`loss_mask` 前 30% +token 置 0 后半置 1(当前固定比例)、 `logprobs/old_logprobs/advantages/rewards/values/prox_logp` 由 +`seed + step*100003 + global_idx` 确定,**保证同 step、同条序列在所有 rank 上生成的合成字段一致**。 + +______________________________________________________________________ + +## 4. 训练入口:`run_archon_training_test.py` + +### 4.1 启动命令 + +```bash +torchrun --nproc_per_node=$NPROC \ + tests/experimental/archon/torchrun/run_archon_training_test.py \ + --config tests/experimental/archon/torchrun/archon_training_test.yaml \ + test_config.step=4 \ + test_config.data_dir=/path/to/data_dir +``` + +`--config` 后面的参数为 OmegaConf dotlist 覆盖,想改任意 `actor.*` / `test_config.*` 直接写 `key=value` +即可。 + +直接复用普通训练 YAML 即可(只额外补 `test_config.*` 覆盖): + +```bash +torchrun --nproc_per_node=8 \ + tests/experimental/archon/torchrun/run_archon_training_test.py \ + --config examples/math/gsm8k_sft_archon_fp8.yaml \ + test_config.step=4 \ + test_config.data_dir=/path/to/data_dir +``` + +### 4.2 每一步做的事 + +1. 按字典序取 `step_files[step % len]`,`torch.load` 成 `list[Tensor]`。 + +1. 每个 rank 按 `dp_rank::dp_world_size` 步长选自己的本地子集,用上述合成字段 构造 trajectory dict。 + +1. `redistribute_trajectories(..., packing_algorithm=engine.config.packing_algorithm)` + 按配置(`ffd` / `kk` / `dta`)重新分配到各 rank。 + +1. `torch.cuda.reset_peak_memory_stats()` → `engine.train_batch(...)` → 计时。 + +1. 对每个 rank 的局部统计做 all-reduce 聚合后,记录一条 **全局** JSON 到 `/stats.jsonl`: + + ```json + { + "step": 0, "file": "/.../000.pt", "world_size": 2, "dp_world_size": 2, + "num_global_sequences": 8, "num_global_tokens": 30568, + "elapsed_s_max": 1.234, "peak_mem_mib_max": 8192.5, + "loss": 0.1234, "loss_source": "train_batch_return_global_token_weighted", + "grad_norm_max": 0.56, "update_successful": 1.0, "lr_max": 1.0e-5 + } + ``` + +### 4.3 loss 怎么抓 + +- 通过 `engine.train_batch(..., return_loss=True)` 直接拿每个 rank 的本地 step loss。 +- 不再依赖 monkey-patch 或额外 side-channel 采样。 +- 再按 token 做跨 rank 加权平均,得到单个 `global loss` 写入 `stats.jsonl`,方便直接对拍。 + +### 4.4 `disable_optimizer=true` 的行为 + +1. 进 `_create_engine` 前把 + `engine_cfg.optimizer = None`,`ArchonLMEngine._create_optimizer` 会 early-return,因此 + **完全不分配优化器状态显存**。 +1. monkey-patch `engine.optimizer_step / optimizer_zero_grad` 为 no-op,但仍会: + - 计算 `grad_norm`(对所有参数 `.grad` 做 L2)并返回; + - 把 `param.grad` 置空,保证下一步梯度不累加。 +1. 前向 + 反向正常执行,`loss` 曲线仍有意义,只是 **参数不变**。 + +因此 `disable_optimizer=true` 场景下,`diff.pt` 的更新指标会接近 0,主要 用来对拍 "两种配置的纯前向 / 反向 loss 是否一致"。 + +### 4.5 `diff.pt` 导出(低额外显存) + +启动训练前,脚本按参数顺序调用 `.full_tensor()`,把每个参数的初始值转成 CPU `float32` 并只在 rank 0 保存。训练结束后,按相同顺序再次 +materialize 全量参数, 在 rank 0 上与初始 CPU 快照做差并计算更新统计: + +- `mean_abs_update` +- `max_abs_update` +- `l2_update` +- `rel_l2_update`(`l2_update / ||initial||_2`) + +实现上是“**一次只处理一个参数**”:虽然用了 `.full_tensor()`,但峰值额外显存被限制在 单个参数 full tensor 的量级,可读性更高,也不需要对象 +gather 逻辑。最终由 rank 0 写出 `diff.pt`: + +参数命名上,`diff.pt["params"]` 会优先使用 `state_dict_adapter.convert_single_to_hf` 生成 HuggingFace +key;若无法映射,则回退到去掉 wrapper 前缀(如 `._orig_mod`)后的 Archon 原始 key。 + +``` +/stats.jsonl # 每 step 全局汇总 +/diff.pt # 参数更新统计(非全量参数) +``` + +______________________________________________________________________ + +## 5. 对拍入口:`compare_training_dumps.py` + +用法: + +```bash +python tests/experimental/archon/torchrun/compare_training_dumps.py \ + --dump-a \ + --dump-b \ + --loss-atol 1e-6 \ + --loss-rtol 1e-6 \ + --compare-initial # 可选:仅对 legacy params_initial.pt 有效 +``` + +### 5.1 loss 严格对拍 + +- 从两边 `/stats.jsonl` 读所有记录,按 `step` 分组。 +- 每 step 直接比较全局 `loss`。 +- 判定公式:`|loss_a - loss_b| <= atol + rtol * |loss_b|`。 +- 打印每一步的 `loss_a / loss_b / abs_gap / rel_gap / OK|MISMATCH`,再给出 整体 `PASS / FAIL`。 +- 任一步 `MISMATCH` 脚本 `exit code = 1`,方便接入 CI。 + +### 5.2 `diff.pt` 差异(信息性,**不做强对齐**) + +默认加载两边 `diff.pt`,逐参数比较更新统计 gap: + +| 指标 | 含义 | +| -------------- | ---------------- | +| `max_abs_gap` | \` | +| `mean_abs_gap` | \` | +| `l2_gap` | \` | +| `rel_l2_gap` | \` | +| `numel_match` | `numel` 是否一致 | + +脚本会打印全局汇总 + top-K gap 最大张量。若两边都没有 `diff.pt`,会自动回退到旧版 `params.pt` 比较逻辑。 + +______________________________________________________________________ + +## 6. 典型使用姿势 + +### 6.1 DTA 开关对拍 + +```bash +# Run A: DTA 开启 +torchrun --nproc_per_node=2 \ + tests/experimental/archon/torchrun/run_archon_training_test.py \ + --config tests/experimental/archon/torchrun/archon_training_test.yaml \ + test_config.step=8 \ + test_config.data_dir=/data/token_samples \ + test_config.disable_optimizer=true + +# Run B: DTA 关闭 +torchrun --nproc_per_node=2 \ + tests/experimental/archon/torchrun/run_archon_training_test.py \ + --config tests/experimental/archon/torchrun/archon_training_test.yaml \ + test_config.step=8 \ + test_config.data_dir=/data/token_samples \ + test_config.disable_optimizer=true \ + actor.tree_training_mode=disabled \ + actor.packing_algorithm=ffd + +# 对拍 +python tests/experimental/archon/torchrun/compare_training_dumps.py \ + --dump-a /storage/openpsi/experiments/logs/$USER/archon_train_test/trial0/dta_d2_Qwen__Qwen2.5-0.5B-Instruct \ + --dump-b /storage/openpsi/experiments/logs/$USER/archon_train_test/trial0/disabled_d2_Qwen__Qwen2.5-0.5B-Instruct \ + --loss-atol 1e-4 --loss-rtol 1e-4 +``` + +`disable_optimizer=true` 保证两次跑的初始参数完全相同,loss 差异只来自实现 差异;tolerance 按实际算子误差调整。 + +### 6.2 训练 benchmark(实际更新参数) + +把 `disable_optimizer` 去掉并把 `save_diff=false`(仅关掉 `diff.pt` 落盘)以节省磁盘,即可拿到 `elapsed_s_max` +/ `peak_mem_mib_max` / `loss` 曲线作为性能基线。 + +______________________________________________________________________ + +## 7. 代码位置速查 + +``` +tests/experimental/archon/torchrun/ +├── archon_training_test.yaml # 示例配置 +├── training_test_config.py # 配置加载 +├── run_archon_training_test.py # 训练入口(torchrun) +└── compare_training_dumps.py # 离线对拍 +``` + +相关生产代码参考: + +- `areal/experimental/engine/archon_engine.py` -- `ArchonLMEngine.train_batch`。 +- `areal/experimental/dta/wrapper.py` -- DTA 前向/反向包装。 +- `areal/trainer/ppo/actor.py` -- `grpo_loss_fn` 签名与默认参数。 +- `areal/infra/dist_rollout.py` -- `redistribute_trajectories`。 +- `areal/api/cli_args.py` -- `TrainEngineConfig` 与树训练相关字段。 +- `tests/experimental/archon/test_dta.py` -- 原有 DTA 单测中的对拍模式,供参考。 diff --git a/docs/zh/algorithms/grpo_series.md b/docs/zh/algorithms/grpo_series.md index d882b8a6e8..ede028bbc2 100644 --- a/docs/zh/algorithms/grpo_series.md +++ b/docs/zh/algorithms/grpo_series.md @@ -108,10 +108,10 @@ actor: | `eps_clip` | float | `0.2` | 下裁剪边界:比率裁剪到 `[1-eps_clip, ...]` | | `eps_clip_higher` | float \| None | `None` | 上裁剪边界:设置时,比率裁剪到 `[1-eps_clip, 1+eps_clip_higher]` | -当 `eps_clip_higher` 为 `None` 时,使用对称裁剪: $\text{clip}(r, 1-\epsilon, 1+\epsilon)$。 +当 `eps_clip_higher` 为 `None` 时,使用对称裁剪: $\\text{clip}(r, 1-\\epsilon, 1+\\epsilon)$。 -当设置 `eps_clip_higher` 时(DAPO风格),使用非对称裁剪: $\text{clip}(r, 1-\epsilon_{\text{low}}, -1+\epsilon_{\text{high}})$。 +当设置 `eps_clip_higher` 时(DAPO风格),使用非对称裁剪: $\\text{clip}(r, 1-\\epsilon\_{\\text{low}}, +1+\\epsilon\_{\\text{high}})$。 ### 重要性采样级别(`actor.importance_sampling_level`) @@ -197,7 +197,7 @@ SAPO用软sigmoid门替换PPO的硬裁剪,提供平滑梯度和非对称控制 **标准PPO:** -$$ L^{\text{PPO}} = -\mathbb{E}_t[\min(r_t A_t, r_t^{\text{clip}} A_t)] $$ +$$ L^{\\text{PPO}} = -\\mathbb{E}\_t\[\\min(r_t A_t, r_t^{\\text{clip}} A_t)\] $$ **SAPO(带软门):** @@ -235,7 +235,7 @@ r_{i,t}(\theta) \hat{A}_{i,t}, \text{clip}\left( r_{i,t}(\theta), 1-\epsilon_{\text{low}}, 1+\epsilon_{\text{high}} \right) \hat{A}_{i,t} \right) \right] $$ -其中 $\hat{A}_{i,t}$ 是分组归一化优势,$r_{i,t}(\theta)$ 是token级策略比率。 +其中 $\\hat{A}_{i,t}$ 是分组归一化优势,$r_{i,t}(\\theta)$ 是token级策略比率。 **非对称裁剪参数:** diff --git a/docs/zh/cli_reference.md b/docs/zh/cli_reference.md index ff6a4aac24..ea45b5539a 100644 --- a/docs/zh/cli_reference.md +++ b/docs/zh/cli_reference.md @@ -367,7 +367,10 @@ Configuration for PPO actor model, a subclass of a TrainEngine. | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `tree_training_mode` | string | `"disabled"` | Tree training mode. 'sparse' enables tree training with Flex Attention module (flex attention), 'dta' enables Dynamic Tree Attention (dynamic tree training), 'disabled' disables tree training. **Choices:** `disabled`, `sparse`, `dta` | +| `enable_tree_training` | boolean | `False` | \[DEPRECATED\] Use tree_training_mode instead. enable_tree_training=True maps to tree_training_mode='sparse'. If both are set, tree_training_mode takes precedence. | +| `dta_block_size` | integer | `2048` | Block size for Dynamic Tree Attention. Set to -1 to disable block-size limit. Only effective when tree_training_mode='dta'. | +| `packing_algorithm` | string | `"ffd"` | Trajectory packing across data-parallel ranks during distributed rollout (`redistribute_trajectories`). 'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order n_tree_tokens. Not to be confused with `mb_spec.packing_algorithm`, which only controls micro-batch formation (ffd/kk) during training. **Choices:** `ffd`, `kk`, `dta` | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -412,45 +415,48 @@ Configuration for PPO actor model, a subclass of a TrainEngine. Configuration for PPO critic model, a subclass of a TrainEngine. -| Parameter | Type | Default | Description | -| ------------------------ | --------------------------------------------------- | --------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string | **Required** | - | -| `trial_name` | string | **Required** | - | -| `path` | string | `""` | Path to HuggingFace checkpoint | -| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | -| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | -| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | -| `is_critic` | boolean | `False` | Whether to use a critic/reward model | -| `temperature` | float | `1.0` | Temperature during generation. | -| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | -| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | -| `disable_dropout` | boolean | `False` | Disable dropout layers during training | -| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | -| `dtype` | string | `"bfloat16"` | Parameter data type. | -| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | -| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | -| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | -| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | -| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | -| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | -| `lora_rank` | integer | `32` | lora rank | -| `lora_alpha` | integer | `16` | lora alpha | -| `target_modules` | list of string | **Required** | lora target_modules. | -| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | -| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | -| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | -| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | -| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | -| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | -| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | -| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | -| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | -| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | -| `eps_clip` | float | `0.5` | Clipping factor for value loss | -| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | +| Parameter | Type | Default | Description | +| ------------------------ | --------------------------------------------------- | --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `experiment_name` | string | **Required** | - | +| `trial_name` | string | **Required** | - | +| `path` | string | `""` | Path to HuggingFace checkpoint | +| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | +| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | +| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | +| `is_critic` | boolean | `False` | Whether to use a critic/reward model | +| `temperature` | float | `1.0` | Temperature during generation. | +| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | +| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | +| `disable_dropout` | boolean | `False` | Disable dropout layers during training | +| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | +| `dtype` | string | `"bfloat16"` | Parameter data type. | +| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | +| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | +| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | +| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | +| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | +| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `lora_rank` | integer | `32` | lora rank | +| `lora_alpha` | integer | `16` | lora alpha | +| `target_modules` | list of string | **Required** | lora target_modules. | +| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | +| `tree_training_mode` | string | `"disabled"` | Tree training mode. 'sparse' enables tree training with Flex Attention module (flex attention), 'dta' enables Dynamic Tree Attention (dynamic tree training), 'disabled' disables tree training. **Choices:** `disabled`, `sparse`, `dta` | +| `enable_tree_training` | boolean | `False` | \[DEPRECATED\] Use tree_training_mode instead. enable_tree_training=True maps to tree_training_mode='sparse'. If both are set, tree_training_mode takes precedence. | +| `dta_block_size` | integer | `2048` | Block size for Dynamic Tree Attention. Set to -1 to disable block-size limit. Only effective when tree_training_mode='dta'. | +| `packing_algorithm` | string | `"ffd"` | Trajectory packing across data-parallel ranks during distributed rollout (`redistribute_trajectories`). 'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order n_tree_tokens. Not to be confused with `mb_spec.packing_algorithm`, which only controls micro-batch formation (ffd/kk) during training. **Choices:** `ffd`, `kk`, `dta` | +| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | +| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | +| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | +| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | +| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | +| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | +| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | +| `eps_clip` | float | `0.5` | Clipping factor for value loss | +| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | (section-train-engine)= @@ -458,42 +464,45 @@ Configuration for PPO critic model, a subclass of a TrainEngine. Core configuration for model training, including optimization and backend settings. -| Parameter | Type | Default | Description | -| ------------------------ | --------------------------------------------------- | --------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string | **Required** | - | -| `trial_name` | string | **Required** | - | -| `path` | string | `""` | Path to HuggingFace checkpoint | -| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | -| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | -| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | -| `is_critic` | boolean | `False` | Whether to use a critic/reward model | -| `temperature` | float | `1.0` | Temperature during generation. | -| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | -| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | -| `disable_dropout` | boolean | `False` | Disable dropout layers during training | -| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | -| `dtype` | string | `"bfloat16"` | Parameter data type. | -| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | -| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | -| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | -| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | -| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | -| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | -| `lora_rank` | integer | `32` | lora rank | -| `lora_alpha` | integer | `16` | lora alpha | -| `target_modules` | list of string | **Required** | lora target_modules. | -| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | -| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | -| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | -| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | -| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | -| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | -| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | -| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | -| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | +| Parameter | Type | Default | Description | +| ------------------------ | --------------------------------------------------- | --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `experiment_name` | string | **Required** | - | +| `trial_name` | string | **Required** | - | +| `path` | string | `""` | Path to HuggingFace checkpoint | +| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | +| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | +| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | +| `is_critic` | boolean | `False` | Whether to use a critic/reward model | +| `temperature` | float | `1.0` | Temperature during generation. | +| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | +| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | +| `disable_dropout` | boolean | `False` | Disable dropout layers during training | +| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | +| `dtype` | string | `"bfloat16"` | Parameter data type. | +| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | +| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | +| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | +| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | +| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | +| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `lora_rank` | integer | `32` | lora rank | +| `lora_alpha` | integer | `16` | lora alpha | +| `target_modules` | list of string | **Required** | lora target_modules. | +| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | +| `tree_training_mode` | string | `"disabled"` | Tree training mode. 'sparse' enables tree training with Flex Attention module (flex attention), 'dta' enables Dynamic Tree Attention (dynamic tree training), 'disabled' disables tree training. **Choices:** `disabled`, `sparse`, `dta` | +| `enable_tree_training` | boolean | `False` | \[DEPRECATED\] Use tree_training_mode instead. enable_tree_training=True maps to tree_training_mode='sparse'. If both are set, tree_training_mode takes precedence. | +| `dta_block_size` | integer | `2048` | Block size for Dynamic Tree Attention. Set to -1 to disable block-size limit. Only effective when tree_training_mode='dta'. | +| `packing_algorithm` | string | `"ffd"` | Trajectory packing across data-parallel ranks during distributed rollout (`redistribute_trajectories`). 'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order n_tree_tokens. Not to be confused with `mb_spec.packing_algorithm`, which only controls micro-batch formation (ffd/kk) during training. **Choices:** `ffd`, `kk`, `dta` | +| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | +| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | +| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | +| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | +| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | +| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | (section-generation-hyperparameters)= @@ -964,44 +973,47 @@ Configuration for Direct Preference Optimization (DPO) experiments. Engine configuration for DPO training, extending TrainEngineConfig with DPO-specific fields. -| Parameter | Type | Default | Description | -| ------------------------ | --------------------------------------------------- | --------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string | **Required** | - | -| `trial_name` | string | **Required** | - | -| `path` | string | `""` | Path to HuggingFace checkpoint | -| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | -| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | -| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | -| `is_critic` | boolean | `False` | Whether to use a critic/reward model | -| `temperature` | float | `1.0` | Temperature during generation. | -| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | -| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | -| `disable_dropout` | boolean | `False` | Disable dropout layers during training | -| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | -| `dtype` | string | `"bfloat16"` | Parameter data type. | -| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | -| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | -| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | -| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | -| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | -| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | -| `lora_rank` | integer | `32` | lora rank | -| `lora_alpha` | integer | `16` | lora alpha | -| `target_modules` | list of string | **Required** | lora target_modules. | -| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | -| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | -| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | -| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | -| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | -| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | -| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | -| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | -| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | -| `beta` | float | `0.1` | KL penalty coefficient for DPO loss. | -| `loss_type` | string | `"sigmoid"` | DPO loss variant. 'sigmoid': original DPO loss (Rafailov et al. 2023). 'ipo': Identity Preference Optimization with per-token length normalization (Azar et al. 2023). **Choices:** `sigmoid`, `ipo` | +| Parameter | Type | Default | Description | +| ------------------------ | --------------------------------------------------- | --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `experiment_name` | string | **Required** | - | +| `trial_name` | string | **Required** | - | +| `path` | string | `""` | Path to HuggingFace checkpoint | +| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | +| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | +| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | +| `is_critic` | boolean | `False` | Whether to use a critic/reward model | +| `temperature` | float | `1.0` | Temperature during generation. | +| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | +| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | +| `disable_dropout` | boolean | `False` | Disable dropout layers during training | +| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | +| `dtype` | string | `"bfloat16"` | Parameter data type. | +| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | +| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | +| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | +| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | +| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | +| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `lora_rank` | integer | `32` | lora rank | +| `lora_alpha` | integer | `16` | lora alpha | +| `target_modules` | list of string | **Required** | lora target_modules. | +| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | +| `tree_training_mode` | string | `"disabled"` | Tree training mode. 'sparse' enables tree training with Flex Attention module (flex attention), 'dta' enables Dynamic Tree Attention (dynamic tree training), 'disabled' disables tree training. **Choices:** `disabled`, `sparse`, `dta` | +| `enable_tree_training` | boolean | `False` | \[DEPRECATED\] Use tree_training_mode instead. enable_tree_training=True maps to tree_training_mode='sparse'. If both are set, tree_training_mode takes precedence. | +| `dta_block_size` | integer | `2048` | Block size for Dynamic Tree Attention. Set to -1 to disable block-size limit. Only effective when tree_training_mode='dta'. | +| `packing_algorithm` | string | `"ffd"` | Trajectory packing across data-parallel ranks during distributed rollout (`redistribute_trajectories`). 'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order n_tree_tokens. Not to be confused with `mb_spec.packing_algorithm`, which only controls micro-batch formation (ffd/kk) during training. **Choices:** `ffd`, `kk`, `dta` | +| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | +| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | +| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | +| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | +| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | +| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | +| `beta` | float | `0.1` | KL penalty coefficient for DPO loss. | +| `loss_type` | string | `"sigmoid"` | DPO loss variant. 'sigmoid': original DPO loss (Rafailov et al. 2023). 'ipo': Identity Preference Optimization with per-token length normalization (Azar et al. 2023). **Choices:** `sigmoid`, `ipo` | (section-distributed-data-parallel)= @@ -1254,7 +1266,10 @@ Configuration class: TeacherConfig | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `tree_training_mode` | string | `"disabled"` | Tree training mode. 'sparse' enables tree training with Flex Attention module (flex attention), 'dta' enables Dynamic Tree Attention (dynamic tree training), 'disabled' disables tree training. **Choices:** `disabled`, `sparse`, `dta` | +| `enable_tree_training` | boolean | `False` | \[DEPRECATED\] Use tree_training_mode instead. enable_tree_training=True maps to tree_training_mode='sparse'. If both are set, tree_training_mode takes precedence. | +| `dta_block_size` | integer | `2048` | Block size for Dynamic Tree Attention. Set to -1 to disable block-size limit. Only effective when tree_training_mode='dta'. | +| `packing_algorithm` | string | `"ffd"` | Trajectory packing across data-parallel ranks during distributed rollout (`redistribute_trajectories`). 'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order n_tree_tokens. Not to be confused with `mb_spec.packing_algorithm`, which only controls micro-batch formation (ffd/kk) during training. **Choices:** `ffd`, `kk`, `dta` | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | diff --git a/examples/experimental/dta/gsm8k_grpo.yaml b/examples/experimental/dta/gsm8k_grpo.yaml new file mode 100644 index 0000000000..9da395469c --- /dev/null +++ b/examples/experimental/dta/gsm8k_grpo.yaml @@ -0,0 +1,194 @@ +# GSM8K PPO: Archon training + DTA, decoupled 1+1 (1 GPU SGLang rollout + 1 GPU train). +# Train side: actor + critic + ref colocate on archon:d1. Rollout: sglang:d1. +# +# Requirements: +# - 2+ visible GPUs on the node; cluster.n_gpus_per_node: 2 +# - enable_offload=true with actor/critic offload uses torch_memory_saver; set up TMS +# (e.g. LD_PRELOAD) per AReaL docs before running. +# - DTA is incompatible with gradient_checkpointing on Archon; keep it false. +# +# Example: +# uv run python examples/math/gsm8k_rl.py \ +# --config examples/math/gsm8k_ppo_archon_dta_1plus1.yaml \ +# scheduler.type=local + +experiment_name: gsm8k-grpo +trial_name: trial0 + +seed: 1 +total_train_epochs: 10 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 2 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + +scheduler: + type: local + +rollout: + backend: "sglang:d1" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 64 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + +actor: + backend: "archon:d1" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: /storage/openpsi/models/Qwen__Qwen3-0.6B + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 2048 + archon: + attn_type: varlen + optimizer: + type: adam + lr: 1.70e-5 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + rejection_sampling: + metric: ratio + upper: 5.0 + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + tree_training_mode: dta + dta_block_size: 512 + packing_algorithm: dta + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: {} + +ref: + backend: ${actor.backend} + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + gradient_checkpointing: false + mb_spec: + max_tokens_per_mb: 8192 + archon: + attn_type: varlen + optimizer: null + tree_training_mode: disabled + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 32768 + gpu_memory_utilization: 0.9 + +train_dataset: + batch_size: 8 + shuffle: true + pin_memory: true + num_workers: 4 + path: /storage/openpsi/data/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 64 + pin_memory: true + num_workers: 4 + path: ${train_dataset.path} + type: rl + +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: online + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false diff --git a/examples/math/gsm8k_ppo_dta.yaml b/examples/math/gsm8k_ppo_dta.yaml new file mode 100644 index 0000000000..a644317ac1 --- /dev/null +++ b/examples/math/gsm8k_ppo_dta.yaml @@ -0,0 +1,209 @@ +experiment_name: gsm8k-ppo-dta +trial_name: trial0 + +seed: 1 +enable_offload: false +total_train_epochs: 10 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + + +scheduler: + type: null + +rollout: + backend: "sglang:d4p1t1" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + +actor: + backend: "archon:d4p1t1" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen2.5-1.5B-Instruct + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: bfloat16 + archon: + attn_type: varlen + mb_spec: + max_tokens_per_mb: 10240 + optimizer: + type: adam + lr: 1.70e-5 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + rejection_sampling: + metric: ratio + upper: 5.0 + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + tree_training_mode: dta + packing_algorithm: dta + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: {} + +critic: + backend: ${actor.backend} + is_critic: true + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: ${actor.dtype} + archon: + attn_type: varlen + eps_clip: 0.5 + ppo_n_minibatches: ${actor.ppo_n_minibatches} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: ${actor.optimizer} + tree_training_mode: dta + packing_algorithm: ${actor.packing_algorithm} + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +ref: + backend: ${actor.backend} + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: ${actor.dtype} + archon: + attn_type: varlen + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + tree_training_mode: dta + packing_algorithm: ${actor.packing_algorithm} + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +# SGLang +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 32768 + gpu_memory_utilization: 0.9 + +# datasets +train_dataset: + batch_size: 256 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 256 + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false diff --git a/examples/tau2/README.md b/examples/tau2/README.md index f7b68277e9..2ea06f9fc2 100644 --- a/examples/tau2/README.md +++ b/examples/tau2/README.md @@ -151,7 +151,7 @@ For reward curves of experiments on a larger scale, please refer to the `generated/` directory under `cluster.fileroot`. You can analyze these for debugging and evaluation. -1. **Tree training**: The configs enable `enable_tree_training=true` by default, which +1. **Tree training**: The configs use `tree_training_mode=sparse` by default, which optimizes training by sharing prefix computations across rollouts with the same prompt. This option can largely accelerate training but will possibly increase GPU memory usage if `actor.mb_spec.max_tokens_per_mb` is large. And this setting may diff --git a/examples/tau2/config_1.7b_airline.yaml b/examples/tau2/config_1.7b_airline.yaml index 95a8f8d8e6..f29f81d5c5 100644 --- a/examples/tau2/config_1.7b_airline.yaml +++ b/examples/tau2/config_1.7b_airline.yaml @@ -86,7 +86,7 @@ actor: std_level: batch max_new_tokens: ${gconfig.max_new_tokens} pad_to_maximum: true - enable_tree_training: true + tree_training_mode: sparse scheduling_spec: - task_type: worker port_count: 2 diff --git a/examples/tau2/config_235b_moe_airline.yaml b/examples/tau2/config_235b_moe_airline.yaml index b134e92294..5ea68dd7ce 100644 --- a/examples/tau2/config_235b_moe_airline.yaml +++ b/examples/tau2/config_235b_moe_airline.yaml @@ -124,7 +124,7 @@ actor: std_unbiased: true eps: 1.0e-05 max_new_tokens: ${gconfig.max_new_tokens} - enable_tree_training: false + tree_training_mode: disabled scheduling_spec: - task_type: worker port_count: 2 @@ -158,7 +158,7 @@ ref: max_tokens_per_mb: 32768 n_mbs_divisor: 1 optimizer: null - enable_tree_training: false + tree_training_mode: disabled # Tau2 environment configuration econfig: diff --git a/examples/tau2/config_30b_moe_airline.yaml b/examples/tau2/config_30b_moe_airline.yaml index 07b201fc36..a8ee9d5e1c 100644 --- a/examples/tau2/config_30b_moe_airline.yaml +++ b/examples/tau2/config_30b_moe_airline.yaml @@ -124,7 +124,7 @@ actor: std_unbiased: true eps: 1.0e-05 max_new_tokens: ${gconfig.max_new_tokens} - enable_tree_training: false + tree_training_mode: disabled scheduling_spec: - task_type: worker port_count: 2 @@ -158,7 +158,7 @@ ref: max_tokens_per_mb: 32768 n_mbs_divisor: 1 optimizer: null - enable_tree_training: false + tree_training_mode: disabled # Tau2 environment configuration econfig: diff --git a/examples/tau2/config_8b_airline.yaml b/examples/tau2/config_8b_airline.yaml index 77a366984a..6b0232d4c8 100644 --- a/examples/tau2/config_8b_airline.yaml +++ b/examples/tau2/config_8b_airline.yaml @@ -86,7 +86,7 @@ actor: std_level: batch max_new_tokens: ${gconfig.max_new_tokens} pad_to_maximum: true - enable_tree_training: true + tree_training_mode: sparse scheduling_spec: - task_type: worker port_count: 2 diff --git a/examples/tau2/dta/1.7b-dta.png b/examples/tau2/dta/1.7b-dta.png new file mode 100644 index 0000000000..084f217f05 Binary files /dev/null and b/examples/tau2/dta/1.7b-dta.png differ diff --git a/examples/tau2/dta/README.md b/examples/tau2/dta/README.md new file mode 100644 index 0000000000..4c07774ffe --- /dev/null +++ b/examples/tau2/dta/README.md @@ -0,0 +1,11 @@ +# DTA Config Notes + +This directory is reserved for TAU2 DTA-focused variants. + +Current baseline TAU2 configs under `examples/tau2/` have been aligned to: + +- `tree_training_mode: dta` +- `packing_algorithm: dta` +- `gradient_checkpointing: false` + +Add future DTA-specific ablation configs here. diff --git a/examples/tau2/dta/config_1.7b_airline_dta.yaml b/examples/tau2/dta/config_1.7b_airline_dta.yaml new file mode 100644 index 0000000000..58b037733f --- /dev/null +++ b/examples/tau2/dta/config_1.7b_airline_dta.yaml @@ -0,0 +1,206 @@ +experiment_name: tau2-rl +trial_name: 1.5b-airline + +seed: 1 +enable_offload: false +total_train_epochs: 200 +total_train_steps: 500 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /path/to/experiments # Replace with your experiment directory + name_resolve: + type: nfs + nfs_record_root: /path/to/name_resolve # Replace with your name resolve directory + + +scheduler: + type: local + +gconfig: + n_samples: 8 + min_new_tokens: 0 + max_new_tokens: 8192 + max_tokens: 16384 + greedy: false + temperature: 1.0 + +rollout: + backend: "sglang:d6" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: true + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + agent: + mode: inline + tool_call_parser: qwen25 + reasoning_parser: qwen3 + engine_max_tokens: ${gconfig.max_tokens} + export_style: individual + turn_discount: 1.0 + +actor: + backend: "archon:d2" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen3-1.7B # HuggingFace model path + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 24576 + optimizer: + type: adam + lr: 1.7e-5 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + rejection_sampling: + metric: ratio + upper: 5.0 + reward_norm: null + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + pad_to_maximum: true + tree_training_mode: dta + packing_algorithm: dta + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + cpu: 2 + mem: 16 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: + # NCCL settings - adjust based on your cluster network configuration + AREAL_PROXY_WARN_ONCE: "1" # Avoid warning spamming + +ref: + backend: ${actor.backend} + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 24576 + optimizer: null + +# Tau2 environment configuration +econfig: + domain: airline + max_steps: 50 + add_thinking_tool: false + solo_mode: false + user_llm_base_url: http://localhost:8000/v1/ # Replace with your user LLM endpoint + user_llm: openai/self-hosted-Qwen2.5-72B + user_llm_args: + temperature: 0.0 + max_completion_tokens: 512 + turn_discount: 1.0 + invalid_format_penalty: 0.1 + +# SGLang inference server configuration +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 32768 + gpu_memory_utilization: 0.9 + +# Datasets +train_dataset: + batch_size: 30 + shuffle: true + pin_memory: true + num_workers: 4 + path: tau2/train + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 30 + shuffle: true + pin_memory: true + num_workers: 4 + path: tau2/test + type: rl + drop_last: false + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: null + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: online + +perf_tracer: + enabled: false + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + save_interval: 1 + session_tracer: + enabled: true + flush_threshold: 100 diff --git a/examples/tau2/dta/config_8b_airline_dta.yaml b/examples/tau2/dta/config_8b_airline_dta.yaml new file mode 100644 index 0000000000..c5bbe20e31 --- /dev/null +++ b/examples/tau2/dta/config_8b_airline_dta.yaml @@ -0,0 +1,214 @@ +experiment_name: tau2-rl +trial_name: 7b-airline + +seed: 1 +enable_offload: false +total_train_epochs: 200 +total_train_steps: 500 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 2 + n_gpus_per_node: 8 + fileroot: /path/to/experiments # Replace with your experiment directory + name_resolve: + type: nfs + nfs_record_root: /path/to/name_resolve # Replace with your name resolve directory + + +scheduler: + type: slurm + +gconfig: + n_samples: 8 + min_new_tokens: 0 + max_new_tokens: 8192 + max_tokens: 32767 + greedy: false + temperature: 1.0 + +rollout: + backend: "sglang:d8" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + agent: + mode: inline + tool_call_parser: qwen25 + reasoning_parser: qwen3 + engine_max_tokens: ${gconfig.max_tokens} + export_style: individual + turn_discount: 1.0 + +actor: + backend: "archon:d8" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen3-8B # HuggingFace model path + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 32768 + optimizer: + type: adam + lr: 1.7e-5 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + rejection_sampling: + metric: ratio + upper: 5.0 + reward_norm: null + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + pad_to_maximum: true + tree_training_mode: dta + packing_algorithm: dta + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + cpu: 4 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + # image: /path/to/container.sif # Optional: specify container image for Slurm + env_vars: + # NCCL settings - adjust based on your cluster network configuration + NCCL_DEBUG: "WARN" + TAU2_DATA_DIR: /path/to/tau2-bench/data # Set if using tau2-bench + WANDB_BASE_URL: your-wandb-url # Optional: for W&B logging + WANDB_API_KEY: your-wandb-api-key # Optional: for W&B logging + AREAL_PROXY_WARN_ONCE: "1" # Avoid warning spamming + additional_bash_cmds: + # install tau2-bench with async generation + - uv pip install git+https://github.com/dhh1995/tau2-bench.git@dhh/async-and-custom-completion + +ref: + backend: ${actor.backend} + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 32768 + optimizer: null + +# Tau2 environment configuration +econfig: + domain: airline + max_steps: 50 + add_thinking_tool: false + solo_mode: false + user_llm_base_url: http://localhost:8000/v1/ # Replace with your user LLM endpoint + user_llm: openai/self-hosted-Qwen2.5-72B + user_llm_args: + temperature: 0.0 + max_completion_tokens: 512 + turn_discount: 1.0 + invalid_format_penalty: 0.1 + +# SGLang inference server configuration +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 32768 + gpu_memory_utilization: 0.9 + +# Datasets +train_dataset: + batch_size: 32 + shuffle: true + pin_memory: true + num_workers: 4 + path: tau2/train + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 32 + shuffle: true + pin_memory: true + num_workers: 4 + path: tau2/test + type: rl + drop_last: false + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: null + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: online + +perf_tracer: + enabled: false + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + save_interval: 1 + session_tracer: + enabled: true + flush_threshold: 100 diff --git a/tests/experimental/archon/README.md b/tests/experimental/archon/README.md new file mode 100644 index 0000000000..a984edbdbd --- /dev/null +++ b/tests/experimental/archon/README.md @@ -0,0 +1,90 @@ +# Archon 测试说明 + +## `test_dta.py` 简介 + +`test_dta.py` 主要验证 Archon 的 DTA 路径,包括: + +- `forward_batch` 冒烟检查 +- `train_batch` 冒烟检查 +- 与 FSDP 的数值一致性对比 + +## 测试函数说明 + +- `test_engine_is_initialized`:检查引擎能否正常初始化,并确认 DTA 开关状态正确。 +- `test_forward_batch_runs`:只验证 Archon 的 `forward_batch` 在 DTA 开启时可正常跑通。 +- `test_train_batch_runs`:只验证 Archon 的 `train_batch` 在 DTA 开启时可正常跑通并返回结果。 +- `test_forward_batch_matches_fsdp`:对比 Archon 与 FSDP 的 `forward_batch` + 输出,检查形状和数值误差是否在可接受范围内。 +- `test_train_batch_matches_fsdp`:对比 Archon 与 FSDP + 一次训练步后的梯度范数和参数更新量,检查训练信号一致性。很难强对齐,建议观察 grad_norm 和 delta_norm 是否对齐。 + +## 输入数据格式 + +通过 `--dta-data` 传入一个 `.pt` 文件,内容要求: + +- 类型是 `list[torch.Tensor]` +- 每个元素是 1-D token 序列(不做 padding) + +示例: + +```python +[ + torch.tensor([101, 2023, 2003, 1037, 3231], dtype=torch.long), + torch.tensor([101, 2064, 2017, 2393, 1029], dtype=torch.long), +] +``` + +## 参数说明 + +- `--dta-data PATH`:DTA 数据文件路径;不传会跳过 DTA 测试 +- `--dta-limit INT`:最多使用前 N 条序列,`-1` 表示全部使用 +- `--max-tokens-per-mb INT`:单条序列token 上限(用于序列/微批控制) +- `--no-dta`:关闭 DTA +- `--use-hf`:model 使用 HuggingFace 模型路径分支,即去掉 archon 包装 +- `--model-path PATH`:模型路径(与 `--use-hf` 搭配) + +## 用法示例(`python -m pytest`) + +只跑 DTA 测试: + +```bash +python -m pytest -v -s tests/experimental/archon/test_dta.py \ + --dta-data /path/to/dta_samples.pt +``` + +限制样本数量(快速迭代): + +```bash +python -m pytest -v -s tests/experimental/archon/test_dta.py \ + --dta-data /path/to/dta_samples.pt \ + --dta-limit 16 +``` + +调整 token 上限: + +```bash +python -m pytest -v -s tests/experimental/archon/test_dta.py \ + --dta-data /path/to/dta_samples.pt \ + --max-tokens-per-mb 4096 +``` + +使用 HF 模型路径: + +```bash +python -m pytest -v -s tests/experimental/archon/test_dta.py \ + --dta-data /path/to/dta_samples.pt \ + --use-hf \ + --model-path /path/to/model +``` + +按函数精确运行(`::`): + +```bash +python -m pytest -v -s tests/experimental/archon/test_dta.py::test_forward_batch_runs \ + --dta-data /path/to/dta_samples.pt +``` + +```bash +python -m pytest -v -s tests/experimental/archon/test_dta.py::test_train_batch_matches_fsdp \ + --dta-data /path/to/dta_samples.pt +``` diff --git a/tests/experimental/archon/conftest.py b/tests/experimental/archon/conftest.py index 3782ff2b28..18c094938f 100644 --- a/tests/experimental/archon/conftest.py +++ b/tests/experimental/archon/conftest.py @@ -14,6 +14,7 @@ import sys import types from pathlib import Path +from types import SimpleNamespace import pytest import torch @@ -42,6 +43,61 @@ collect_ignore_glob.extend(["test_qwen3_5*.py", "test_hf_parity_qwen3_5*.py"]) +def pytest_addoption(parser): + parser.addoption( + "--dta-data", + type=str, + default=None, + help="Path to .pt file with DTA token sequences (list[Tensor]).", + ) + parser.addoption( + "--no-dta", + action="store_true", + default=False, + help="Disable DTA.", + ) + parser.addoption( + "--max-tokens-per-mb", + type=int, + default=5596, + help="Cap sequence length and set mb_spec.max_tokens_per_mb for archon tests.", + ) + parser.addoption( + "--dta-limit", + type=int, + default=-1, + help="Use at most N sequences from --dta-data; -1 keeps all sequences.", + ) + parser.addoption( + "--use-hf", + action="store_true", + default=False, + help="Use HuggingFace model for Archon DTA tests.", + ) + parser.addoption( + "--model-path", + type=str, + default="/storage/openpsi/models/Qwen__Qwen2.5-0.5B-Instruct/", + help="Path to model.", + ) + + +@pytest.fixture(scope="module") +def archon_test_config(request) -> SimpleNamespace: + """Expose archon runtime config to tests/fixtures.""" + Ans = SimpleNamespace( + max_tokens_per_mb=int(request.config.getoption("--max-tokens-per-mb")), + tree_training_mode=( + "disabled" if request.config.getoption("--no-dta") else "dta" + ), + dta_data=request.config.getoption("--dta-data"), + dta_limit=int(request.config.getoption("--dta-limit")), + use_hf=request.config.getoption("--use-hf"), + model_path=request.config.getoption("--model-path"), + ) + return Ans + + def pytest_collection_modifyitems(config, items): """Skip archon tests based on version requirements.""" if _TORCH_VERSION >= _MIN_TORCH_VERSION: diff --git a/tests/experimental/archon/test_dta.py b/tests/experimental/archon/test_dta.py new file mode 100644 index 0000000000..de2a619183 --- /dev/null +++ b/tests/experimental/archon/test_dta.py @@ -0,0 +1,391 @@ +"""DTA tests for ArchonEngine with numerical checks against FSDP. + +This suite keeps smoke-level checks and adds dual-engine numerical validation for: +1) ``forward_batch`` output consistency +2) ``train_batch`` optimization signal consistency + +Requires ``--dta-data=`` pointing to a file containing +``list[Tensor]`` (1-D token sequences without padding). +""" + +import time +from typing import Any + +import pytest +import torch + +from tests.experimental.archon.utils import ( + compare_tensors, + create_archon_engine, + create_fsdp_engine, + destroy_test_engine, + dta_dummy_loss_fn, + dta_loss_weight_fn, + load_pt_batch, + snapshot_module_parameters, + strip_wrapper_prefixes, +) + +_CUDA_AVAILABLE = torch.cuda.is_available() + +pytestmark = [ + pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA not available"), +] + + +def _canonicalize_param_dict( + params: dict[str, torch.Tensor], + *, + source: str, + archon_adapter: Any | None = None, +) -> dict[str, torch.Tensor]: + """Normalize parameter-key namespace for cross-engine comparisons.""" + canonical: dict[str, torch.Tensor] = {} + for raw_name, value in params.items(): + if source == "archon": + assert archon_adapter is not None, ( + "archon_adapter is required for archon canonicalization" + ) + mapped = archon_adapter.convert_single_to_hf(raw_name, value) + if not mapped: + key = strip_wrapper_prefixes(raw_name) + else: + key, _ = mapped[0] + else: + key = strip_wrapper_prefixes(raw_name) + canonical[key] = value + return canonical + + +def _clone_batch(batch: dict[str, Any]) -> dict[str, Any]: + out: dict[str, Any] = {} + for key, value in batch.items(): + out[key] = value.clone() if isinstance(value, torch.Tensor) else value + return out + + +def _assert_tensor_finite(t: torch.Tensor, name: str) -> None: + assert torch.isfinite(t).all(), f"{name} contains non-finite values" + + +def _assert_grad_norm_finite(result: dict[str, Any], name: str) -> float: + assert "grad_norm" in result, f"{name} missing grad_norm: {list(result.keys())}" + grad_norm = float(result["grad_norm"]) + assert torch.isfinite(torch.tensor(grad_norm)).item(), ( + f"{name} grad_norm is NaN/Inf: {grad_norm}" + ) + return grad_norm + + +def _run_forward_batch( + engine: Any, batch: dict[str, Any], name: str +) -> tuple[torch.Tensor, float, float]: + """Run one forward step and validate finite output.""" + engine.eval() + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + start_time = time.perf_counter() + with torch.no_grad(): + output = engine.forward_batch(batch) + if torch.cuda.is_available(): + torch.cuda.synchronize() + elapsed_s = time.perf_counter() - start_time + peak_mem_mib = ( + torch.cuda.max_memory_allocated() / (1024**2) + if torch.cuda.is_available() + else 0.0 + ) + print( + f"[{name}] forward_batch elapsed: {elapsed_s:.2f} s, peak_mem: {peak_mem_mib:.2f} MiB" + ) + assert output.shape[0] == batch["input_ids"].shape[0], ( + f"{name} forward output shape mismatch" + ) + assert output is not None, f"{name} forward output is None" + _assert_tensor_finite(output, f"{name} forward output") + return output, elapsed_s, peak_mem_mib + + +def _run_train_batch_and_snapshot( + engine: Any, + batch: dict[str, Any], + name: str, +) -> tuple[ + dict[str, Any], + dict[str, torch.Tensor], + dict[str, torch.Tensor], + dict[str, torch.Tensor], + float, + float, + float, +]: + """Run one train step and return result, snapshots, deltas, grad_norm, elapsed_s, peak_mem_mib.""" + engine.train() + engine.optimizer_zero_grad() + before = snapshot_module_parameters( + engine.model, + to_cpu=True, + param_filter=lambda n, p: p.requires_grad, + ) + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + start_time = time.perf_counter() + result = engine.train_batch(batch, dta_dummy_loss_fn, dta_loss_weight_fn) + if torch.cuda.is_available(): + torch.cuda.synchronize() + elapsed_s = time.perf_counter() - start_time + peak_mem_mib = ( + torch.cuda.max_memory_allocated() / (1024**2) + if torch.cuda.is_available() + else 0.0 + ) + print( + f"[{name}] train_batch elapsed: {elapsed_s:.2f} s, peak_mem: {peak_mem_mib:.2f} MiB" + ) + grad_norm = _assert_grad_norm_finite(result, name) + after = snapshot_module_parameters( + engine.model, + to_cpu=True, + param_filter=lambda n, p: p.requires_grad, + ) + before_names = set(before.keys()) + after_names = set(after.keys()) + assert before_names == after_names, ( + f"{name} trainable parameter names changed after train_batch: " + f"only_before={sorted(before_names - after_names)[:20]}, " + f"only_after={sorted(after_names - before_names)[:20]}" + ) + + deltas: dict[str, torch.Tensor] = {} + for param_name in sorted(before_names): + _assert_tensor_finite( + after[param_name], f"{name} param after train_batch: {param_name}" + ) + delta = after[param_name] - before[param_name] + _assert_tensor_finite(delta, f"{name} delta after train_batch: {param_name}") + deltas[param_name] = delta + + engine.optimizer_zero_grad() + return result, before, after, deltas, grad_norm, elapsed_s, peak_mem_mib + + +def _assert_train_consistency( + *, + archon_before: dict[str, torch.Tensor], + fsdp_before: dict[str, torch.Tensor], + archon_deltas: dict[str, torch.Tensor], + fsdp_deltas: dict[str, torch.Tensor], + archon_grad_norm: float, + fsdp_grad_norm: float, + archon_adapter: Any, +) -> None: + """Validate train-step consistency between Archon and FSDP.""" + grad_norm_gap = abs(archon_grad_norm - fsdp_grad_norm) + grad_norm_rel_gap = grad_norm_gap / max(abs(fsdp_grad_norm), 1e-6) + print( + f"[Archon vs FSDP] archon_grad_norm={archon_grad_norm:.6f}, fsdp_grad_norm={fsdp_grad_norm:.6f}, gap={grad_norm_gap:.6f}, rel_gap={grad_norm_rel_gap:.6f}" + ) + assert grad_norm_rel_gap < 0.25, ( + "train_batch grad_norm differs too much: " + f"archon={archon_grad_norm:.6f}, fsdp={fsdp_grad_norm:.6f}, rel_gap={grad_norm_rel_gap:.3f}" + ) + + archon_before_canonical = _canonicalize_param_dict( + archon_before, source="archon", archon_adapter=archon_adapter + ) + fsdp_before_canonical = _canonicalize_param_dict(fsdp_before, source="fsdp") + archon_deltas_canonical = _canonicalize_param_dict( + archon_deltas, source="archon", archon_adapter=archon_adapter + ) + fsdp_deltas_canonical = _canonicalize_param_dict(fsdp_deltas, source="fsdp") + + archon_names = set(archon_before_canonical.keys()) + fsdp_names = set(fsdp_before_canonical.keys()) + assert archon_names == fsdp_names, ( + "Canonical trainable parameter names are not aligned between Archon and FSDP: " + f"only_archon={sorted(archon_names - fsdp_names)[:20]}, " + f"only_fsdp={sorted(fsdp_names - archon_names)[:20]}" + ) + + # Compute L2 norm of all delta tensors for Archon and FSDP (sqrt of sum of squares of all elements) + def _global_l2_norm(param_dict): + # param_dict: dict[param_name, tensor] + return float( + torch.sqrt(sum((param.float() ** 2).sum() for param in param_dict.values())) + ) + + archon_delta_norm = _global_l2_norm(archon_deltas_canonical) + fsdp_delta_norm = _global_l2_norm(fsdp_deltas_canonical) + + print( + f"[delta norm] archon_delta_norm={archon_delta_norm:.6f}, " + f"fsdp_delta_norm={fsdp_delta_norm:.6f}, " + f"abs_gap={abs(archon_delta_norm - fsdp_delta_norm):.6f}, " + f"rel_gap={abs(archon_delta_norm - fsdp_delta_norm) / (abs(fsdp_delta_norm) + 1e-8):.4f}" + ) + + mismatches = [] + for name in sorted(archon_names): + delta_metrics = compare_tensors( + archon_deltas_canonical[name], + fsdp_deltas_canonical[name], + atol=1e-8, + rtol=0.3, + ) + if not delta_metrics.shape_match or not delta_metrics.allclose: + mismatches.append((name, str(delta_metrics))) + + if mismatches: + print( + f"Note: Found {len(mismatches)} parameter delta mismatches out of {len(archon_names)} parameters after train_batch (showing up to 20): " + f"{mismatches[:20]}" + ) + + +@pytest.fixture(scope="module") +def batch(request, archon_test_config): + pt_path = archon_test_config.dta_data + if pt_path is None: + pytest.skip("Skipped: pass --dta-data= to run DTA tests") + return load_pt_batch(test_config=archon_test_config) + + +def test_engine_is_initialized(archon_test_config): + """Engine initializes with DTA flag from CLI option.""" + engine = create_archon_engine(test_config=archon_test_config) + try: + assert engine.initialized + assert engine.tree_training_mode == archon_test_config.tree_training_mode + assert hasattr(engine, "dta_wrapper") == ( + archon_test_config.tree_training_mode == "dta" + ) + finally: + destroy_test_engine(engine) + + +def test_forward_batch_runs(batch, archon_test_config): + """Smoke check for DTA forward path on Archon engine.""" + archon_batch = _clone_batch(batch) + archon_engine = create_archon_engine(test_config=archon_test_config) + try: + _, _, _ = _run_forward_batch(archon_engine, archon_batch, name="Archon") + finally: + destroy_test_engine(archon_engine) + + +def test_train_batch_runs(batch, archon_test_config): + """Smoke check for DTA train path on Archon engine.""" + archon_batch = _clone_batch(batch) + archon_engine = create_archon_engine(test_config=archon_test_config) + try: + result, _, _, _, _, _, _ = _run_train_batch_and_snapshot( + archon_engine, archon_batch, name="Archon" + ) + assert isinstance(result, dict), f"Expected dict, got {type(result)}" + finally: + destroy_test_engine(archon_engine) + + +def test_forward_batch_matches_fsdp(batch, archon_test_config): + """Numerical check: DTA-enabled Archon forward_batch ~= FSDP forward_batch.""" + archon_batch = _clone_batch(batch) + fsdp_batch = _clone_batch(batch) + + archon_engine = create_archon_engine(test_config=archon_test_config) + try: + archon_out, archon_elapsed_s, archon_peak_mem_mib = _run_forward_batch( + archon_engine, archon_batch, name="Archon" + ) + finally: + destroy_test_engine(archon_engine) + + fsdp_engine = create_fsdp_engine(test_config=archon_test_config) + try: + fsdp_out, fsdp_elapsed_s, fsdp_peak_mem_mib = _run_forward_batch( + fsdp_engine, fsdp_batch, name="FSDP" + ) + finally: + destroy_test_engine(fsdp_engine) + + assert archon_out.shape == fsdp_out.shape, ( + f"forward_batch shape mismatch: archon={archon_out.shape}, fsdp={fsdp_out.shape}" + ) + + metrics = compare_tensors(archon_out, fsdp_out, atol=1e-4, rtol=1e-2) + assert metrics.mean_diff < 0.25, f"forward_batch mean_diff too large: {metrics}" + forward_speedup = fsdp_elapsed_s / max(archon_elapsed_s, 1e-12) + print( + "[Forward speedup] " + f"archon={archon_elapsed_s:.4f}s, fsdp={fsdp_elapsed_s:.4f}s, " + f"speedup={forward_speedup:.3f}x" + ) + print( + "[Forward peak memory] " + f"archon={archon_peak_mem_mib:.2f}MiB, fsdp={fsdp_peak_mem_mib:.2f}MiB" + ) + + +def test_train_batch_matches_fsdp(batch, archon_test_config): + """Numerical check: DTA-enabled Archon train signal ~= FSDP train signal.""" + archon_batch = _clone_batch(batch) + fsdp_batch = _clone_batch(batch) + + archon_engine = create_archon_engine(test_config=archon_test_config) + archon_adapter = None + try: + ( + archon_result, + archon_before, + _, + archon_deltas, + archon_grad_norm, + archon_elapsed_s, + archon_peak_mem_mib, + ) = _run_train_batch_and_snapshot( + archon_engine, + archon_batch, + name="Archon", + ) + archon_adapter = archon_engine.state_dict_adapter + assert archon_adapter is not None, ( + "Archon state_dict_adapter should be initialized" + ) + finally: + destroy_test_engine(archon_engine) + + fsdp_engine = create_fsdp_engine(test_config=archon_test_config) + try: + ( + fsdp_result, + fsdp_before, + _, + fsdp_deltas, + fsdp_grad_norm, + fsdp_elapsed_s, + fsdp_peak_mem_mib, + ) = _run_train_batch_and_snapshot(fsdp_engine, fsdp_batch, name="FSDP") + finally: + destroy_test_engine(fsdp_engine) + assert archon_adapter is not None + + train_speedup = fsdp_elapsed_s / max(archon_elapsed_s, 1e-12) + print( + "[Train speedup] " + f"archon={archon_elapsed_s:.4f}s, fsdp={fsdp_elapsed_s:.4f}s, " + f"speedup={train_speedup:.3f}x" + ) + print( + "[Train peak memory] " + f"archon={archon_peak_mem_mib:.2f}MiB, fsdp={fsdp_peak_mem_mib:.2f}MiB" + ) + _assert_train_consistency( + archon_before=archon_before, + fsdp_before=fsdp_before, + archon_deltas=archon_deltas, + fsdp_deltas=fsdp_deltas, + archon_grad_norm=archon_grad_norm, + fsdp_grad_norm=fsdp_grad_norm, + archon_adapter=archon_adapter, + ) diff --git a/tests/experimental/archon/test_dynamic_nccl_p2p.py b/tests/experimental/archon/test_dynamic_nccl_p2p.py new file mode 100644 index 0000000000..15075f7100 --- /dev/null +++ b/tests/experimental/archon/test_dynamic_nccl_p2p.py @@ -0,0 +1,25 @@ +"""Minimal NCCL P2P test for dynamic request-driven communication.""" + +import pytest +import torch + +from tests.experimental.archon.utils import run_torchrun_test + +from areal.infra.platforms import current_platform + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" +) + + +@pytest.mark.multi_gpu +@pytest.mark.slow +def test_dynamic_nccl_p2p_mailbox(): + """Verify dynamic header-driven send/recv without a fixed global order.""" + if current_platform.device_count() < 3: + pytest.skip("This test requires at least 3 GPUs") + + run_torchrun_test( + "tests/experimental/archon/torchrun/run_dynamic_nccl_p2p.py", + n_gpus=3, + ) diff --git a/tests/experimental/archon/torchrun/archon_training_test.yaml b/tests/experimental/archon/torchrun/archon_training_test.yaml new file mode 100644 index 0000000000..7a369504d2 --- /dev/null +++ b/tests/experimental/archon/torchrun/archon_training_test.yaml @@ -0,0 +1,72 @@ +# Example config for tests/experimental/archon/torchrun/run_archon_training_test.py. +# +# Launch (2 GPUs, dp=2, DTA enabled): +# torchrun --nproc_per_node=2 \ +# tests/experimental/archon/torchrun/run_archon_training_test.py \ +# --config tests/experimental/archon/torchrun/archon_training_test.yaml \ +# test_config.step=4 \ +# test_config.data_dir=/path/to/data_dir +# +# CLI overrides use OmegaConf dotlist syntax; eg ``actor.mb_spec.max_tokens_per_mb=8192``. + +experiment_name: archon_train_test +trial_name: trial0 + +cluster: + fileroot: /storage/openpsi/experiments + +actor: + backend: archon:d2 + path: /storage/openpsi/models/Qwen__Qwen3-0.6B + dtype: bfloat16 + init_from_scratch: false + gradient_checkpointing: false + + mb_spec: + max_tokens_per_mb: 16384 + + archon: + attn_type: varlen + + optimizer: + type: adam + lr: 1.0e-5 + weight_decay: 0.01 + warmup_steps_proportion: 0.0 + lr_scheduler_type: constant + gradient_clipping: 1.0 + + # Dynamic Tree Attention toggles. + tree_training_mode: dta # {disabled, sparse, dta} + dta_block_size: 2048 + packing_algorithm: dta # rollout: {ffd, kk, dta} (not mb_spec) + +test_config: + # step is required; set via CLI or here. + step: 4 + # Optional: max sequences to use from each .pt after load (0 = use all). + max_sequences_per_pt: 16 + # data_dir is required; must be set via CLI override, eg + # test_config.data_dir=/path/to/data_dir + data_dir: "" + disable_optimizer: false + save_diff: true + save_params: false + save_initial_params: false + seed: 42 + is_critic: false + # Optional entropy regularization for debugging token-length sensitivity. + # entropy_mode=sum makes longer (less truncated) batches contribute more. + entropy_coef: 1e-4 + entropy_mode: sum # {mean, sum} + # Optional: store full fp32 parameter deltas in diff.pt for elementwise compare. + save_full_diff_tensors_fp32: true + # Optional: after the last step, write last_grads.pt (before teardown). Safe with + # disable_optimizer=true (captures .grad inside the patched step before zeroing). + dump_last_grads: true + # Full fp32 grad tensors in last_grads.pt (set false for stats-only / smaller files). + save_full_last_grad_tensors_fp32: true + # Optional: ``null``/omit = no HF export after init; else directory for ``save_model_to_hf``. + save_hf_checkpoint_dir: null + # Optional: dump forward_batch parity payload (current mode vs disabled baseline). + dump_forward_compare: true diff --git a/tests/experimental/archon/torchrun/compare_training_dumps.py b/tests/experimental/archon/torchrun/compare_training_dumps.py new file mode 100644 index 0000000000..0a54fba926 --- /dev/null +++ b/tests/experimental/archon/torchrun/compare_training_dumps.py @@ -0,0 +1,1273 @@ +"""Offline comparison tool for two ArchonEngine training-test dumps. + +Given two directories produced by :mod:`run_archon_training_test`, this script: + +1. Loads global ``stats.jsonl`` and builds a per-step view, then performs a + strict loss alignment check. +2. Preferentially loads ``diff.pt`` from each dump and compares parameter update + signatures. If ``diff.pt`` is absent, falls back to legacy ``params.pt`` + (and optionally ``params_initial.pt``) tensor diffs. +3. If both dumps contain ``last_grads.pt`` (from ``test_config.dump_last_grads``), + checks A/B alignment (signature + optional full ``grad_tensors_fp32``) and + prints one compact table like the diff dump's "Top delta_gap" plus a one-line + PASS/FAIL summary. ``requires_grad=True`` only means a parameter *may* receive + gradients; a zero ``.grad`` after backward is normal if the loss graph does not + flow to that parameter for that step (e.g. critic value head only). +4. If both dumps contain ``forward.step*.summary.pt`` (from + ``test_config.dump_forward_compare``), compares valid-token forward outputs. + +The tool is launched as a plain Python script -- no distributed setup required. + +Example:: + + python tests/experimental/archon/torchrun/compare_training_dumps.py \\ + --dump-a /tmp/run_a --dump-b /tmp/run_b \\ + --loss-rtol 1e-6 --loss-atol 1e-6 + +Exit code is non-zero when any enabled alignment check fails. +""" + +from __future__ import annotations + +import argparse +import json +import math +import os +import re +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import torch + +# ----------------------------------------------------------------------------- +# Terminal colors +# ----------------------------------------------------------------------------- + + +_ANSI_RESET = "\033[0m" +_ANSI_BOLD = "\033[1m" +_ANSI_RED = "\033[31m" +_ANSI_GREEN = "\033[32m" +_ANSI_YELLOW = "\033[33m" + + +def _supports_color() -> bool: + term = os.environ.get("TERM", "").lower() + no_color = os.environ.get("NO_COLOR") + return sys.stdout.isatty() and bool(term) and term != "dumb" and not no_color + + +def _colorize(text: str, color: str, *, bold: bool = False) -> str: + if not _supports_color(): + return text + prefix = f"{_ANSI_BOLD}{color}" if bold else color + return f"{prefix}{text}{_ANSI_RESET}" + + +# ----------------------------------------------------------------------------- +# Loading +# ----------------------------------------------------------------------------- + + +def _stats_file(dump_dir: str) -> str: + path = os.path.join(dump_dir, "stats.jsonl") + if not os.path.isfile(path): + raise FileNotFoundError( + f"Missing stats file: {path}. Did the training run finish?" + ) + return path + + +def _load_stats(dump_dir: str) -> dict[int, list[dict[str, Any]]]: + """Return ``{step -> list[records]}`` for one dump.""" + by_step: dict[int, list[dict[str, Any]]] = {} + path = _stats_file(dump_dir) + with open(path) as fp: + for line in fp: + line = line.strip() + if not line: + continue + rec = json.loads(line) + step = int(rec["step"]) + by_step.setdefault(step, []).append(rec) + return by_step + + +def _rank_head_loss(records: list[dict[str, Any]]) -> float: + """Pick one valid loss for a step. + + Current dumps store one global record per step; we read the first valid one. + """ + for r in records: + loss = r.get("loss") + if loss is None: + continue + if isinstance(loss, float) and math.isnan(loss): + continue + return float(loss) + return float("nan") + + +# ----------------------------------------------------------------------------- +# Loss comparison +# ----------------------------------------------------------------------------- + + +@dataclass +class LossDiff: + step: int + loss_a: float + loss_b: float + abs_gap: float + rel_gap: float + aligned: bool + + +def _compare_losses( + stats_a: dict[int, list[dict[str, Any]]], + stats_b: dict[int, list[dict[str, Any]]], + *, + atol: float, + rtol: float, +) -> list[LossDiff]: + steps_a = set(stats_a.keys()) + steps_b = set(stats_b.keys()) + shared = sorted(steps_a & steps_b) + if steps_a != steps_b: + print( + f"[warn] step sets differ: only_a={sorted(steps_a - steps_b)[:10]} " + f"only_b={sorted(steps_b - steps_a)[:10]}" + ) + + diffs: list[LossDiff] = [] + for step in shared: + la = _rank_head_loss(stats_a[step]) + lb = _rank_head_loss(stats_b[step]) + gap = abs(la - lb) + rel = gap / max(abs(lb), 1e-12) + aligned = gap <= (atol + rtol * abs(lb)) + diffs.append( + LossDiff( + step=step, + loss_a=la, + loss_b=lb, + abs_gap=gap, + rel_gap=rel, + aligned=aligned, + ) + ) + return diffs + + +# ----------------------------------------------------------------------------- +# Parameter comparison +# ----------------------------------------------------------------------------- + + +@dataclass +class ParamDiff: + name: str + shape_match: bool + max_diff: float + mean_diff: float + l2_diff: float + l2_a: float + l2_b: float + norm_gap: float + delta_gap: float + + +@dataclass +class ParamUpdateStat: + name: str + numel: float + mean_abs_update: float + max_abs_update: float + l2_update: float + rel_l2_update: float + + +@dataclass +class GradSnapshotStat: + """Per-parameter stats from ``last_grads.pt`` (final-step .grad snapshot).""" + + name: str + numel: float + mean_abs: float + max_abs: float + l2: float + + +@dataclass +class DiffFileGap: + name: str + numel_match: bool + mean_abs_gap: float + mean_abs_rel_gap: float + max_abs_gap: float + max_abs_rel_gap: float + l2_gap: float + l2_rel_gap: float + + +@dataclass +class ForwardTensorGap: + key: str + shape_match: bool + max_diff: float + mean_diff: float + l2_rel_gap: float + + +_FORWARD_SUMMARY_RE = re.compile(r"forward\.step(\d+)\.summary\.pt$") + + +def _load_diff_signatures(path: str) -> dict[str, ParamUpdateStat]: + payload = torch.load(path, map_location="cpu") + if not isinstance(payload, dict): + raise ValueError(f"Expected dict payload in {path}, got {type(payload)}") + params = payload.get("params") + if not isinstance(params, dict): + raise ValueError(f"Expected 'params' dict in {path}, got {type(params)}") + + out: dict[str, ParamUpdateStat] = {} + for name, item in params.items(): + if not isinstance(item, dict): + raise ValueError( + f"Expected metrics dict for parameter '{name}' in {path}, got {type(item)}" + ) + out[name] = ParamUpdateStat( + name=str(name), + numel=float(item.get("numel", 0.0)), + mean_abs_update=float(item.get("mean_abs_update", 0.0)), + max_abs_update=float(item.get("max_abs_update", 0.0)), + l2_update=float(item.get("l2_update", 0.0)), + rel_l2_update=float(item.get("rel_l2_update", 0.0)), + ) + return out + + +def _print_last_grads_requires_grad_one_line(path_a: str, path_b: str) -> None: + """One summary line; list frozen param names only when non-empty.""" + + payloads: list[dict[str, Any]] = [] + for path in (path_a, path_b): + raw = torch.load(path, map_location="cpu") + payloads.append(raw if isinstance(raw, dict) else {}) + + line_parts: list[str] = [] + for label, payload in zip(("A", "B"), payloads, strict=True): + meta = payload.get("requires_grad_meta") + if not isinstance(meta, dict): + line_parts.append(f"{label}: (no requires_grad_meta)") + else: + nt = meta.get("num_named_requires_grad_true") + nf = meta.get("num_named_requires_grad_false") + line_parts.append(f"{label}: requires_grad true={nt} false={nf}") + print(f" {'; '.join(line_parts)}") + + for label, payload in zip(("A", "B"), payloads, strict=True): + meta = payload.get("requires_grad_meta") + if not isinstance(meta, dict): + continue + false_names = meta.get("named_requires_grad_false") or [] + if false_names: + print(f" {label} requires_grad=False (sample): {false_names[:8]}") + + +def _load_grad_signatures(path: str) -> dict[str, GradSnapshotStat]: + payload = torch.load(path, map_location="cpu") + if not isinstance(payload, dict): + raise ValueError(f"Expected dict payload in {path}, got {type(payload)}") + params = payload.get("params") + if not isinstance(params, dict): + raise ValueError(f"Expected 'params' dict in {path}, got {type(params)}") + + out: dict[str, GradSnapshotStat] = {} + for name, item in params.items(): + if not isinstance(item, dict): + raise ValueError( + f"Expected metrics dict for parameter '{name}' in {path}, got {type(item)}" + ) + out[name] = GradSnapshotStat( + name=str(name), + numel=float(item.get("numel", 0.0)), + mean_abs=float(item.get("mean_abs", 0.0)), + max_abs=float(item.get("max_abs", 0.0)), + l2=float(item.get("l2", 0.0)), + ) + return out + + +def _compare_grad_signatures( + stats_a: dict[str, GradSnapshotStat], + stats_b: dict[str, GradSnapshotStat], +) -> tuple[list[DiffFileGap], list[str], list[str]]: + """Compare ``last_grads.pt`` entries; reuse :class:`DiffFileGap` for reporting.""" + names_a = set(stats_a.keys()) + names_b = set(stats_b.keys()) + shared = sorted(names_a & names_b) + only_a = sorted(names_a - names_b) + only_b = sorted(names_b - names_a) + + gaps: list[DiffFileGap] = [] + for name in shared: + a = stats_a[name] + b = stats_b[name] + gaps.append( + DiffFileGap( + name=name, + numel_match=int(round(a.numel)) == int(round(b.numel)), + mean_abs_gap=abs(a.mean_abs - b.mean_abs), + mean_abs_rel_gap=_relative_gap(a.mean_abs, b.mean_abs), + max_abs_gap=abs(a.max_abs - b.max_abs), + max_abs_rel_gap=_relative_gap(a.max_abs, b.max_abs), + l2_gap=abs(a.l2 - b.l2), + l2_rel_gap=_relative_gap(a.l2, b.l2), + ) + ) + return gaps, only_a, only_b + + +def _global_grad_l2_relative_gap( + stats_a: dict[str, GradSnapshotStat], + stats_b: dict[str, GradSnapshotStat], +) -> float: + shared = set(stats_a.keys()) & set(stats_b.keys()) + if not shared: + return 0.0 + total_l2_a = math.sqrt(sum(float(stats_a[name].l2) ** 2 for name in shared)) + total_l2_b = math.sqrt(sum(float(stats_b[name].l2) ** 2 for name in shared)) + return _relative_gap(total_l2_a, total_l2_b) + + +def _load_full_delta_tensors(path: str) -> dict[str, torch.Tensor]: + payload = torch.load(path, map_location="cpu") + if not isinstance(payload, dict): + return {} + tensors = payload.get("delta_tensors_fp32") + if not isinstance(tensors, dict): + return {} + out: dict[str, torch.Tensor] = {} + for name, tensor in tensors.items(): + if isinstance(tensor, torch.Tensor): + out[str(name)] = tensor.detach().float().cpu() + return out + + +def _load_full_grad_tensors(path: str) -> dict[str, torch.Tensor]: + payload = torch.load(path, map_location="cpu") + if not isinstance(payload, dict): + return {} + tensors = payload.get("grad_tensors_fp32") + if not isinstance(tensors, dict): + return {} + out: dict[str, torch.Tensor] = {} + for name, tensor in tensors.items(): + if isinstance(tensor, torch.Tensor): + out[str(name)] = tensor.detach().float().cpu() + return out + + +def _relative_gap(a: float, b: float) -> float: + return abs(a - b) / max(abs(a), abs(b), 1e-12) + + +def _compare_diff_signatures( + stats_a: dict[str, ParamUpdateStat], + stats_b: dict[str, ParamUpdateStat], +) -> tuple[list[DiffFileGap], list[str], list[str]]: + names_a = set(stats_a.keys()) + names_b = set(stats_b.keys()) + shared = sorted(names_a & names_b) + only_a = sorted(names_a - names_b) + only_b = sorted(names_b - names_a) + + gaps: list[DiffFileGap] = [] + for name in shared: + a = stats_a[name] + b = stats_b[name] + gaps.append( + DiffFileGap( + name=name, + numel_match=int(round(a.numel)) == int(round(b.numel)), + mean_abs_gap=abs(a.mean_abs_update - b.mean_abs_update), + mean_abs_rel_gap=_relative_gap(a.mean_abs_update, b.mean_abs_update), + max_abs_gap=abs(a.max_abs_update - b.max_abs_update), + max_abs_rel_gap=_relative_gap(a.max_abs_update, b.max_abs_update), + l2_gap=abs(a.l2_update - b.l2_update), + l2_rel_gap=_relative_gap(a.l2_update, b.l2_update), + ) + ) + return gaps, only_a, only_b + + +def _compare_full_deltas( + tensors_a: dict[str, torch.Tensor], + tensors_b: dict[str, torch.Tensor], +) -> tuple[list[ParamDiff], list[ParamDiff], list[ParamDiff]]: + """Compare full fp32 delta tensors. + + Returns (shared_diffs, only_a_vs_zero, only_b_vs_zero). + """ + shared, only_a, only_b = _compare_state_dicts(tensors_a, tensors_b) + + def _vs_zero(name: str, tensor: torch.Tensor) -> ParamDiff: + zeros = torch.zeros_like(tensor) + delta = tensor - zeros + l2_tensor = float(tensor.norm().item()) + l2_delta = float(delta.norm().item()) + return ParamDiff( + name=name, + shape_match=True, + max_diff=float(delta.abs().max().item()) if delta.numel() > 0 else 0.0, + mean_diff=float(delta.abs().mean().item()) if delta.numel() > 0 else 0.0, + l2_diff=l2_delta, + l2_a=l2_tensor, + l2_b=0.0, + norm_gap=_relative_gap(l2_tensor, 0.0), + delta_gap=l2_delta / max(l2_tensor, 1e-12), + ) + + only_a_diffs = [_vs_zero(name, tensors_a[name]) for name in only_a] + only_b_diffs = [_vs_zero(name, tensors_b[name]) for name in only_b] + return shared, only_a_diffs, only_b_diffs + + +def _summarize_diff_gaps(gaps: list[DiffFileGap]) -> dict[str, float]: + if not gaps: + return { + "num_params": 0.0, + "numel_mismatch": 0.0, + "max_abs_gap_max": 0.0, + "max_abs_rel_gap_max": 0.0, + "mean_abs_gap_mean": 0.0, + "mean_abs_rel_gap_mean": 0.0, + "l2_gap_mean": 0.0, + "l2_rel_gap_mean": 0.0, + "mean_abs_rel_gap_max": 0.0, + "l2_rel_gap_max": 0.0, + } + return { + "num_params": float(len(gaps)), + "numel_mismatch": float(sum(0 if g.numel_match else 1 for g in gaps)), + "max_abs_gap_max": float(max(g.max_abs_gap for g in gaps)), + "max_abs_rel_gap_max": float(max(g.max_abs_rel_gap for g in gaps)), + "mean_abs_gap_mean": float(sum(g.mean_abs_gap for g in gaps) / len(gaps)), + "mean_abs_rel_gap_mean": float( + sum(g.mean_abs_rel_gap for g in gaps) / len(gaps) + ), + "l2_gap_mean": float(sum(g.l2_gap for g in gaps) / len(gaps)), + "l2_rel_gap_mean": float(sum(g.l2_rel_gap for g in gaps) / len(gaps)), + "mean_abs_rel_gap_max": float(max(g.mean_abs_rel_gap for g in gaps)), + "l2_rel_gap_max": float(max(g.l2_rel_gap for g in gaps)), + } + + +def _global_l2_update_relative_gap( + stats_a: dict[str, ParamUpdateStat], + stats_b: dict[str, ParamUpdateStat], +) -> float: + shared = set(stats_a.keys()) & set(stats_b.keys()) + if not shared: + return 0.0 + total_l2_a = math.sqrt(sum(float(stats_a[name].l2_update) ** 2 for name in shared)) + total_l2_b = math.sqrt(sum(float(stats_b[name].l2_update) ** 2 for name in shared)) + return _relative_gap(total_l2_a, total_l2_b) + + +def _load_state_dict(path: str) -> dict[str, torch.Tensor]: + state = torch.load(path, map_location="cpu", weights_only=True) + if not isinstance(state, dict): + raise ValueError(f"Expected dict state_dict in {path}, got {type(state)}") + return {k: v.detach().float() for k, v in state.items()} + + +def _compare_state_dicts( + state_a: dict[str, torch.Tensor], + state_b: dict[str, torch.Tensor], +) -> tuple[list[ParamDiff], list[str], list[str]]: + names_a = set(state_a.keys()) + names_b = set(state_b.keys()) + shared = sorted(names_a & names_b) + only_a = sorted(names_a - names_b) + only_b = sorted(names_b - names_a) + + diffs: list[ParamDiff] = [] + for name in shared: + a = state_a[name] + b = state_b[name] + shape_ok = a.shape == b.shape + if not shape_ok: + diffs.append( + ParamDiff( + name=name, + shape_match=False, + max_diff=float("inf"), + mean_diff=float("inf"), + l2_diff=float("inf"), + l2_a=float("inf"), + l2_b=float("inf"), + norm_gap=float("inf"), + delta_gap=float("inf"), + ) + ) + continue + delta = a - b + l2_a = float(a.norm().item()) + l2_b = float(b.norm().item()) + l2_delta = float(delta.norm().item()) + diffs.append( + ParamDiff( + name=name, + shape_match=True, + max_diff=float(delta.abs().max().item()), + mean_diff=float(delta.abs().mean().item()), + l2_diff=l2_delta, + l2_a=l2_a, + l2_b=l2_b, + norm_gap=_relative_gap(l2_a, l2_b), + delta_gap=l2_delta / max(l2_a, l2_b, 1e-12), + ) + ) + return diffs, only_a, only_b + + +def _discover_forward_summary_files(dump_dir: str) -> dict[int, str]: + out: dict[int, str] = {} + for p in Path(dump_dir).glob("forward.step*.summary.pt"): + m = _FORWARD_SUMMARY_RE.fullmatch(p.name) + if m is None: + continue + step = int(m.group(1)) + out[step] = str(p) + return out + + +def _load_forward_compare_payload(path: str) -> dict[str, Any]: + payload = torch.load(path, map_location="cpu") + if not isinstance(payload, dict): + raise ValueError(f"Expected dict payload in {path}, got {type(payload)}") + return payload + + +def _compare_forward_tensor( + a: torch.Tensor, + b: torch.Tensor, + *, + key: str, +) -> ForwardTensorGap: + ta = a.detach().float().cpu() + tb = b.detach().float().cpu() + if ta.shape != tb.shape: + return ForwardTensorGap( + key=key, + shape_match=False, + max_diff=float("inf"), + mean_diff=float("inf"), + l2_rel_gap=float("inf"), + ) + delta = (ta - tb).abs() + max_diff = float(delta.max().item()) if delta.numel() > 0 else 0.0 + mean_diff = float(delta.mean().item()) if delta.numel() > 0 else 0.0 + l2a = float(ta.norm().item()) + l2b = float(tb.norm().item()) + l2_rel_gap = _relative_gap(l2a, l2b) + return ForwardTensorGap( + key=key, + shape_match=True, + max_diff=max_diff, + mean_diff=mean_diff, + l2_rel_gap=l2_rel_gap, + ) + + +def _summarize_param_diffs(diffs: list[ParamDiff]) -> dict[str, float]: + if not diffs: + return { + "num_params": 0, + "global_max_diff": 0.0, + "global_mean_diff": 0.0, + "global_l2_diff": 0.0, + "global_norm_gap": 0.0, + "max_norm_gap": 0.0, + "global_delta_gap": 0.0, + "max_delta_gap": 0.0, + } + matched = [d for d in diffs if d.shape_match] + if not matched: + return { + "num_params": len(diffs), + "global_max_diff": float("inf"), + "global_mean_diff": float("inf"), + "global_l2_diff": float("inf"), + "global_norm_gap": float("inf"), + "max_norm_gap": float("inf"), + "global_delta_gap": float("inf"), + "max_delta_gap": float("inf"), + } + max_diff = max(d.max_diff for d in matched) + total_l2 = math.sqrt(sum(d.l2_diff**2 for d in matched)) + mean_diff = sum(d.mean_diff for d in matched) / len(matched) + total_l2_a = math.sqrt(sum(d.l2_a**2 for d in matched)) + total_l2_b = math.sqrt(sum(d.l2_b**2 for d in matched)) + global_norm_gap = _relative_gap(total_l2_a, total_l2_b) + max_norm_gap = max(d.norm_gap for d in matched) + global_delta_gap = total_l2 / max(total_l2_a, total_l2_b, 1e-12) + max_delta_gap = max(d.delta_gap for d in matched) + return { + "num_params": len(diffs), + "global_max_diff": float(max_diff), + "global_mean_diff": float(mean_diff), + "global_l2_diff": float(total_l2), + "global_norm_gap": float(global_norm_gap), + "max_norm_gap": float(max_norm_gap), + "global_delta_gap": float(global_delta_gap), + "max_delta_gap": float(max_delta_gap), + } + + +# ----------------------------------------------------------------------------- +# Reporting +# ----------------------------------------------------------------------------- + + +def _print_loss_report( + diffs: list[LossDiff], + atol: float, + rtol: float, +) -> bool: + print("\n=== Per-step loss comparison (strict) ===") + print(f"atol={atol:.3e} rtol={rtol:.3e}") + print( + f"{'step':>4} {'loss_a':>14} {'loss_b':>14} " + f"{'abs_gap':>12} {'rel_gap':>12} {'status':>8}" + ) + ok = True + for d in diffs: + status_raw = "OK" if d.aligned else "MISMATCH" + status = ( + _colorize(status_raw, _ANSI_GREEN, bold=True) + if d.aligned + else _colorize(status_raw, _ANSI_RED, bold=True) + ) + ok = ok and d.aligned + print( + f"{d.step:>4d} {d.loss_a:>14.6f} {d.loss_b:>14.6f} " + f"{d.abs_gap:>12.3e} {d.rel_gap:>12.3e} {status:>8}" + ) + overall = ( + _colorize("PASS", _ANSI_GREEN, bold=True) + if ok + else _colorize("FAIL", _ANSI_RED, bold=True) + ) + print(f"Loss alignment overall: {overall}") + return ok + + +def _print_param_report( + label: str, + diffs: list[ParamDiff], + only_a: list[str], + only_b: list[str], + top_k: int = 10, +) -> None: + summary = _summarize_param_diffs(diffs) + shared_count = len(diffs) + mismatch_count = len(only_a) + len(only_b) + print(f"\n=== {label} parameter comparison ===") + print( + f" key_coverage: shared={shared_count} only_a={len(only_a)} " + f"only_b={len(only_b)} total_union={shared_count + mismatch_count}" + ) + if only_a or only_b: + warn = ( + f"[warn] parameter key mismatch: only_a={only_a[:5]}{'...' if len(only_a) > 5 else ''} " + f"only_b={only_b[:5]}{'...' if len(only_b) > 5 else ''}" + ) + print(_colorize(warn, _ANSI_YELLOW, bold=True)) + print(f" global_norm_gap: {summary['global_norm_gap']}") + print(f" max_norm_gap: {summary['max_norm_gap']}") + print(f" global_delta_gap: {summary['global_delta_gap']}") + print(f" max_delta_gap: {summary['max_delta_gap']}") + worst = sorted(diffs, key=lambda d: d.max_diff, reverse=True)[:top_k] + print(f" top-{len(worst)} tensors by max_diff:") + for d in worst: + print( + f" {d.name[:80]:<80} max={d.max_diff:.3e} " + f"mean={d.mean_diff:.3e} norm_gap={d.norm_gap:.3e} " + f"delta_gap={d.delta_gap:.3e} " + f"shape_match={d.shape_match}" + ) + + +def _print_diff_file_report( + title: str, + gaps: list[DiffFileGap], + only_a: list[str], + only_b: list[str], + top_k: int = 10, +) -> None: + summary = _summarize_diff_gaps(gaps) + shared_count = len(gaps) + mismatch_count = len(only_a) + len(only_b) + print(f"\n=== {title} ===") + print( + f" key_coverage: shared={shared_count} only_a={len(only_a)} " + f"only_b={len(only_b)} total_union={shared_count + mismatch_count}" + ) + if only_a or only_b: + warn = ( + f"[warn] parameter key mismatch: only_a={only_a[:5]}{'...' if len(only_a) > 5 else ''} " + f"only_b={only_b[:5]}{'...' if len(only_b) > 5 else ''}" + ) + print(_colorize(warn, _ANSI_YELLOW, bold=True)) + for k, v in summary.items(): + print(f" {k}: {v}") + worst = sorted(gaps, key=lambda g: g.max_abs_gap, reverse=True)[:top_k] + print(f" top-{len(worst)} tensors by max_abs_gap:") + for g in worst: + print( + f" {g.name[:80]:<80} max_abs_gap={g.max_abs_gap:.3e} " + f"max_abs_rel_gap={g.max_abs_rel_gap:.3e} " + f"mean_abs_gap={g.mean_abs_gap:.3e} " + f"mean_abs_rel_gap={g.mean_abs_rel_gap:.3e} " + f"l2_gap={g.l2_gap:.3e} " + f"l2_rel_gap={g.l2_rel_gap:.3e} " + f"numel_match={g.numel_match}" + ) + + +def _print_diff_alignment_report( + summary: dict[str, float], + *, + rel_gap_tol: float, + title: str = "diff.pt alignment", +) -> bool: + l2_rel_max_value = float(summary.get("l2_rel_gap_max", float("inf"))) + global_l2_rel_value = float(summary.get("global_l2_rel_gap", float("inf"))) + l2_ok = l2_rel_max_value < rel_gap_tol + global_l2_ok = global_l2_rel_value < rel_gap_tol + aligned_ok = l2_ok and global_l2_ok + + print(f"\n=== {title} ===") + print( + f"max(l2_rel_gap) < {rel_gap_tol:.3e}: " + f"{l2_rel_max_value:.3e} " + f"{_colorize('PASS', _ANSI_GREEN, bold=True) if l2_ok else _colorize('FAIL', _ANSI_RED, bold=True)}" + ) + print( + f"global_l2_rel_gap < {rel_gap_tol:.3e}: " + f"{global_l2_rel_value:.3e} " + f"{_colorize('PASS', _ANSI_GREEN, bold=True) if global_l2_ok else _colorize('FAIL', _ANSI_RED, bold=True)}" + ) + print( + f"{title} overall: " + + ( + _colorize("PASS", _ANSI_GREEN, bold=True) + if aligned_ok + else _colorize("FAIL", _ANSI_RED, bold=True) + ) + ) + return aligned_ok + + +def _print_unmatched_tensor_report( + label: str, + diffs: list[ParamDiff], + side: str, + top_k: int = 10, +) -> None: + print(f"\n=== {label} unmatched full-delta tensors ({side} vs zeros) ===") + if not diffs: + print(" none") + return + summary = _summarize_param_diffs(diffs) + print(f" global_norm_gap: {summary['global_norm_gap']}") + print(f" max_norm_gap: {summary['max_norm_gap']}") + print(f" global_delta_gap: {summary['global_delta_gap']}") + print(f" max_delta_gap: {summary['max_delta_gap']}") + worst = sorted(diffs, key=lambda d: d.max_diff, reverse=True)[:top_k] + print(f" top-{len(worst)} tensors by max_diff:") + for d in worst: + print( + f" {d.name[:80]:<80} max={d.max_diff:.3e} " + f"mean={d.mean_diff:.3e} norm_gap={d.norm_gap:.3e} " + f"delta_gap={d.delta_gap:.3e}" + ) + + +def _print_full_delta_alignment_report( + param_diffs: list[ParamDiff], + *, + rel_gap_tol: float, + heading: str = "full-delta fp32 alignment", +) -> bool: + summary = _summarize_param_diffs(param_diffs) + global_value = float(summary.get("global_delta_gap", float("inf"))) + max_value = float(summary.get("max_delta_gap", float("inf"))) + global_ok = global_value < rel_gap_tol + max_ok = max_value < rel_gap_tol + ok = global_ok and max_ok + print(f"\n=== {heading} ===") + print( + f"global_delta_gap < {rel_gap_tol:.3e}: " + f"{global_value:.3e} " + f"{_colorize('PASS', _ANSI_GREEN, bold=True) if global_ok else _colorize('FAIL', _ANSI_RED, bold=True)}" + ) + print( + f"max_delta_gap < {rel_gap_tol:.3e}: " + f"{max_value:.3e} " + f"{_colorize('PASS', _ANSI_GREEN, bold=True) if max_ok else _colorize('FAIL', _ANSI_RED, bold=True)}" + ) + return ok + + +def _print_name_gap_report( + *, + norm_gaps: dict[str, float], + delta_gaps: dict[str, float], + l2_delta_a: dict[str, float] | None = None, + l2_delta_b: dict[str, float] | None = None, + top_k: int = 10, + heading: str = "Top delta_gap parameters (shared names)", + l2_note: str | None = None, +) -> None: + """Print top-k shared parameter gaps sorted by delta_gap. + + When ``l2_delta_a`` / ``l2_delta_b`` are provided, each column is the L2 + norm of the saved fp32 update tensor ``Δθ = current - initial`` for that + dump (so you can see whether a parameter actually moved vs. stayed at 0). + """ + shared_names = sorted(set(norm_gaps.keys()) & set(delta_gaps.keys())) + ranked = sorted( + shared_names, + key=lambda name: delta_gaps.get(name, float("-inf")), + reverse=True, + )[:top_k] + print(f"\n=== {heading} ===") + if l2_delta_a is not None and l2_delta_b is not None: + print( + l2_note + or "(||A||, ||B|| are L2 norms of Δθ from each dump's delta_tensors_fp32.)" + ) + print( + f"{'name':<72} {'||A||':>12} {'||B||':>12} " + f"{'norm_gap':>12} {'delta_gap':>12}" + ) + for name in ranked: + norm_value = norm_gaps.get(name, float("inf")) + delta_value = delta_gaps.get(name, float("inf")) + na = l2_delta_a.get(name, float("nan")) + nb = l2_delta_b.get(name, float("nan")) + print( + f"{name[:72]:<72} {na:>12.3e} {nb:>12.3e} " + f"{norm_value:>12.3e} {delta_value:>12.3e}" + ) + else: + print(f"{'name':<80} {'norm_gap':>12} {'delta_gap':>12}") + for name in ranked: + norm_value = norm_gaps.get(name, float("inf")) + delta_value = delta_gaps.get(name, float("inf")) + print(f"{name[:80]:<80} {norm_value:>12.3e} {delta_value:>12.3e}") + + +def _print_forward_compare_report( + dump_a: str, + dump_b: str, + *, + rel_gap_tol: float, + max_abs_tol: float, +) -> bool: + summary_a = _discover_forward_summary_files(dump_a) + summary_b = _discover_forward_summary_files(dump_b) + if not summary_a and not summary_b: + print("\n[info] no forward.step*.summary.pt in either dump.") + return True + if not summary_a or not summary_b: + print( + "\n[warn] forward summaries present on only one side " + f"(a={bool(summary_a)}, b={bool(summary_b)}); skipping forward comparison." + ) + return False + + steps_a = set(summary_a.keys()) + steps_b = set(summary_b.keys()) + shared_steps = sorted(steps_a & steps_b) + only_a_steps = sorted(steps_a - steps_b) + only_b_steps = sorted(steps_b - steps_a) + + print("\n=== valid-token forward output comparison ===") + print( + f" step_coverage: shared={len(shared_steps)} only_a={len(only_a_steps)} " + f"only_b={len(only_b_steps)}" + ) + if only_a_steps or only_b_steps: + print( + _colorize( + f"[warn] forward summary mismatch: only_a={only_a_steps[:6]} " + f"only_b={only_b_steps[:6]}", + _ANSI_YELLOW, + bold=True, + ) + ) + + ok = not only_a_steps and not only_b_steps + tensor_gaps: list[ForwardTensorGap] = [] + global_max_gap = 0.0 + global_mean_gap = 0.0 + for step in shared_steps: + pa = _load_forward_compare_payload(summary_a[step]) + pb = _load_forward_compare_payload(summary_b[step]) + meta_ok = ( + pa.get("output_kind") == pb.get("output_kind") + and int(pa.get("global_input_tokens", -1)) + == int(pb.get("global_input_tokens", -2)) + and int(pa.get("global_valid_output_numel", -1)) + == int(pb.get("global_valid_output_numel", -2)) + and int(pa.get("global_mask_mismatch_ranks", -1)) + == int(pb.get("global_mask_mismatch_ranks", -2)) + ) + ok = ok and meta_ok + ranks_a = { + int(x["rank"]): x for x in pa.get("per_rank", []) if isinstance(x, dict) + } + ranks_b = { + int(x["rank"]): x for x in pb.get("per_rank", []) if isinstance(x, dict) + } + rank_shared = sorted(set(ranks_a) & set(ranks_b)) + ok = ok and set(ranks_a) == set(ranks_b) + for rank in rank_shared: + ra = ranks_a[rank] + rb = ranks_b[rank] + pos_a = ra.get("valid_token_positions") + pos_b = rb.get("valid_token_positions") + positions_ok = ( + isinstance(pos_a, torch.Tensor) + and isinstance(pos_b, torch.Tensor) + and torch.equal(pos_a, pos_b) + ) + ok = ok and positions_ok + ta = ra.get("valid_forward_fp32") + tb = rb.get("valid_forward_fp32") + if isinstance(ta, torch.Tensor) and isinstance(tb, torch.Tensor): + tg = _compare_forward_tensor( + ta, tb, key=f"step={step},rank={rank},valid_forward" + ) + tensor_gaps.append(tg) + global_max_gap = max(global_max_gap, tg.max_diff) + global_mean_gap = max(global_mean_gap, tg.mean_diff) + ok = ( + ok + and tg.shape_match + and tg.max_diff <= max_abs_tol + and tg.l2_rel_gap < rel_gap_tol + ) + else: + ok = False + print( + _colorize( + f"[warn] missing valid_forward_fp32 for step={step}, rank={rank}", + _ANSI_YELLOW, + bold=True, + ) + ) + + print( + f" valid_forward_max_diff_max={global_max_gap:.3e}, " + f"valid_forward_mean_diff_max={global_mean_gap:.3e} " + f"(tol abs={max_abs_tol:.3e}, rel={rel_gap_tol:.3e})" + ) + if tensor_gaps: + worst = sorted(tensor_gaps, key=lambda g: g.max_diff, reverse=True)[:10] + print(f" top-{len(worst)} forward tensor gaps by max_diff:") + for g in worst: + print( + f" {g.key:<52} max={g.max_diff:.3e} " + f"mean={g.mean_diff:.3e} l2_rel={g.l2_rel_gap:.3e} " + f"shape_match={g.shape_match}" + ) + + print( + " forward_compare overall: " + + ( + _colorize("PASS", _ANSI_GREEN, bold=True) + if ok + else _colorize("FAIL", _ANSI_RED, bold=True) + ) + ) + return ok + + +# ----------------------------------------------------------------------------- +# Main +# ----------------------------------------------------------------------------- + + +def run_comparison( + dump_a: str, + dump_b: str, + *, + loss_atol: float, + loss_rtol: float, + diff_rel_gap_tol: float, + full_delta_rel_gap_tol: float, + compare_initial: bool, +) -> bool: + print(f"[compare] dump_a={dump_a}") + print(f"[compare] dump_b={dump_b}") + + stats_a = _load_stats(dump_a) + stats_b = _load_stats(dump_b) + loss_diffs = _compare_losses(stats_a, stats_b, atol=loss_atol, rtol=loss_rtol) + loss_ok = _print_loss_report(loss_diffs, loss_atol, loss_rtol) + diff_ok = True + full_delta_ok = True + grad_ok = True + grad_full_ok = True + forward_ok = _print_forward_compare_report( + dump_a, + dump_b, + rel_gap_tol=diff_rel_gap_tol, + max_abs_tol=full_delta_rel_gap_tol, + ) + + diff_a = Path(dump_a) / "diff.pt" + diff_b = Path(dump_b) / "diff.pt" + if diff_a.exists() and diff_b.exists(): + sig_a = _load_diff_signatures(str(diff_a)) + sig_b = _load_diff_signatures(str(diff_b)) + gaps, only_a, only_b = _compare_diff_signatures(sig_a, sig_b) + if only_a or only_b: + print("\n=== diff.pt param keys only in one dump ===") + print(f" only_in_a: {only_a}") + print(f" only_in_b: {only_b}") + diff_summary = _summarize_diff_gaps(gaps) + diff_summary["global_l2_rel_gap"] = _global_l2_update_relative_gap(sig_a, sig_b) + diff_ok = ( + float(diff_summary.get("l2_rel_gap_max", float("inf"))) < diff_rel_gap_tol + and float(diff_summary.get("global_l2_rel_gap", float("inf"))) + < diff_rel_gap_tol + ) + + full_a = _load_full_delta_tensors(str(diff_a)) + full_b = _load_full_delta_tensors(str(diff_b)) + if full_a and full_b: + shared_diffs, only_a_diffs, only_b_diffs = _compare_full_deltas( + full_a, full_b + ) + all_full_diffs = [*shared_diffs, *only_a_diffs, *only_b_diffs] + full_summary = _summarize_param_diffs(all_full_diffs) + full_delta_ok = ( + float(full_summary.get("global_delta_gap", float("inf"))) + < full_delta_rel_gap_tol + and float(full_summary.get("max_delta_gap", float("inf"))) + < full_delta_rel_gap_tol + ) + + norm_gaps: dict[str, float] = {} + for g in gaps: + norm_gaps[g.name] = float(g.l2_rel_gap) + + delta_gaps: dict[str, float] = {} + l2_delta_a: dict[str, float] = {} + l2_delta_b: dict[str, float] = {} + for d in shared_diffs: + delta_gaps[d.name] = float(d.delta_gap) + l2_delta_a[d.name] = float(d.l2_a) + l2_delta_b[d.name] = float(d.l2_b) + + _print_name_gap_report( + norm_gaps=norm_gaps, + delta_gaps=delta_gaps, + l2_delta_a=l2_delta_a, + l2_delta_b=l2_delta_b, + top_k=10, + ) + else: + print( + "\n[info] full fp32 delta tensors not found in both diff.pt " + "(enable test_config.save_full_diff_tensors_fp32=true in both runs)." + ) + else: + final_a = Path(dump_a) / "params.pt" + final_b = Path(dump_b) / "params.pt" + if final_a.exists() and final_b.exists(): + print( + "\n[info] diff.pt not found in both dumps; " + "falling back to legacy params.pt comparison." + ) + state_a = _load_state_dict(str(final_a)) + state_b = _load_state_dict(str(final_b)) + diffs, only_a, only_b = _compare_state_dicts(state_a, state_b) + _print_param_report("Final", diffs, only_a, only_b) + else: + print( + f"\n[info] skipping final param comparison " + f"(diff exists: a={diff_a.exists()}, b={diff_b.exists()}; " + f"params exists: a={final_a.exists()}, b={final_b.exists()})" + ) + + if compare_initial: + init_a = Path(dump_a) / "params_initial.pt" + init_b = Path(dump_b) / "params_initial.pt" + if init_a.exists() and init_b.exists(): + state_a = _load_state_dict(str(init_a)) + state_b = _load_state_dict(str(init_b)) + diffs, only_a, only_b = _compare_state_dicts(state_a, state_b) + _print_param_report("Initial", diffs, only_a, only_b) + + lg_a = Path(dump_a) / "last_grads.pt" + lg_b = Path(dump_b) / "last_grads.pt" + if lg_a.is_file() and lg_b.is_file(): + print("\n=== last_grads.pt ===") + _print_last_grads_requires_grad_one_line(str(lg_a), str(lg_b)) + + gsig_a = _load_grad_signatures(str(lg_a)) + gsig_b = _load_grad_signatures(str(lg_b)) + ggaps, g_only_a, g_only_b = _compare_grad_signatures(gsig_a, gsig_b) + if g_only_a or g_only_b: + print(" param keys only in one dump:") + print(f" only_a={g_only_a[:8]}{'...' if len(g_only_a) > 8 else ''}") + print(f" only_b={g_only_b[:8]}{'...' if len(g_only_b) > 8 else ''}") + grad_summary = _summarize_diff_gaps(ggaps) + grad_summary["global_l2_rel_gap"] = _global_grad_l2_relative_gap(gsig_a, gsig_b) + grad_ok = ( + float(grad_summary.get("l2_rel_gap_max", float("inf"))) < diff_rel_gap_tol + and float(grad_summary.get("global_l2_rel_gap", float("inf"))) + < diff_rel_gap_tol + ) + + full_ga = _load_full_grad_tensors(str(lg_a)) + full_gb = _load_full_grad_tensors(str(lg_b)) + if full_ga and full_gb: + shared_g, only_a_g, only_b_g = _compare_full_deltas(full_ga, full_gb) + all_g = [*shared_g, *only_a_g, *only_b_g] + full_g_summary = _summarize_param_diffs(all_g) + grad_full_ok = ( + float(full_g_summary.get("global_delta_gap", float("inf"))) + < full_delta_rel_gap_tol + and float(full_g_summary.get("max_delta_gap", float("inf"))) + < full_delta_rel_gap_tol + ) + + norm_gaps_g: dict[str, float] = {} + for g in ggaps: + norm_gaps_g[g.name] = float(g.l2_rel_gap) + delta_gaps_g: dict[str, float] = {} + l2_ga: dict[str, float] = {} + l2_gb: dict[str, float] = {} + for d in shared_g: + delta_gaps_g[d.name] = float(d.delta_gap) + l2_ga[d.name] = float(d.l2_a) + l2_gb[d.name] = float(d.l2_b) + _print_name_gap_report( + norm_gaps=norm_gaps_g, + delta_gaps=delta_gaps_g, + l2_delta_a=l2_ga, + l2_delta_b=l2_gb, + top_k=10, + heading="Top last_grads gaps (shared names)", + l2_note=( + "(||A||, ||B|| are L2 norms of grad_tensors_fp32; " + "norm_gap from signature l2_rel; zero ||·||₂ is an all-zero grad.)" + ), + ) + print( + f" alignment: signature={'PASS' if grad_ok else 'FAIL'} " + f"full_tensor={'PASS' if grad_full_ok else 'FAIL'} " + f"(tol signature={diff_rel_gap_tol} elementwise={full_delta_rel_gap_tol})" + ) + elif full_ga or full_gb: + grad_full_ok = False + print( + "\n[warn] last_grads.pt grad_tensors_fp32 present on only one side " + f"(a={bool(full_ga)}, b={bool(full_gb)}); skipping full grad tensor compare." + ) + print( + f" alignment: signature={'PASS' if grad_ok else 'FAIL'} " + "full_tensor=FAIL (asymmetric dumps)" + ) + else: + print( + "\n [info] no grad_tensors_fp32 in both dumps " + "(enable test_config.save_full_last_grad_tensors_fp32=true); " + f"signature_only={'PASS' if grad_ok else 'FAIL'}." + ) + elif lg_a.is_file() or lg_b.is_file(): + print( + "\n[warn] last_grads.pt present on only one side " + f"(a={lg_a.is_file()}, b={lg_b.is_file()}); skipping grad comparison." + ) + + return ( + loss_ok + and diff_ok + and full_delta_ok + and grad_ok + and grad_full_ok + and forward_ok + ) + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser( + description="Compare two ArchonEngine training-test dump_dirs." + ) + parser.add_argument("--dump-a", type=str, required=True, help="First dump dir.") + parser.add_argument("--dump-b", type=str, required=True, help="Second dump dir.") + parser.add_argument( + "--loss-atol", + type=float, + default=1e-6, + help="Absolute tolerance for per-step loss alignment (default 1e-6).", + ) + parser.add_argument( + "--loss-rtol", + type=float, + default=1e-3, + help="Relative tolerance for per-step loss alignment (default 1e-3).", + ) + parser.add_argument( + "--diff-rel-gap-tol", + type=float, + default=1e-2, + help=( + "Relative-gap threshold for diff.pt (max(l2_rel_gap), global_l2_rel_gap) " + "and, when both dumps include last_grads.pt, the same checks on " + "gradient signature gaps. Default 1e-2." + ), + ) + parser.add_argument( + "--full-delta-rel-gap-tol", + type=float, + default=1e-2, + help=( + "Threshold for full tensor alignment: both global_delta_gap and " + "max_delta_gap must be < tol (diff.pt delta_tensors_fp32 when present; " + "last_grads.pt grad_tensors_fp32 when present). Default 1e-2." + ), + ) + parser.add_argument( + "--compare-initial", + action="store_true", + help="Also compare legacy params_initial.pt if present in both dumps.", + ) + args = parser.parse_args(argv) + + ok = run_comparison( + dump_a=args.dump_a, + dump_b=args.dump_b, + loss_atol=args.loss_atol, + loss_rtol=args.loss_rtol, + diff_rel_gap_tol=args.diff_rel_gap_tol, + full_delta_rel_gap_tol=args.full_delta_rel_gap_tol, + compare_initial=args.compare_initial, + ) + return 0 if ok else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/experimental/archon/torchrun/run_archon_training_test.py b/tests/experimental/archon/torchrun/run_archon_training_test.py new file mode 100644 index 0000000000..c9a3c1ca91 --- /dev/null +++ b/tests/experimental/archon/torchrun/run_archon_training_test.py @@ -0,0 +1,1161 @@ +"""Training-side smoke/benchmark runner for ``ArchonLMEngine``. + +This script runs ``N`` training steps of ``ArchonLMEngine.train_batch`` under +``torch.distributed`` and records global per-step loss, elapsed time and peak +GPU memory. At the end it optionally dumps ``diff.pt`` (parameter update +statistics) so two runs can be compared offline via +:mod:`compare_training_dumps`. + +Launch with torchrun:: + + torchrun --nproc_per_node=$WORLD_SIZE \\ + tests/experimental/archon/torchrun/run_archon_training_test.py \\ + --config tests/experimental/archon/torchrun/archon_training_test.yaml \\ + test_config.step=4 \\ + test_config.data_dir=/path/to/data + +Primary outputs land under ``/``: + +- ``stats.jsonl`` -- one global JSON record per step (rank-aggregated) +- ``diff.pt`` -- per-parameter update stats (saved on rank 0 only) +- ``last_grads.pt`` -- optional per-parameter gradient stats and, by default, full + fp32 gradient tensors (``grad_tensors_fp32``) after the final step + (``test_config.dump_last_grads=true``; see :class:`TestOnlyConfig`). + Each ``params`` entry includes ``requires_grad``; ``requires_grad_meta`` summarizes + all ``named_parameters()`` without ``full_tensor``. + +The runner is intentionally narrow: inputs are assumed to be ``list[Tensor]`` +(1-D token ids) per ``.pt`` file, and the loss function is hard-wired to a +typical ``grpo_loss_fn`` setup (or ``ppo_critic_loss_fn`` when +``test_config.is_critic`` is true). + +``test_config.is_critic`` is applied to ``TrainEngineConfig.is_critic`` before +constructing the engine so the model (e.g. ``score`` vs ``output`` head), HF +load path, and loss stay aligned. Checkpoints are always read from YAML +``actor.path`` (``cfg.engine.path``). For critic models, whether the value +head is filled from the HF ``lm_head`` tensor depends on the Archon +``state_dict_adapter`` for that architecture (see engine load warnings for +missing / unexpected keys). +""" + +from __future__ import annotations + +import dataclasses +import functools +import glob +import json +import math +import os +import sys +import time +import types +from pathlib import Path +from typing import Any + +import torch +import torch.distributed as dist + +# Make repo root importable when invoked via torchrun from any cwd. +_THIS_FILE = Path(__file__).resolve() +_REPO_ROOT = _THIS_FILE.parents[4] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from tests.experimental.archon.torchrun.training_test_config import ( # noqa: E402 + ArchonTrainingTestConfig, + ensure_dump_dir, + load_training_test_config, +) + +from areal.api.io_struct import FinetuneSpec, SaveLoadMeta # noqa: E402 +from areal.experimental.archon.utils import strip_wrapper_prefixes # noqa: E402 +from areal.experimental.engine.archon_engine import ArchonLMEngine # noqa: E402 +from areal.infra.platforms import current_platform # noqa: E402 +from areal.trainer.ppo.actor import grpo_loss_fn # noqa: E402 +from areal.trainer.ppo.critic import ppo_loss_fn as ppo_critic_loss_fn # noqa: E402 +from areal.utils.data import concat_batch # noqa: E402 +from areal.utils.logging import getLogger # noqa: E402 +from areal.utils.network import find_free_ports # noqa: E402 + +# Fixed prompt ratio for synthetic loss mask construction. +_PROMPT_RATIO = 0.3 + +# ----------------------------------------------------------------------------- +# Distributed setup +# ----------------------------------------------------------------------------- + + +def _setup_distributed_environment() -> tuple[int, int]: + """Initialize the global process group using torchrun env vars.""" + if dist.is_initialized(): + return dist.get_rank(), dist.get_world_size() + + os.environ.setdefault("RANK", "0") + os.environ.setdefault("LOCAL_RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + os.environ.setdefault("MASTER_ADDR", "localhost") + os.environ.setdefault("MASTER_PORT", str(find_free_ports(1)[0])) + + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + dist.init_process_group( + backend="nccl", + init_method=(f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"), + world_size=world_size, + rank=rank, + ) + current_platform.set_device(int(os.environ["LOCAL_RANK"])) + return rank, world_size + + +# ----------------------------------------------------------------------------- +# Data loading / trajectory construction +# ----------------------------------------------------------------------------- + + +def _list_step_files(data_dir: str) -> list[str]: + """Sort .pt files in ``data_dir`` lexicographically (ascii dict order).""" + if not os.path.isdir(data_dir): + raise FileNotFoundError(f"data_dir does not exist or is not a dir: {data_dir}") + files = sorted(glob.glob(os.path.join(data_dir, "*.pt"))) + if not files: + raise FileNotFoundError(f"No .pt files under {data_dir}") + return files + + +def _load_sequences(pt_path: str) -> list[torch.Tensor]: + seqs = torch.load(pt_path, map_location="cpu", weights_only=True) + if not isinstance(seqs, list) or not seqs: + raise ValueError( + f"Expected non-empty list[Tensor] in {pt_path}, got {type(seqs)}" + ) + for i, s in enumerate(seqs): + if not isinstance(s, torch.Tensor) or s.ndim != 1: + raise ValueError( + f"Entry {i} of {pt_path} is not a 1-D tensor: " + f"type={type(s)}, ndim={getattr(s, 'ndim', None)}" + ) + return seqs + + +def _synthetic_advantages(seq_len: int, global_idx: int) -> torch.Tensor: + """Deterministic per-token advantages for tests (CPU float32). + + Depends on ``seq_len`` and ``global_idx`` so that changing truncation / + ``max_tokens_per_mb`` or which sequence is loaded changes targets in a + structured, reproducible way (unlike i.i.d. Gaussian noise). + """ + if seq_len <= 0: + raise ValueError("seq_len must be positive.") + t = torch.arange(seq_len, dtype=torch.float32) + denom = max(float(seq_len), 1.0) + phase = 2.0 * math.pi * (t + 0.5) / denom + seq_phase = 2.0 * math.pi * float(global_idx % 997) / 997.0 + return (torch.sin(phase + seq_phase) * 0.5).unsqueeze(0) + + +def _synthetic_old_values(seq_len: int, global_idx: int) -> torch.Tensor: + """Deterministic per-token old values (``input_data['values']``) for tests. + + Replaces i.i.d. ``randn`` so ``returns = values + advantages`` is a smooth + function of position and ``global_idx`` only, making critic targets easier + to reason about and reproduce across DTA / batching. + """ + if seq_len <= 0: + raise ValueError("seq_len must be positive.") + t = torch.arange(seq_len, dtype=torch.float32) + denom = max(float(seq_len), 1.0) + # Different phase from :func:`_synthetic_advantages` so the sum is not redundant. + phase = 2.0 * math.pi * (1.5 * t + 0.25) / denom + seq_phase = 2.0 * math.pi * float((global_idx * 2 + 1) % 991) / 991.0 + return (torch.cos(phase + seq_phase) * 0.4).unsqueeze(0) + + +def _build_trajectory( + input_ids: torch.Tensor, + global_idx: int, + base_seed: int, + max_tokens: int, + device: torch.device, +) -> dict[str, Any]: + """Wrap one 1-D token sequence as a GRPO-ready trajectory dict. + + The per-sequence RNG seed is derived from ``global_idx`` for logprobs. + ``advantages`` and ``values`` (hence ``returns``) are **non-random** (see + :func:`_synthetic_advantages` and :func:`_synthetic_old_values`) so + critic/actor tests behave reproducibly when sequence length or batch + composition changes. + """ + assert input_ids.ndim == 1 + seq_len = int(min(int(input_ids.numel()), int(max_tokens))) + if seq_len <= 0: + raise ValueError(f"Sequence at idx {global_idx} has non-positive length.") + + ids = input_ids[:seq_len].long().unsqueeze(0).contiguous() + attention_mask = torch.ones(1, seq_len, dtype=torch.long) + loss_mask = torch.zeros(1, seq_len) + prompt_len = max(1, int(seq_len * _PROMPT_RATIO)) + loss_mask[:, prompt_len:] = 1.0 + + gen = torch.Generator(device="cpu").manual_seed(int(base_seed) + int(global_idx)) + logprobs = torch.randn(1, seq_len, generator=gen) * 0.5 - 2.0 + old_logprobs = logprobs.clone() + advantages = _synthetic_advantages(seq_len, global_idx) + rewards = torch.randint(0, 2, (1,), generator=gen).float() + values = _synthetic_old_values(seq_len, global_idx) + returns = values + advantages + + traj = { + "input_ids": ids, + "attention_mask": attention_mask, + "loss_mask": loss_mask, + "logprobs": logprobs, + "old_logprobs": old_logprobs, + "advantages": advantages, + "rewards": rewards, + "values": values, + "returns": returns, + "prox_logp": old_logprobs.clone(), + } + return { + k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in traj.items() + } + + +def _build_local_trajectories( + seqs: list[torch.Tensor], + dp_rank: int, + dp_world_size: int, + base_seed: int, + max_tokens: int, + device: torch.device, +) -> list[dict[str, Any]]: + """Each rank owns a disjoint stride of the sequence list. + + Per-rank sequence counts may differ (e.g. when ``len(seqs)`` is not a + multiple of ``dp_world_size``). No all-gather / load-balancing redistribution. + """ + if len(seqs) < dp_world_size: + raise ValueError( + f"Need at least dp_world_size={dp_world_size} sequences, got {len(seqs)}." + ) + out: list[dict[str, Any]] = [] + for global_i in range(dp_rank, len(seqs), dp_world_size): + out.append( + _build_trajectory( + input_ids=seqs[global_i], + global_idx=global_i, + base_seed=base_seed, + max_tokens=max_tokens, + device=device, + ) + ) + return out + + +# ----------------------------------------------------------------------------- +# Loss function / engine patching +# ----------------------------------------------------------------------------- + + +# Reasonable defaults mirroring ``tests/experimental/archon/test_grpo.py``. +_GRPO_KW: dict[str, Any] = dict( + eps_clip=0.2, + eps_clip_higher=None, + c_clip=None, + importance_sampling_level="token", + current_version=1, + prox_logp_method="recompute", + use_sapo_loss=False, + use_decoupled_loss=False, +) + + +def _loss_weight_fn(input_data: dict[str, Any]) -> torch.Tensor: + mask = input_data["loss_mask"] + return mask.count_nonzero() + + +def _make_loss_fn(cfg: ArchonTrainingTestConfig): + """Build test loss with optional entropy regularization.""" + if cfg.test_config.is_critic: + return functools.partial(ppo_critic_loss_fn, eps_clip=3.0) + + base_loss_fn = functools.partial(grpo_loss_fn, **_GRPO_KW) + entropy_coef = float(cfg.test_config.entropy_coef) + entropy_mode = str(cfg.test_config.entropy_mode) + if entropy_coef <= 0: + return base_loss_fn + + def _loss_fn(logprobs, entropy, input_data, **kwargs): + base_loss = base_loss_fn(logprobs, entropy, input_data, **kwargs) + loss_mask = input_data["loss_mask"].bool() + valid_entropy = entropy.float().masked_select(loss_mask) + if valid_entropy.numel() == 0: + entropy_term = base_loss * 0.0 + elif entropy_mode == "mean": + entropy_term = -valid_entropy.mean() + else: + entropy_term = -valid_entropy.sum() + return base_loss + entropy_coef * entropy_term + + return _loss_fn + + +def _patch_engine_for_test( + engine: ArchonLMEngine, + disable_optimizer: bool, + *, + dump_last_grads: bool = False, + num_training_steps: int = 0, + last_grad_holder: dict[str, Any] | None = None, + save_full_last_grad_tensors_fp32: bool = True, +) -> None: + """Inject optional optimizer no-ops onto the engine.""" + if not disable_optimizer: + return + + noop_step_counter = {"n": 0} + + def _noop_zero_grad(self): + for p in self._get_all_parameters(): + if p.grad is not None: + p.grad = None + + def _noop_step(self): + grad_norm = 0.0 + for p in self._get_all_parameters(): + if p.grad is not None: + grad_norm += float(p.grad.detach().float().norm().item()) ** 2 + grad_norm = grad_norm**0.5 + noop_step_counter["n"] += 1 + if ( + dump_last_grads + and last_grad_holder is not None + and noop_step_counter["n"] == int(num_training_steps) + ): + payload = _build_grad_snapshot_payload( + self, + save_full_grad_tensors_fp32=save_full_last_grad_tensors_fp32, + ) + if payload is not None: + last_grad_holder.clear() + last_grad_holder.update(payload) + _noop_zero_grad(self) + return { + "update_successful": 1.0, + "grad_norm": grad_norm, + "lr": 0.0, + } + + engine.optimizer_zero_grad = types.MethodType(_noop_zero_grad, engine) + engine.optimizer_step = types.MethodType(_noop_step, engine) + + +# ----------------------------------------------------------------------------- +# Engine lifecycle +# ----------------------------------------------------------------------------- + + +def _resolve_test_hf_export_dir(cfg: ArchonTrainingTestConfig) -> str: + """Return absolute export dir, or ``\"\"`` when ``save_hf_checkpoint_dir`` is None.""" + raw = cfg.test_config.save_hf_checkpoint_dir + if raw is None: + return "" + s = str(raw).strip() + if not s: + return "" + return os.path.abspath(os.path.expanduser(s)) + + +def _create_engine(cfg: ArchonTrainingTestConfig) -> ArchonLMEngine: + """Construct + initialize an ``ArchonLMEngine`` from the test config.""" + parallel_strategy = cfg.parallel.to_parallel_strategy() + + engine_cfg = cfg.engine + critic = bool(cfg.test_config.is_critic) + if critic != bool(engine_cfg.is_critic) and int(os.environ.get("RANK", "0")) == 0: + getLogger("[ArchonTrainingTest]").warning( + "test_config.is_critic=%s overrides actor.is_critic=%s for engine " + "(model head + checkpoint layout must match critic vs actor).", + critic, + engine_cfg.is_critic, + ) + replace_kw: dict[str, Any] = {"is_critic": critic} + if cfg.test_config.disable_optimizer: + # Skip optimizer creation entirely so no Adam state is allocated. + replace_kw["optimizer"] = None + engine_cfg = dataclasses.replace(engine_cfg, **replace_kw) + + engine = ArchonLMEngine(engine_cfg) + engine.create_process_group(parallel_strategy=parallel_strategy) + + ft_spec = FinetuneSpec( + total_train_epochs=1, + dataset_size=max(1, int(cfg.test_config.step)), + train_batch_size=1, + ) + engine.initialize(addr=None, ft_spec=ft_spec) + + hf_export = _resolve_test_hf_export_dir(cfg) + if hf_export: + meta = SaveLoadMeta( + path=hf_export, + weight_format="hf", + with_optim=False, + tokenizer=engine.tokenizer, + ) + engine.save(meta) + + return engine + + +def _destroy_engine(engine: ArchonLMEngine | None) -> None: + if engine is not None: + engine.destroy() + if dist.is_initialized(): + dist.destroy_process_group() + + +# ----------------------------------------------------------------------------- +# Parameter diff dump +# ----------------------------------------------------------------------------- + + +def _materialize_full_param(param: torch.Tensor) -> torch.Tensor: + """Return a full (unsharded) tensor for one parameter.""" + from torch.distributed.tensor import DTensor + + if isinstance(param, DTensor): + return param.full_tensor() + return param + + +def _to_dump_name_tensors( + engine: ArchonLMEngine, raw_name: str, tensor: torch.Tensor +) -> list[tuple[str, torch.Tensor]]: + """Convert one Archon parameter into dump-name/tensor pairs. + + Prefer HuggingFace keys when ``state_dict_adapter`` is available; otherwise + use wrapper-stripped Archon keys. + """ + adapter = engine.state_dict_adapter + if adapter is not None: + mapped = adapter.convert_single_to_hf(raw_name, tensor) + if mapped: + return [(strip_wrapper_prefixes(name), value) for name, value in mapped] + return [(strip_wrapper_prefixes(raw_name), tensor)] + + +def _build_grad_snapshot_payload( + engine: ArchonLMEngine, + *, + save_full_grad_tensors_fp32: bool = True, +) -> dict[str, Any] | None: + """Collect per-parameter gradient statistics (all ranks must call). + + Optionally embeds full CPU fp32 gradient tensors under ``grad_tensors_fp32`` + (same role as ``delta_tensors_fp32`` in ``diff.pt``). + + Returns a payload dict on rank 0 only; other ranks return ``None`` after + participating in any required ``full_tensor()`` collectives. + """ + rank = dist.get_rank() if dist.is_initialized() else 0 + if dist.is_initialized(): + dist.barrier(group=engine.cpu_group) + + if rank == 0: + per_param: dict[str, dict[str, float]] = {} + full_grad_tensors_fp32: dict[str, torch.Tensor] | None = ( + {} if save_full_grad_tensors_fp32 else None + ) + global_numel = 0.0 + global_abs_sum = 0.0 + global_l2_sq = 0.0 + global_max_abs = 0.0 + else: + per_param = {} + full_grad_tensors_fp32 = None + global_numel = 0.0 + global_abs_sum = 0.0 + global_l2_sq = 0.0 + global_max_abs = 0.0 + + for raw_name, param in engine.model.named_parameters(): + grad = param.grad + if grad is None: + continue + full_g = _materialize_full_param(grad) + if rank == 0: + for dump_name, dump_tensor in _to_dump_name_tensors( + engine, raw_name, full_g + ): + tensor_f = dump_tensor.detach().to(device="cpu", dtype=torch.float32) + numel = float(tensor_f.numel()) + if numel <= 0: + continue + abs_t = tensor_f.abs() + abs_sum = float(abs_t.sum().item()) + l2_sq = float(tensor_f.double().pow(2).sum().item()) + max_abs = float(abs_t.max().item()) + l2 = math.sqrt(max(l2_sq, 0.0)) + if dump_name in per_param: + raise ValueError( + f"Duplicate grad dump key '{dump_name}' from raw param '{raw_name}'." + ) + per_param[dump_name] = { + "numel": numel, + "mean_abs": abs_sum / numel, + "max_abs": max_abs, + "l2": l2, + "requires_grad": bool(param.requires_grad), + } + global_numel += numel + global_abs_sum += abs_sum + global_l2_sq += l2_sq + global_max_abs = max(global_max_abs, max_abs) + if full_grad_tensors_fp32 is not None: + full_grad_tensors_fp32[dump_name] = tensor_f.clone() + del full_g + + if rank == 0: + g_l2 = math.sqrt(max(global_l2_sq, 0.0)) + payload: dict[str, Any] = { + "schema_version": 2, + "aggregation": "full_tensor_grad_one_param_peak", + "params": per_param, + "global": { + "num_params": len(per_param), + "numel": global_numel, + "mean_abs": global_abs_sum / max(global_numel, 1.0), + "max_abs": global_max_abs, + "l2": g_l2, + }, + } + if full_grad_tensors_fp32 is not None: + payload["grad_tensors_fp32"] = full_grad_tensors_fp32 + + rg_false: list[str] = [] + rg_true = 0 + for rn, p in engine.model.named_parameters(): + if p.requires_grad: + rg_true += 1 + else: + rg_false.append(rn) + payload["requires_grad_meta"] = { + "num_named_requires_grad_true": rg_true, + "num_named_requires_grad_false": len(rg_false), + "named_requires_grad_false": rg_false[:128], + } + return payload + return None + + +def _snapshot_initial_full_params( + engine: ArchonLMEngine, +) -> dict[str, torch.Tensor] | None: + """Capture initial full params on CPU (rank 0 only). + + Each parameter is materialized one-by-one via ``full_tensor()`` and moved to + CPU immediately. This keeps extra GPU memory bounded by one parameter tensor. + """ + rank = dist.get_rank() if dist.is_initialized() else 0 + out: dict[str, torch.Tensor] | None = {} if rank == 0 else None + for raw_name, param in engine.model.named_parameters(): + full = _materialize_full_param(param) + if rank == 0: + assert out is not None + for dump_name, dump_tensor in _to_dump_name_tensors(engine, raw_name, full): + if dump_name in out: + raise ValueError( + f"Duplicate dump key '{dump_name}' from raw param '{raw_name}'." + ) + out[dump_name] = ( + dump_tensor.detach().to(device="cpu", dtype=torch.float32).clone() + ) + del full + if dist.is_initialized(): + dist.barrier(group=engine.cpu_group) + return out + + +def _save_diff_snapshot( + engine: ArchonLMEngine, + initial_params: dict[str, torch.Tensor] | None, + dump_dir: str, + filename: str, + save_full_diff_tensors_fp32: bool = False, +) -> str | None: + """Save ``diff.pt`` with per-parameter update metrics (rank 0 only).""" + out_path: str | None = None + rank = dist.get_rank() if dist.is_initialized() else 0 + if dist.is_initialized(): + dist.barrier(group=engine.cpu_group) + + if rank == 0 and initial_params is None: + raise RuntimeError("Missing initial params on rank 0 for diff snapshot.") + + # Every rank participates in full_tensor() to ensure collective safety. + if rank == 0: + assert initial_params is not None + per_param: dict[str, dict[str, float]] = {} + full_delta_tensors_fp32: dict[str, torch.Tensor] | None = ( + {} if save_full_diff_tensors_fp32 else None + ) + global_numel = 0.0 + global_abs_sum = 0.0 + global_l2_sq = 0.0 + global_ref_l2_sq = 0.0 + global_max_abs = 0.0 + else: + per_param = {} + global_numel = 0.0 + global_abs_sum = 0.0 + global_l2_sq = 0.0 + global_ref_l2_sq = 0.0 + global_max_abs = 0.0 + + for raw_name, param in engine.model.named_parameters(): + full = _materialize_full_param(param) + if rank == 0: + assert initial_params is not None + for dump_name, dump_tensor in _to_dump_name_tensors(engine, raw_name, full): + if dump_name not in initial_params: + raise KeyError( + f"Missing initial parameter for dump key '{dump_name}' " + f"(raw='{raw_name}')." + ) + initial = initial_params[dump_name] + current = dump_tensor.detach().to(device="cpu", dtype=torch.float32) + if current.shape != initial.shape: + raise ValueError( + f"Shape mismatch for '{dump_name}': current={tuple(current.shape)} " + f"vs initial={tuple(initial.shape)}" + ) + delta = current - initial + abs_delta = delta.abs() + numel = float(delta.numel()) + abs_sum = float(abs_delta.sum().item()) + l2_sq = float(delta.double().pow(2).sum().item()) + ref_l2_sq = float(initial.double().pow(2).sum().item()) + max_abs = float(abs_delta.max().item()) if delta.numel() > 0 else 0.0 + l2 = math.sqrt(max(l2_sq, 0.0)) + ref_l2 = math.sqrt(max(ref_l2_sq, 0.0)) + if dump_name in per_param: + raise ValueError( + f"Duplicate final dump key '{dump_name}' from raw param '{raw_name}'." + ) + per_param[dump_name] = { + "numel": numel, + "mean_abs_update": abs_sum / max(numel, 1.0), + "max_abs_update": max_abs, + "l2_update": l2, + "rel_l2_update": l2 / max(ref_l2, 1e-12), + } + if full_delta_tensors_fp32 is not None: + full_delta_tensors_fp32[dump_name] = delta.clone() + global_numel += numel + global_abs_sum += abs_sum + global_l2_sq += l2_sq + global_ref_l2_sq += ref_l2_sq + global_max_abs = max(global_max_abs, max_abs) + del full + + if rank == 0: + payload = { + "schema_version": 1, + "aggregation": "full_tensor_one_param_peak", + "params": per_param, + "global": { + "num_params": len(per_param), + "numel": global_numel, + "mean_abs_update": global_abs_sum / max(global_numel, 1.0), + "max_abs_update": global_max_abs, + "l2_update": math.sqrt(max(global_l2_sq, 0.0)), + "rel_l2_update": math.sqrt(max(global_l2_sq, 0.0)) + / max(math.sqrt(max(global_ref_l2_sq, 0.0)), 1e-12), + }, + } + if full_delta_tensors_fp32 is not None: + payload["delta_tensors_fp32"] = full_delta_tensors_fp32 + + os.makedirs(dump_dir, exist_ok=True) + out_path = os.path.join(dump_dir, filename) + torch.save(payload, out_path) + + if dist.is_initialized(): + dist.barrier(group=engine.cpu_group) + return out_path + + +# ----------------------------------------------------------------------------- +# Per-step training +# ----------------------------------------------------------------------------- + + +def _build_step_batch( + *, + engine: ArchonLMEngine, + cfg: ArchonTrainingTestConfig, + step_idx: int, + step_file: str, + device: torch.device, +) -> dict[str, Any]: + """Build one local stride-sharded batch for a data file.""" + dp_rank = engine.data_parallel_rank + dp_world_size = engine.data_parallel_world_size + seqs = _load_sequences(step_file) + cap = int(cfg.test_config.max_sequences_per_pt) + if cap > 0: + seqs = seqs[:cap] + if not seqs: + raise ValueError( + f"Step {step_idx}: after max_sequences_per_pt={cap}, " + f"no sequences left in {step_file}." + ) + max_tokens = int(engine.config.mb_spec.max_tokens_per_mb) + trajectories = _build_local_trajectories( + seqs=seqs, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + base_seed=cfg.test_config.seed + step_idx * 100003, + max_tokens=max_tokens, + device=device, + ) + if not trajectories: + raise RuntimeError( + f"Step {step_idx}: local trajectory list is empty for rank {dp_rank}." + ) + batch, _ = concat_batch(trajectories) + return { + k: (v.to(device) if isinstance(v, torch.Tensor) else v) + for k, v in batch.items() + } + + +def _dump_forward_outputs( + *, + engine: ArchonLMEngine, + cfg: ArchonTrainingTestConfig, + step_idx: int, + step_file: str, + device: torch.device, + dump_dir: str, +) -> dict[str, Any]: + """Run forward_batch and dump valid-token outputs for this step. + + Actor outputs are token logprobs; critic outputs are token values. + """ + rank = dist.get_rank() if dist.is_initialized() else 0 + batch = _build_step_batch( + engine=engine, + cfg=cfg, + step_idx=step_idx, + step_file=step_file, + device=device, + ) + out = engine.forward_batch(input_=batch) + if not torch.is_tensor(out): + raise TypeError( + f"forward dump expects dict input and tensor output, got {type(out)}." + ) + out_cpu = out.detach().to(device="cpu", dtype=torch.float32) + valid_mask_cpu = batch["attention_mask"].detach().to(device="cpu").bool() + local_input_tokens = int(valid_mask_cpu.sum().item()) + mask_matches_output = tuple(valid_mask_cpu.shape) == tuple(out_cpu.shape) + if mask_matches_output: + valid_positions = valid_mask_cpu.nonzero(as_tuple=False).to(torch.int32) + valid_forward = out_cpu[valid_mask_cpu].contiguous() + else: + valid_positions = torch.empty((0, 2), dtype=torch.int32) + valid_forward = torch.empty(0, dtype=torch.float32) + mismatch_t = torch.tensor( + [0.0 if mask_matches_output else 1.0], + dtype=torch.float64, + device=device, + ) + token_reduce_t = torch.tensor( + [ + float(local_input_tokens), + float(valid_forward.numel()), + ], + dtype=torch.float64, + device=device, + ) + if dist.is_initialized(): + dist.all_reduce(mismatch_t, op=dist.ReduceOp.SUM) + dist.all_reduce(token_reduce_t, op=dist.ReduceOp.SUM) + global_mask_mismatch_ranks = int(round(float(mismatch_t.item()))) + global_input_tokens = int(round(float(token_reduce_t[0].item()))) + global_valid_output_numel = int(round(float(token_reduce_t[1].item()))) + + local_summary = { + "rank": int(rank), + "local_input_tokens": local_input_tokens, + "local_valid_output_numel": int(valid_forward.numel()), + "local_padded_output_numel": int(out_cpu.numel()), + "shape": list(out_cpu.shape), + "local_valid_mask_matches_output": bool(mask_matches_output), + "global_input_tokens": global_input_tokens, + "global_valid_output_numel": global_valid_output_numel, + "global_mask_mismatch_ranks": global_mask_mismatch_ranks, + "valid_token_positions": valid_positions, + "valid_forward_fp32": valid_forward, + } + if dist.is_initialized(): + all_summaries: list[dict[str, Any] | None] = [ + None for _ in range(dist.get_world_size()) + ] + dist.all_gather_object(all_summaries, local_summary) + else: + all_summaries = [local_summary] + if rank == 0: + summary_payload = { + "schema_version": 1, + "step": int(step_idx), + "file": os.path.abspath(step_file), + "tree_training_mode": str(engine.config.tree_training_mode), + "is_critic": bool(engine.config.is_critic), + "output_kind": "value" if engine.config.is_critic else "logprob", + "global_input_tokens": global_input_tokens, + "global_valid_output_numel": global_valid_output_numel, + "global_mask_mismatch_ranks": global_mask_mismatch_ranks, + "per_rank": all_summaries, + } + summary_path = os.path.join(dump_dir, f"forward.step{step_idx}.summary.pt") + torch.save(summary_payload, summary_path) + if dist.is_initialized(): + dist.barrier(group=engine.cpu_group) + return { + "step": int(step_idx), + "file": os.path.abspath(step_file), + "global_input_tokens": global_input_tokens, + "global_valid_output_numel": global_valid_output_numel, + "global_mask_mismatch_ranks": global_mask_mismatch_ranks, + } + + +def _run_single_step( + *, + engine: ArchonLMEngine, + cfg: ArchonTrainingTestConfig, + step_idx: int, + step_file: str, + device: torch.device, + loss_fn, +) -> dict[str, Any]: + """Run one training step and return a global (all-rank) stats record.""" + dp_world_size = engine.data_parallel_world_size + batch = _build_step_batch( + engine=engine, cfg=cfg, step_idx=step_idx, step_file=step_file, device=device + ) + + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + t0 = time.perf_counter() + + result = engine.train_batch( + input_=batch, + loss_fn=loss_fn, + loss_weight_fn=_loss_weight_fn, + return_loss=True, + ) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + elapsed_s = time.perf_counter() - t0 + peak_mem_mib = ( + float(torch.cuda.max_memory_allocated() / (1024**2)) + if torch.cuda.is_available() + else 0.0 + ) + + step_loss = float(result.get("loss", float("nan"))) + loss_source = "train_batch_return" + + num_local_seqs = int(batch["input_ids"].shape[0]) + num_local_tokens = int(batch["attention_mask"].sum().item()) + + grad_norm_local = float(result.get("grad_norm", float("nan"))) + grad_norm_local_for_max = ( + grad_norm_local if math.isfinite(grad_norm_local) else float("-inf") + ) + grad_norm_local_for_min = ( + grad_norm_local if math.isfinite(grad_norm_local) else float("inf") + ) + lr_local = float(result.get("lr", 0.0)) + update_successful_local = float(result.get("update_successful", 0.0)) + + # train_batch(return_loss=True) returns each rank's contribution to the + # globally normalized objective. The correct global loss is the SUM across + # DP ranks (not a second token-weighted average). + loss_contrib_local = float(step_loss) if math.isfinite(step_loss) else 0.0 + loss_valid_local = 1.0 if math.isfinite(step_loss) else 0.0 + + reduce_sum = torch.tensor( + [ + loss_contrib_local, + loss_valid_local, + float(num_local_seqs), + float(num_local_tokens), + ], + dtype=torch.float64, + device=device, + ) + reduce_max = torch.tensor( + [ + float(elapsed_s), + float(peak_mem_mib), + float(grad_norm_local_for_max), + float(lr_local), + ], + dtype=torch.float64, + device=device, + ) + reduce_min = torch.tensor( + [float(update_successful_local), float(grad_norm_local_for_min)], + dtype=torch.float64, + device=device, + ) + + if dist.is_initialized(): + dist.all_reduce(reduce_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(reduce_max, op=dist.ReduceOp.MAX) + dist.all_reduce(reduce_min, op=dist.ReduceOp.MIN) + + global_loss_valid_count = int(round(float(reduce_sum[1].item()))) + global_loss = ( + float(reduce_sum[0].item()) if global_loss_valid_count > 0 else float("nan") + ) + global_grad_norm = float(reduce_max[2].item()) + if not math.isfinite(global_grad_norm): + global_grad_norm = float("nan") + global_grad_norm_min = float(reduce_min[1].item()) + if not math.isfinite(global_grad_norm_min): + global_grad_norm_min = float("nan") + + return { + "step": int(step_idx), + "file": os.path.abspath(step_file), + "world_size": int(dist.get_world_size()) if dist.is_initialized() else 1, + "dp_world_size": int(dp_world_size), + "num_global_sequences": int(round(float(reduce_sum[2].item()))), + "num_global_tokens": int(round(float(reduce_sum[3].item()))), + "elapsed_s_max": float(reduce_max[0].item()), + "peak_mem_mib_max": float(reduce_max[1].item()), + "loss": float(global_loss), + "loss_source": f"{loss_source}_global_dp_sum", + "grad_norm_max": float(global_grad_norm), + "grad_norm_min": float(global_grad_norm_min), + "update_successful": float(reduce_min[0].item()), + "lr_max": float(reduce_max[3].item()), + } + + +# ----------------------------------------------------------------------------- +# Main +# ----------------------------------------------------------------------------- + + +def main(argv: list[str] | None = None) -> None: + cfg, config_path = load_training_test_config(argv) + + rank, world_size = _setup_distributed_environment() + device = torch.device(current_platform.device_type) + logger = getLogger(f"[ArchonTrainingTest Rank {rank}]") + + dump_dir = ensure_dump_dir(cfg, rank=rank) + stats_path = os.path.join(dump_dir, "stats.jsonl") + diff_path = os.path.join(dump_dir, "diff.pt") + last_grads_path = os.path.join(dump_dir, "last_grads.pt") + if rank == 0: + # Truncate any prior stats file. + open(stats_path, "w").close() + # Remove stale diff snapshot so failed runs never expose old results. + if cfg.test_config.save_diff and os.path.exists(diff_path): + os.remove(diff_path) + if cfg.test_config.dump_last_grads and os.path.exists(last_grads_path): + os.remove(last_grads_path) + + if rank == 0: + logger.info( + "config=%s dump_dir=%s world_size=%s", + config_path, + dump_dir, + world_size, + ) + + step_files = _list_step_files(cfg.test_config.data_dir) + if rank == 0: + logger.info( + "Found %d .pt files in %s", + len(step_files), + cfg.test_config.data_dir, + ) + + engine: ArchonLMEngine | None = None + + try: + engine = _create_engine(cfg) + last_grad_holder: dict[str, Any] = {} + _patch_engine_for_test( + engine, + disable_optimizer=cfg.test_config.disable_optimizer, + dump_last_grads=cfg.test_config.dump_last_grads, + num_training_steps=int(cfg.test_config.step), + last_grad_holder=( + last_grad_holder + if ( + cfg.test_config.disable_optimizer + and cfg.test_config.dump_last_grads + ) + else None + ), + save_full_last_grad_tensors_fp32=( + cfg.test_config.save_full_last_grad_tensors_fp32 + ), + ) + if rank == 0 and cfg.test_config.save_params: + logger.warning( + "test_config.save_params is deprecated in low-memory mode and " + "ignored. Use diff.pt.", + ) + if rank == 0 and cfg.test_config.save_initial_params: + logger.warning( + "test_config.save_initial_params is ignored in low-memory mode.", + ) + + initial_params: dict[str, torch.Tensor] | None = None + if cfg.test_config.save_diff: + initial_params = _snapshot_initial_full_params(engine) + + loss_fn = _make_loss_fn(cfg) + num_steps = int(cfg.test_config.step) + for step_idx in range(num_steps): + file_idx = step_idx % len(step_files) + step_file = step_files[file_idx] + if rank == 0: + logger.info( + "Starting training step %d/%d (0-based index %d), data file=%s", + step_idx + 1, + num_steps, + step_idx, + os.path.abspath(step_file), + ) + + if cfg.test_config.dump_forward_compare: + forward_record = _dump_forward_outputs( + engine=engine, + cfg=cfg, + step_idx=step_idx, + step_file=step_file, + device=device, + dump_dir=dump_dir, + ) + if rank == 0: + logger.info( + "Forward output dumped: step=%d file=%s output_kind=%s " + "global_tokens=%d global_valid_outputs=%d mask_mismatch_ranks=%d", + forward_record["step"], + os.path.basename(forward_record["file"]), + "value" if engine.config.is_critic else "logprob", + forward_record["global_input_tokens"], + forward_record["global_valid_output_numel"], + forward_record["global_mask_mismatch_ranks"], + ) + + record = _run_single_step( + engine=engine, + cfg=cfg, + step_idx=step_idx, + step_file=step_file, + device=device, + loss_fn=loss_fn, + ) + + if rank == 0: + with open(stats_path, "a") as fp: + fp.write(json.dumps(record) + "\n") + logger.info( + "Step %03d done: file=%s loss=%.6f grad_norm(min/max)=%.4f/%.4f " + "elapsed(max)=%.2fs peak_mem(max)=%.1fMiB", + step_idx, + os.path.basename(step_file), + record["loss"], + record["grad_norm_min"], + record["grad_norm_max"], + record["elapsed_s_max"], + record["peak_mem_mib_max"], + ) + + if cfg.test_config.dump_last_grads: + if cfg.test_config.disable_optimizer: + grad_payload: dict[str, Any] = dict(last_grad_holder) + else: + built = _build_grad_snapshot_payload( + engine, + save_full_grad_tensors_fp32=( + cfg.test_config.save_full_last_grad_tensors_fp32 + ), + ) + grad_payload = built if built is not None else {} + if rank == 0: + if not grad_payload: + logger.warning( + "dump_last_grads: empty payload (no .grad tensors?)." + ) + else: + os.makedirs(dump_dir, exist_ok=True) + torch.save(grad_payload, last_grads_path) + gmeta = grad_payload.get("global", {}) + gtensors = grad_payload.get("grad_tensors_fp32") + n_full = len(gtensors) if isinstance(gtensors, dict) else 0 + rg_meta = grad_payload.get("requires_grad_meta") or {} + n_rg_t = rg_meta.get("num_named_requires_grad_true") + n_rg_f = rg_meta.get("num_named_requires_grad_false") + in_dump_rg_f = sum( + 1 + for v in grad_payload.get("params", {}).values() + if isinstance(v, dict) and not v.get("requires_grad", True) + ) + logger.info( + "Wrote %s: num_params=%s global_l2=%s full_grad_tensors=%s " + "named_requires_grad_true/false=%s/%s " + "(in this dump, params with requires_grad=False: %s)", + last_grads_path, + gmeta.get("num_params"), + gmeta.get("l2"), + n_full, + n_rg_t, + n_rg_f, + in_dump_rg_f, + ) + if n_rg_f: + sample = (rg_meta.get("named_requires_grad_false") or [])[:8] + logger.warning( + "Some named_parameters have requires_grad=False (sample raw names): %s", + sample, + ) + if dist.is_initialized(): + dist.barrier(group=engine.cpu_group) + + if cfg.test_config.save_diff: + _save_diff_snapshot( + engine, + initial_params, + dump_dir, + "diff.pt", + save_full_diff_tensors_fp32=( + cfg.test_config.save_full_diff_tensors_fp32 + ), + ) + if initial_params is not None: + initial_params.clear() + finally: + _destroy_engine(engine) + + +if __name__ == "__main__": + main() diff --git a/tests/experimental/archon/torchrun/run_dynamic_nccl_p2p.py b/tests/experimental/archon/torchrun/run_dynamic_nccl_p2p.py new file mode 100644 index 0000000000..3766410c55 --- /dev/null +++ b/tests/experimental/archon/torchrun/run_dynamic_nccl_p2p.py @@ -0,0 +1,154 @@ +import os +import time + +import torch +import torch.distributed as dist + +from areal.infra.platforms import current_platform + +HEADER_SIZE = 4 +OWNER_RANK = 0 +PULL_PARAM = 0 +PUSH_GRAD = 1 +DONE = 2 + + +def init_distributed() -> None: + if dist.is_initialized(): + return + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + master_addr = os.environ["MASTER_ADDR"] + master_port = os.environ["MASTER_PORT"] + dist.init_process_group( + backend="nccl", + init_method=f"tcp://{master_addr}:{master_port}", + rank=rank, + world_size=world_size, + ) + current_platform.set_device(int(os.environ["LOCAL_RANK"])) + + +def request_plan(rank: int) -> list[tuple[int, float]]: + # Different delays force workers to reach the owner in a dynamic order. + plans = { + 1: [(0, 0.15), (1, 0.03)], + 2: [(1, 0.02), (0, 0.10)], + } + return plans[rank] + + +def build_owner_shards(device: torch.device) -> dict[int, torch.Tensor]: + return { + 0: torch.tensor([1.0, 2.0, 3.0, 4.0], device=device), + 1: torch.tensor([10.0, 20.0, 30.0, 40.0], device=device), + } + + +def run_worker(rank: int, device: torch.device) -> None: + for req_id, (shard_id, delay_s) in enumerate(request_plan(rank)): + time.sleep(delay_s) + + pull_header = torch.tensor( + [PULL_PARAM, req_id, shard_id, 4], device=device, dtype=torch.int64 + ) + dist.isend(pull_header, dst=OWNER_RANK).wait() + + param = torch.empty(4, device=device, dtype=torch.float32) + dist.irecv(param, src=OWNER_RANK).wait() + + # Fake local compute. Activations/grad activations stay local; only grad shard is pushed. + grad = param * float(rank) + grad_header = torch.tensor( + [PUSH_GRAD, req_id, shard_id, grad.numel()], + device=device, + dtype=torch.int64, + ) + dist.isend(grad_header, dst=OWNER_RANK).wait() + dist.isend(grad, dst=OWNER_RANK).wait() + + done = torch.tensor([DONE, 0, 0, 0], device=device, dtype=torch.int64) + dist.isend(done, dst=OWNER_RANK).wait() + + +def run_owner(world_size: int, device: torch.device) -> None: + shards = build_owner_shards(device) + grad_accum = {shard_id: torch.zeros_like(t) for shard_id, t in shards.items()} + event_log: list[str] = [] + + header_bufs = { + src: torch.empty(HEADER_SIZE, device=device, dtype=torch.int64) + for src in range(1, world_size) + } + header_works = { + src: dist.irecv(header_bufs[src], src=src) for src in range(1, world_size) + } + done_workers: set[int] = set() + + while len(done_workers) < world_size - 1: + progressed = False + for src in range(1, world_size): + if src in done_workers: + continue + work = header_works[src] + if not work.is_completed(): + continue + + work.wait() + header = header_bufs[src].clone() + op, req_id, shard_id, numel = [int(x) for x in header.tolist()] + progressed = True + + if op == PULL_PARAM: + event_log.append(f"pull:worker={src},req={req_id},shard={shard_id}") + payload = shards[shard_id] + assert payload.numel() == numel + dist.isend(payload, dst=src).wait() + header_works[src] = dist.irecv(header_bufs[src], src=src) + elif op == PUSH_GRAD: + event_log.append(f"grad:worker={src},req={req_id},shard={shard_id}") + grad = torch.empty(numel, device=device, dtype=torch.float32) + dist.irecv(grad, src=src).wait() + grad_accum[shard_id].add_(grad.view_as(grad_accum[shard_id])) + header_works[src] = dist.irecv(header_bufs[src], src=src) + elif op == DONE: + done_workers.add(src) + else: + raise ValueError(f"Unexpected op={op} from worker {src}") + + if not progressed: + time.sleep(0.001) + + expected = { + 0: shards[0] * (1.0 + 2.0), + 1: shards[1] * (1.0 + 2.0), + } + torch.testing.assert_close(grad_accum[0], expected[0], atol=0.0, rtol=0.0) + torch.testing.assert_close(grad_accum[1], expected[1], atol=0.0, rtol=0.0) + + print("Observed dynamic event order:", flush=True) + for event in event_log: + print(f" {event}", flush=True) + print("Dynamic NCCL P2P mailbox test passed", flush=True) + + +def main() -> None: + init_distributed() + try: + rank = dist.get_rank() + world_size = dist.get_world_size() + assert world_size >= 3, "This test requires at least 3 ranks" + device = current_platform.current_device() + + if rank == OWNER_RANK: + run_owner(world_size, device) + else: + run_worker(rank, device) + + dist.barrier() + finally: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/experimental/archon/torchrun/training_test_config.py b/tests/experimental/archon/torchrun/training_test_config.py new file mode 100644 index 0000000000..71344c13cb --- /dev/null +++ b/tests/experimental/archon/torchrun/training_test_config.py @@ -0,0 +1,467 @@ +"""Configuration types + YAML/CLI loader for ArchonEngine training tests. + +The runner accepts regular AReaL YAML plus ``test_config``: + +```yaml +experiment_name: archon_train_test +trial_name: trial0 +cluster: + fileroot: /storage/openpsi/experiments + +actor: # Standard AReaL TrainEngineConfig/PPOActorConfig fields. + backend: archon:d2 + path: /path/to/model + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 5596 + optimizer: # Ignored when test_config.disable_optimizer=true. + type: adam + lr: 1e-5 + ... + tree_training_mode: dta + packing_algorithm: ffd + +test_config: # Test-only knobs, see ``TestOnlyConfig``. + step: 4 + data_dir: /path/to/data_dir + disable_optimizer: false + save_diff: true +``` + +OmegaConf-style dotlist overrides are supported on the CLI, eg:: + + torchrun --nproc_per_node=2 run_archon_training_test.py \ + --config config.yaml \ + test_config.step=4 test_config.disable_optimizer=true +""" + +from __future__ import annotations + +import argparse +import dataclasses +import getpass +import os +import re +import types +import typing +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Union + +from omegaconf import DictConfig, OmegaConf + +from areal.api.alloc_mode import ParallelStrategy +from areal.api.alloc_mode import _AllocationMode as AllocationMode +from areal.api.cli_args import TrainEngineConfig +from areal.utils.logging import getLogger + +_LOGGER = getLogger("TrainingTestConfig") + + +def _log_config_info(msg: str, *args: object) -> None: + """Log once per node before process group init (torchrun sets LOCAL_RANK).""" + if int(os.environ.get("LOCAL_RANK", "0")) != 0: + return + _LOGGER.info(msg, *args) + + +@dataclass +class TestOnlyConfig: + """Test-only settings not inherited from AReaL configs.""" + + step: int = -1 + data_dir: str = "" + disable_optimizer: bool = False + save_diff: bool = True + save_params: bool = False + save_initial_params: bool = False + seed: int = 42 + entropy_coef: float = 0.0 + entropy_mode: str = "sum" + save_full_diff_tensors_fp32: bool = False + # After the final training step, write ``last_grads.pt`` (per-param grad stats). + # With ``disable_optimizer=true``, grads are cleared inside the patched step, so + # the runner captures them in that hook on the last step only. + dump_last_grads: bool = False + # When ``dump_last_grads``, also store full CPU fp32 tensors under + # ``grad_tensors_fp32`` (like ``delta_tensors_fp32`` in diff.pt). Set false to + # save disk/memory (stats-only). + save_full_last_grad_tensors_fp32: bool = True + is_critic: bool = False + # After load, take at most this many sequences per .pt (0 = no cap). + max_sequences_per_pt: int = 0 + # After ``ArchonLMEngine.initialize`` (HF load + buffers), export weights with + # ``ArchonEngine.save`` / ``save_model_to_hf`` into this directory. ``None`` = skip. + save_hf_checkpoint_dir: str | None = None + # Optional forward parity check: compare current tree_training_mode output with + # a baseline engine forced to ``tree_training_mode=disabled`` at one step. + dump_forward_compare: bool = False + + def __post_init__(self) -> None: + if self.step is None or int(self.step) < 0: + raise ValueError( + f"test_config.step must be a non-negative integer, got {self.step}." + ) + if not self.data_dir: + raise ValueError( + "test_config.data_dir is required and must be a non-empty path." + ) + if float(self.entropy_coef) < 0: + raise ValueError( + f"test_config.entropy_coef must be >= 0, got {self.entropy_coef}." + ) + valid_entropy_modes = {"mean", "sum"} + if self.entropy_mode not in valid_entropy_modes: + raise ValueError( + f"test_config.entropy_mode must be one of " + f"{sorted(valid_entropy_modes)}, got '{self.entropy_mode}'." + ) + if int(self.max_sequences_per_pt) < 0: + raise ValueError( + "test_config.max_sequences_per_pt must be >= 0 " + f"(0 = no cap), got {self.max_sequences_per_pt}." + ) + + +@dataclass +class TestParallelConfig: + """Subset of ParallelStrategy fields exposed to the test YAML.""" + + data_parallel_size: int = 1 + tensor_parallel_size: int = 1 + pipeline_parallel_size: int = 1 + context_parallel_size: int = 1 + expert_parallel_size: int = 1 + expert_tensor_parallel_size: int = 1 + + def to_parallel_strategy(self) -> ParallelStrategy: + return ParallelStrategy( + tensor_parallel_size=self.tensor_parallel_size, + pipeline_parallel_size=self.pipeline_parallel_size, + data_parallel_size=self.data_parallel_size, + context_parallel_size=self.context_parallel_size, + expert_parallel_size=self.expert_parallel_size, + expert_tensor_parallel_size=self.expert_tensor_parallel_size, + ) + + def to_compact_tag(self) -> str: + """Compact path-friendly tag, e.g. ``d8t2c2``.""" + parts = [f"d{int(self.data_parallel_size)}"] + if int(self.pipeline_parallel_size) > 1: + parts.append(f"p{int(self.pipeline_parallel_size)}") + if int(self.tensor_parallel_size) > 1: + parts.append(f"t{int(self.tensor_parallel_size)}") + if int(self.context_parallel_size) > 1: + parts.append(f"c{int(self.context_parallel_size)}") + if int(self.expert_parallel_size) > 1: + parts.append(f"e{int(self.expert_parallel_size)}") + if int(self.expert_tensor_parallel_size) > 1: + parts.append(f"et{int(self.expert_tensor_parallel_size)}") + return "".join(parts) + + +@dataclass +class ArchonTrainingTestConfig: + """Top-level container combining AReaL engine config + test knobs.""" + + engine: TrainEngineConfig = field(default_factory=TrainEngineConfig) + parallel: TestParallelConfig = field(default_factory=TestParallelConfig) + test_config: TestOnlyConfig = field(default_factory=TestOnlyConfig) + fileroot: str = "" + + @staticmethod + def _safe_token(value: str, *, fallback: str) -> str: + s = (value or "").strip() + if not s: + s = fallback + s = s.replace(os.sep, "_") + if os.altsep: + s = s.replace(os.altsep, "_") + s = re.sub(r"[^A-Za-z0-9._-]+", "_", s) + s = s.strip("._-") + return s or fallback + + @staticmethod + def _expand_path(path: str) -> str: + return os.path.expanduser(os.path.expandvars(path)) + + def resolve_dump_dir(self) -> str: + """Pick a compact dump_dir under regular training log roots.""" + exp = self._safe_token( + str(self.engine.experiment_name or "archon_train_test"), + fallback="archon_train_test", + ) + trial = self._safe_token( + str(self.engine.trial_name or "trial0"), + fallback="trial0", + ) + tree_mode = self._safe_token( + str(getattr(self.engine, "tree_training_mode", "unknown") or "unknown"), + fallback="unknown", + ) + parallel_tag = self._safe_token(self.parallel.to_compact_tag(), fallback="d1") + model_name = self._safe_token( + Path(str(self.engine.path or "")).name, fallback="model" + ) + leaf = f"{tree_mode}_{parallel_tag}_{model_name}" + + if self.fileroot: + # Align with AReaL StatsLogger layout: + # /logs/// + base = ( + Path(self._expand_path(self.fileroot)) + / "logs" + / getpass.getuser() + / exp + / trial + ) + else: + base = Path.cwd() / exp / trial + return str(base / leaf) + + +def _merge_yaml_and_overrides( + yaml_path: str, + overrides: list[str], +) -> DictConfig: + yaml_cfg = OmegaConf.load(yaml_path) + if not isinstance(yaml_cfg, DictConfig): + raise ValueError( + f"Top-level YAML at {yaml_path} must be a mapping, got {type(yaml_cfg)}." + ) + override_cfg = OmegaConf.from_dotlist(list(overrides)) + return OmegaConf.merge(yaml_cfg, override_cfg) + + +def _as_dict(section: Any) -> dict[str, Any]: + """Resolve an OmegaConf node into a plain ``dict``.""" + if section is None: + return {} + if isinstance(section, DictConfig): + return OmegaConf.to_container(section, resolve=True) # type: ignore[return-value] + if isinstance(section, dict): + return dict(section) + raise TypeError(f"Expected mapping-like config section, got {type(section)}") + + +def _coerce_value(tp: Any, value: Any) -> Any: + """Best-effort coercion of ``value`` to the dataclass field type ``tp``. + + Handles ``Optional[X]``, nested dataclasses, and ``list[DataClass]`` / + ``tuple[DataClass, ...]``. Other annotations (``Literal``, ``int``, ``str``, + ``dict``, ...) pass through unchanged so OmegaConf primitives continue to + work. + """ + if value is None: + return None + + origin = typing.get_origin(tp) + args = typing.get_args(tp) + + if origin is Union or origin is types.UnionType: + # Try the non-None variants in order; first one that accepts the value + # wins. Primitives pass through since ``_coerce_value`` is a no-op for + # non-dataclass leaf types. + non_none = [a for a in args if a is not type(None)] + for alt in non_none: + if dataclasses.is_dataclass(alt) and isinstance(value, dict): + return _build_dataclass(alt, value) + return value + + if dataclasses.is_dataclass(tp): + if isinstance(value, dict): + return _build_dataclass(tp, value) + return value + + if origin in (list, tuple) and args: + inner = args[0] + if dataclasses.is_dataclass(inner) and isinstance(value, (list, tuple)): + coerced = [_coerce_value(inner, v) for v in value] + return tuple(coerced) if origin is tuple else coerced + + return value + + +def _build_dataclass(cls: type, data: dict[str, Any]) -> Any: + """Instantiate ``cls`` from ``data``, recursively coercing nested fields. + + Fields not present in ``data`` fall back to their dataclass defaults so + partial YAML sections are allowed. Unknown keys raise ``TypeError`` to + surface typos early. + """ + assert dataclasses.is_dataclass(cls), f"{cls} is not a dataclass" + hints = typing.get_type_hints(cls) + init_kwargs: dict[str, Any] = {} + known_names = {f.name for f in dataclasses.fields(cls) if f.init} + for key, value in data.items(): + if key not in known_names: + raise TypeError( + f"Unknown field '{key}' for {cls.__name__}; " + f"valid fields: {sorted(known_names)[:30]}..." + ) + for f in dataclasses.fields(cls): + if not f.init: + continue + if f.name not in data: + continue + tp = hints.get(f.name, f.type) + init_kwargs[f.name] = _coerce_value(tp, data[f.name]) + return cls(**init_kwargs) + + +def _build_engine_config_from_actor(actor_section: Any) -> TrainEngineConfig: + """Project actor config onto ``TrainEngineConfig`` fields only.""" + actor_data = _as_dict(actor_section) + train_fields = {f.name for f in dataclasses.fields(TrainEngineConfig) if f.init} + engine_data = {k: v for k, v in actor_data.items() if k in train_fields} + return _build_dataclass(TrainEngineConfig, engine_data) + + +def _build_engine_config( + actor_section: Any, merged_cfg: DictConfig +) -> TrainEngineConfig: + """Build ``TrainEngineConfig`` from top-level ``actor`` section only.""" + top_exp = merged_cfg.get("experiment_name") if merged_cfg is not None else None + top_trial = merged_cfg.get("trial_name") if merged_cfg is not None else None + default_exp = str(top_exp or "archon_train_test") + default_trial = str(top_trial or "trial0") + + if actor_section is None: + raise ValueError("Missing required top-level 'actor' section in config YAML.") + actor_data = _as_dict(actor_section) + actor_data.setdefault("experiment_name", default_exp) + actor_data.setdefault("trial_name", default_trial) + cfg = _build_engine_config_from_actor(actor_data) + + _log_config_info( + "Resolved TrainEngineConfig from top-level 'actor' section.", + ) + return cfg + + +def _parallel_strategy_to_test_config(strategy: ParallelStrategy) -> TestParallelConfig: + """Convert ``ParallelStrategy`` to ``TestParallelConfig``.""" + return TestParallelConfig( + data_parallel_size=int(strategy.data_parallel_size), + tensor_parallel_size=int(strategy.tensor_parallel_size), + pipeline_parallel_size=int(strategy.pipeline_parallel_size), + context_parallel_size=int(strategy.context_parallel_size), + expert_parallel_size=int(strategy.expert_parallel_size), + expert_tensor_parallel_size=int(strategy.expert_tensor_parallel_size), + ) + + +def _build_parallel_config(section: Any) -> TestParallelConfig: + """Build ``TestParallelConfig`` from mapping or allocation-mode string. + + Supported forms: + - Mapping (legacy): + parallel: + data_parallel_size: 2 + tensor_parallel_size: 1 + ... + - String (reuses AllocationMode parser): + parallel: archon:d8 + parallel: sglang:d16+archon:d8 + """ + if section is None: + return TestParallelConfig() + + if isinstance(section, str): + mode = AllocationMode.from_str(section) + train_allocs = [ + a for a in mode.allocations if a.backend in ("fsdp", "megatron", "archon") + ] + if len(train_allocs) != 1: + raise ValueError( + "parallel string must resolve to exactly one training allocation " + f"(got {len(train_allocs)}): {section}" + ) + alloc = train_allocs[0] + if alloc.backend != "archon": + raise ValueError( + "Only archon backend is supported by this test runner. " + f"Got training backend '{alloc.backend}' from parallel='{section}'." + ) + if alloc.parallel is None: + raise ValueError( + f"Resolved archon allocation has no parallel strategy: {section}" + ) + return _parallel_strategy_to_test_config(alloc.parallel) + + return TestParallelConfig(**_as_dict(section)) + + +def _resolve_output_fileroot(merged: DictConfig) -> str: + """Resolve output fileroot with priority: stats_logger > cluster.""" + stats_logger = _as_dict(merged.get("stats_logger") if merged else None) + stats_logger_fileroot = stats_logger.get("fileroot") + if stats_logger_fileroot: + return str(stats_logger_fileroot) + + cluster = _as_dict(merged.get("cluster") if merged else None) + cluster_fileroot = cluster.get("fileroot") + if cluster_fileroot: + return str(cluster_fileroot) + return "" + + +def load_training_test_config( + argv: list[str] | None = None, +) -> tuple[ArchonTrainingTestConfig, str]: + """Parse CLI and return a resolved ``ArchonTrainingTestConfig``.""" + parser = argparse.ArgumentParser( + description="Run ArchonEngine training-side test under torchrun." + ) + parser.add_argument( + "--config", + type=str, + required=True, + help="Path to the YAML config file.", + ) + args, overrides = parser.parse_known_args(argv) + config_path = Path(args.config).expanduser().resolve() + if not config_path.exists(): + raise FileNotFoundError(f"Config file not found: {config_path}") + + merged = _merge_yaml_and_overrides(str(config_path), overrides) + + if merged and merged.get("engine") is not None: + raise ValueError( + "This runner only supports regular AReaL YAML + test_config. " + "Do not provide top-level 'engine'; use top-level 'actor' instead." + ) + if merged and merged.get("parallel") is not None: + raise ValueError( + "This runner derives parallel strategy from actor.backend. " + "Do not provide top-level 'parallel'." + ) + + engine_cfg = _build_engine_config(merged.get("actor") if merged else None, merged) + + actor_data = _as_dict(merged.get("actor") if merged else None) + actor_backend = actor_data.get("backend") + if not actor_backend: + raise ValueError("actor.backend is required and must be a non-empty string.") + parallel_section = actor_backend + _log_config_info("Resolved 'parallel' from actor.backend=%s", actor_backend) + parallel_cfg = _build_parallel_config(parallel_section) + + test_cfg = TestOnlyConfig(**_as_dict(merged.get("test_config") if merged else None)) + fileroot = _resolve_output_fileroot(merged) + + cfg = ArchonTrainingTestConfig( + engine=engine_cfg, + parallel=parallel_cfg, + test_config=test_cfg, + fileroot=fileroot, + ) + return cfg, str(config_path) + + +def ensure_dump_dir(cfg: ArchonTrainingTestConfig, rank: int) -> str: + """Create (on rank 0) and return the resolved dump_dir.""" + dump_dir = cfg.resolve_dump_dir() + if rank == 0: + os.makedirs(dump_dir, exist_ok=True) + return dump_dir diff --git a/tests/experimental/archon/utils.py b/tests/experimental/archon/utils.py index 9c31f49383..f47d2d3a67 100644 --- a/tests/experimental/archon/utils.py +++ b/tests/experimental/archon/utils.py @@ -2,13 +2,16 @@ import os import subprocess +from collections.abc import Callable from dataclasses import dataclass +from types import SimpleNamespace from typing import Any import pytest import torch import torch.distributed as dist from datasets import load_dataset +from torch.distributed.tensor import DTensor from transformers import AutoModelForCausalLM from areal.api import FinetuneSpec, ParallelStrategy @@ -54,6 +57,15 @@ "create_grpo_batch", "DualEngineFixture", "dual_engines", + "create_archon_engine", + "create_fsdp_engine", + "destroy_test_engine", + "create_dta_batch", + "load_pt_batch", + "dta_dummy_loss_fn", + "dta_loss_weight_fn", + "snapshot_module_parameters", + "strip_wrapper_prefixes", ] @@ -72,6 +84,11 @@ def get_model_path_for_type(model_type: str) -> str | None: DATASET_PATH = get_dataset_path("/storage/openpsi/data/gsm8k", "openai/gsm8k") +def strip_wrapper_prefixes(name: str) -> str: + """Drop wrapper-generated path segments from parameter names.""" + return name.replace("._checkpoint_wrapped_module", "").replace("._orig_mod", "") + + @dataclass class ComparisonMetrics: """Metrics for comparing two tensors.""" @@ -468,3 +485,262 @@ def dual_engines(): fixture.setup() yield fixture fixture.teardown() + + +# ============================================================================= +# DTA Engine Testing Utilities +# ============================================================================= + + +def create_dta_batch( + batch_size: int = 4, + seq_len: int = 64, + shared_prefix_len: int = 20, + vocab_size: int = 151936, + device: torch.device | None = None, +) -> dict[str, Any]: + """Build a synthetic batch whose sequences share a common prefix. + + Returns a dict compatible with ``ArchonEngine.train_batch`` (GRPO-style + fields included so the default loss path works). + + Args: + batch_size: Number of sequences. + seq_len: Length of each sequence. + shared_prefix_len: Length of the common prefix across all sequences. + vocab_size: Vocabulary size for random token generation. + device: Target device (defaults to current platform device). + """ + if device is None: + device = torch.device(current_platform.device_type) + + torch.manual_seed(42) + + prefix = torch.randint(100, vocab_size - 100, (shared_prefix_len,)) + rows = [] + for _ in range(batch_size): + suffix = torch.randint(100, vocab_size - 100, (seq_len - shared_prefix_len,)) + rows.append(torch.cat([prefix, suffix])) + input_ids = torch.stack(rows).to(device) + + attention_mask = torch.ones_like(input_ids) + loss_mask = torch.ones(batch_size, seq_len, device=device) + loss_mask[:, :10] = 0.0 + + logprobs = torch.randn(batch_size, seq_len, device=device) * 0.5 - 2.0 + old_logprobs = logprobs.clone() + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "loss_mask": loss_mask, + "logprobs": logprobs, + "old_logprobs": old_logprobs, + "advantages": torch.randn(batch_size, seq_len, device=device), + "rewards": torch.randint(0, 2, (batch_size,), device=device).float(), + "values": torch.zeros(batch_size, seq_len, device=device), + "prox_logp": old_logprobs.clone(), + } + + +def load_pt_batch( + test_config: Any, + prompt_ratio: float = 0.3, + device: torch.device | None = None, +) -> dict[str, Any]: + """Load all token sequences from a ``.pt`` file at full length. + + Each ``.pt`` file contains ``list[Tensor]`` where every tensor is a 1-D + ``int64`` sequence with no padding. All sequences are kept at their + original length and right-padded to the longest one. + + GRPO fields (``loss_mask``, ``logprobs``, ``advantages``, …) are filled + with synthetic values so the batch works with ``train_batch``. + + Args: + test_config: Test config carrying ``dta_data``, ``max_tokens_per_mb``, and optional ``dta_limit``. + prompt_ratio: Fraction of each sequence treated as prompt (loss_mask=0). + device: Target device (defaults to current platform device). + """ + if device is None: + device = torch.device(current_platform.device_type) + # print(f"loadbatch on device: {device}") + + pt_path = str(test_config.dta_data) + assert pt_path is not None, "dta_data is required but got None" + seqs: list[torch.Tensor] = torch.load( + pt_path, map_location="cpu", weights_only=True + ) + assert isinstance(seqs, list) and len(seqs) > 0, ( + f"Expected list[Tensor], got {type(seqs)}" + ) + dta_limit = int(getattr(test_config, "dta_limit", -1)) + if dta_limit >= 0: + seqs = seqs[:dta_limit] + assert len(seqs) > 0, "No sequences available after applying dta_limit." + + bs = len(seqs) + max_tokens_per_mb = int(test_config.max_tokens_per_mb) + lengths = [min(s.numel(), max_tokens_per_mb) for s in seqs] + padded_len = max(lengths) + + input_ids = torch.zeros(bs, padded_len, dtype=torch.long) + attention_mask = torch.zeros(bs, padded_len, dtype=torch.long) + loss_mask = torch.zeros(bs, padded_len) + + for i, (s, length) in enumerate(zip(seqs, lengths)): + input_ids[i, :length] = s[:length] + attention_mask[i, :length] = 1 + prompt_len = max(1, int(length * prompt_ratio)) + loss_mask[i, prompt_len:length] = 1.0 + + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + loss_mask = loss_mask.to(device) + + logprobs = torch.randn(bs, padded_len, device=device) * 0.5 - 2.0 + old_logprobs = logprobs.clone() + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "loss_mask": loss_mask, + "logprobs": logprobs, + "old_logprobs": old_logprobs, + "advantages": torch.randn(bs, padded_len, device=device), + "rewards": torch.randint(0, 2, (bs,), device=device).float(), + "values": torch.zeros(bs, padded_len, device=device), + "prox_logp": old_logprobs.clone(), + } + + +def dta_dummy_loss_fn(logprobs, entropy, input_data, **kwargs): + """Minimal loss for DTA smoke tests.""" + loss_mask = input_data.get("loss_mask") + if loss_mask is None: + return -logprobs.sum() + min_len = min(logprobs.shape[-1], loss_mask.shape[-1]) + logprobs = logprobs[..., :min_len] + loss_mask = loss_mask[..., :min_len] + return -(logprobs * loss_mask).sum() / loss_mask.sum().clamp(min=1) + + +def dta_loss_weight_fn(input_data): + """Loss weight function for DTA smoke tests.""" + lm = input_data.get("loss_mask") + if lm is not None: + return lm.sum() + return torch.tensor(1.0) + + +def snapshot_module_parameters( + module: torch.nn.Module, + to_cpu: bool = False, + param_filter: Callable[[str, torch.nn.Parameter], bool] | None = None, +) -> dict[str, torch.Tensor]: + """Snapshot (clone) selected named parameters for later delta comparisons. + + This is intentionally lightweight to reuse the same comparison pattern + across tests (similar to how `test_grpo.py` compares weight deltas). + """ + snapshots: dict[str, torch.Tensor] = {} + for name, param in module.named_parameters(): + if param_filter is not None and not param_filter(name, param): + continue + t = param.full_tensor() if isinstance(param, DTensor) else param + t = t.detach().clone() + if to_cpu: + t = t.cpu() + snapshots[name] = t + return snapshots + + +def create_archon_engine( + test_config: SimpleNamespace, + model_path: str | None = None, +) -> ArchonLMEngine: + """Create and initialize a single Archon engine for tests.""" + setup_distributed_environment() + model_path = model_path or MODEL_PATHS["qwen2"] + world_size = dist.get_world_size() if dist.is_initialized() else 1 + parallel_strategy = ParallelStrategy(data_parallel_size=world_size) + ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=4, train_batch_size=4) + max_tokens_per_mb = int(test_config.max_tokens_per_mb) + + config = create_engine_config( + model_path, + "archon_dta" if test_config.tree_training_mode == "dta" else "archon", + ) + config.mb_spec = MicroBatchSpec.new( + config.mb_spec, max_tokens_per_mb=max_tokens_per_mb + ) + config.tree_training_mode = test_config.tree_training_mode + if os.environ.get("AREAL_DISABLE_TORCH_COMPILE", "").lower() in ( + "1", + "true", + "yes", + ): + config.archon.enable_compile = False + config.path = test_config.model_path + + engine = ArchonLMEngine(config) + engine.create_process_group(parallel_strategy=parallel_strategy) + engine.initialize(addr=None, ft_spec=ft_spec) + + if test_config.use_hf: + # Clean up original engine.model to avoid memory leaks (显存残留) + if hasattr(engine, "model") and engine.model is not None: + try: + # Call .cpu() + del + torch.cuda.empty_cache for safety + engine.model.cpu() + except Exception: + pass + del engine.model + import gc + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Use the traditional HuggingFace transformer model for DTA smoke tests + from transformers import AutoModelForCausalLM + + engine.model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + device_map=torch.device(current_platform.device_type), + ) + + return engine + + +def create_fsdp_engine( + test_config: SimpleNamespace, + model_path: str | None = None, +) -> FSDPLMEngine: + """Create and initialize a single FSDP engine for tests.""" + setup_distributed_environment() + model_path = model_path or MODEL_PATHS["qwen2"] + world_size = dist.get_world_size() if dist.is_initialized() else 1 + parallel_strategy = ParallelStrategy(data_parallel_size=world_size) + ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=4, train_batch_size=4) + max_tokens_per_mb = int(test_config.max_tokens_per_mb) + + config = create_engine_config(model_path, "fsdp") + config.mb_spec = MicroBatchSpec.new( + config.mb_spec, max_tokens_per_mb=max_tokens_per_mb + ) + config.path = test_config.model_path + + engine = FSDPLMEngine(config) + engine.create_process_group(parallel_strategy=parallel_strategy) + engine.initialize(addr=None, ft_spec=ft_spec) + return engine + + +def destroy_test_engine(engine: FSDPLMEngine | ArchonLMEngine | None) -> None: + """Destroy a test engine and tear down the process group.""" + if engine is not None: + engine.destroy() + if dist.is_initialized(): + dist.destroy_process_group() diff --git a/tests/experimental/dta/test_dp.py b/tests/experimental/dta/test_dp.py new file mode 100644 index 0000000000..e2404b5e14 --- /dev/null +++ b/tests/experimental/dta/test_dp.py @@ -0,0 +1,144 @@ +from types import SimpleNamespace + +import torch + +from areal.experimental.dta.dp import ( + LB_by_DFS_and_TM, + LB_by_n_tokens, + LB_by_TM, + pred_time, + try_divide, +) +from areal.experimental.dta.token_trie import TokenTrie +from areal.experimental.dta.trie import CompressedTrie + + +class ConstantTimeModel: + def pred(self, stats: dict) -> float: + return 1.0 + + +class TreeTokenTimeModel: + def pred(self, stats: dict) -> float: + return float(stats["n_tree_tokens"]) + + +def _make_seqs() -> list[torch.Tensor]: + return [ + torch.tensor([1, 2, 3, 4], dtype=torch.long), + torch.tensor([1, 2, 9], dtype=torch.long), + torch.tensor([7, 8], dtype=torch.long), + torch.tensor([7, 8, 9, 10, 11], dtype=torch.long), + ] + + +def _assert_partition_valid(bins: list[list[int]], n_items: int, k: int) -> None: + assert len(bins) == k + flat = [idx for bucket in bins for idx in bucket] + assert sorted(flat) == list(range(n_items)) + + +def test_lb_by_n_tokens_assigns_all_sequences_once(): + """LB_by_n_tokens should output a valid partition of original indices.""" + token_seqs = _make_seqs() + bins = LB_by_n_tokens(token_seqs, K=2) + + _assert_partition_valid(bins, n_items=len(token_seqs), k=2) + + +def test_lb_by_tm_assigns_all_sequences_once(): + """LB_by_TM should map leaf buckets back to original sequence ids.""" + token_seqs = _make_seqs() + config = SimpleNamespace(K=2, mode="backward", block_size=2) + bins = LB_by_TM(token_seqs, ConstantTimeModel(), config) + + _assert_partition_valid(bins, n_items=len(token_seqs), k=2) + + +def test_try_divide_more_strict_limit_requires_more_partitions(): + """Lower cost_limit should produce at least as many divisions.""" + token_trie = TokenTrie(_make_seqs()) + compressed_trie = CompressedTrie(token_trie.lens, token_trie.lcp_lens) + config = SimpleNamespace(K=2, mode="backward", block_size=2) + n_seqs = len(token_trie.inputs) + divL = [0] * (config.K + 1) + divR = [n_seqs] * (config.K + 1) + model = TreeTokenTimeModel() + + strict_divs = try_divide( + compressed_trie, n_seqs, config, divL, divR, model, cost_limit=1.0 + ) + loose_divs = try_divide( + compressed_trie, n_seqs, config, divL, divR, model, cost_limit=1000.0 + ) + + assert len(strict_divs) >= len(loose_divs) + + +def test_lb_by_dfs_and_tm_assigns_all_sequences_once(): + """LB_by_DFS_and_TM should output a valid partition of original sequence ids.""" + token_seqs = _make_seqs() + config = SimpleNamespace(K=2, mode="backward", block_size=2) + bins = LB_by_DFS_and_TM(token_seqs, TreeTokenTimeModel(), config) + + _assert_partition_valid(bins, n_items=len(token_seqs), k=2) + + +def test_lb_by_dfs_and_tm_k1_returns_single_non_empty_bin(): + """K=1 should place all sequences in the only bucket.""" + token_seqs = _make_seqs() + config = SimpleNamespace(K=1, mode="backward", block_size=2) + bins = LB_by_DFS_and_TM(token_seqs, TreeTokenTimeModel(), config) + + assert len(bins) == 1 + assert sorted(bins[0]) == list(range(len(token_seqs))) + + +def test_lb_by_dfs_and_tm_empty_returns_k_empty_bins(): + """Empty input should return K empty bins without entering search.""" + config = SimpleNamespace(K=3, mode="backward", block_size=2) + bins = LB_by_DFS_and_TM([], TreeTokenTimeModel(), config) + + assert bins == [[], [], []] + + +def test_pred_time_rejects_unsupported_mode(): + """pred_time should fail fast on unknown scheduling mode.""" + token_trie = TokenTrie(_make_seqs()) + compressed_trie = CompressedTrie(token_trie.lens, token_trie.lcp_lens) + + try: + pred_time(compressed_trie, ConstantTimeModel(), mode="invalid", block_size=None) + except ValueError as exc: + assert "Unsupported mode" in str(exc) + else: + raise AssertionError("pred_time should raise ValueError for invalid mode") + + +def test_lb_by_n_tokens_empty_returns_k_empty_bins(): + """LB_by_n_tokens on an empty input should yield K empty bins.""" + bins = LB_by_n_tokens([], K=3) + assert bins == [[], [], []] + + +def test_lb_by_tm_empty_returns_k_empty_bins(): + """LB_by_TM should survive empty inputs now that CompressedTrie does.""" + config = SimpleNamespace(K=3, mode="backward", block_size=2) + bins = LB_by_TM([], ConstantTimeModel(), config) + assert bins == [[], [], []] + + +def test_pred_time_on_empty_trie_returns_finite_value(): + """pred_time should work on an empty compressed trie.""" + trie = TokenTrie([]) + compressed_trie = CompressedTrie(trie.lens, trie.lcp_lens) + + forward_time = pred_time( + compressed_trie, ConstantTimeModel(), mode="forward", block_size=None + ) + backward_time = pred_time( + compressed_trie, ConstantTimeModel(), mode="backward", block_size=2 + ) + + assert isinstance(forward_time, float) + assert isinstance(backward_time, float) diff --git a/tests/experimental/dta/test_dta_engine.py b/tests/experimental/dta/test_dta_engine.py new file mode 100644 index 0000000000..82c691eda6 --- /dev/null +++ b/tests/experimental/dta/test_dta_engine.py @@ -0,0 +1,37 @@ +import torch + +from areal.experimental.dta.dta_engine import DTAEngine + + +class _DummyConfig: + num_hidden_layers = 1 + num_key_value_heads = 1 + hidden_size = 8 + num_attention_heads = 1 + + +class _FailIfCalledModel: + def __call__(self, *args, **kwargs): + raise AssertionError("Model must not be called when new_tokens is empty.") + + +def test_push_forward_only_empty_segment_does_not_call_model(): + engine = DTAEngine( + model_config=_DummyConfig(), + device=torch.device("cpu"), + dtype=torch.float32, + max_seq_len=8, + forward_only=True, + ) + engine.model = _FailIfCalledModel() + engine.returns = [None] + + empty_tokens = torch.empty(0, dtype=torch.long) + engine.push_forward_only( + empty_tokens, + attach_list=[({"_sequence_batch_id": 0}, 0)], + ) + + assert engine.cur_len == 0 + assert isinstance(engine.returns[0], torch.Tensor) + assert engine.returns[0].numel() == 0 diff --git a/tests/experimental/dta/test_token_trie.py b/tests/experimental/dta/test_token_trie.py new file mode 100644 index 0000000000..c8710eb81d --- /dev/null +++ b/tests/experimental/dta/test_token_trie.py @@ -0,0 +1,175 @@ +"""Regression tests for the empty-input path in TokenTrie and CompressedTrie. + +Empty inputs (``inputs == []``) are a legal degenerate state: the trie +contains only the root node and produces empty traversal orders. These +tests pin down that contract so the three permute methods stop raising +``len(lcp_lens) must be ...`` from ``CompressedTrie.__init__``. +""" + +import torch + +from areal.experimental.dta.token_trie import TokenTrie, _leafization +from areal.experimental.dta.trie import CompressedTrie + + +def test_token_trie_empty_inputs_construction_is_legal(): + """TokenTrie([]) should build an empty but valid trie.""" + trie = TokenTrie([]) + + assert trie.inputs == [] + assert trie.attach_lists == [] + assert trie.lens == [] + assert trie.lcp_lens == [] + assert trie.n_sequences == 0 + assert trie.n_tokens == 0 + + +def test_token_trie_empty_inputs_get_stats_returns_zeros(): + """get_stats on an empty trie returns a well-formed, all-zero summary.""" + trie = TokenTrie([]) + + for mode in ("forward", "backward"): + stats = trie.get_stats(mode=mode, block_size=2) + assert stats["n_sequences"] == 0 + assert stats["n_tokens"] == 0 + assert stats["n_leaf_sequences"] == 0 + assert stats["n_tree_tokens"] == 0 + assert stats["sum_prefix_len"] == 0 + assert stats["sum_depth"] == 0 + if mode == "backward": + assert stats["n_f1_tokens"] == 0 + + +def test_token_trie_empty_inputs_permute_is_noop(): + """forward/backward/random permute are no-ops on empty tries, not raises.""" + for permute_name in ("forward_permute", "backward_permute", "random_permute"): + trie = TokenTrie([]) + getattr(trie, permute_name)() + assert trie.inputs == [] + assert trie.lens == [] + assert trie.lcp_lens == [] + + +def test_token_trie_explicit_permute_accepts_empty_order(): + """permute([]) must keep an empty trie in the same empty state.""" + trie = TokenTrie([]) + trie.permute([]) + assert trie.inputs == [] + assert trie.attach_lists == [] + assert trie.lens == [] + assert trie.lcp_lens == [] + + +def test_leafization_accepts_empty_lists(): + """_leafization([], []) should return three empty lists without raising.""" + inputs, attach_lists, lcp_lens = _leafization([], []) + assert inputs == [] + assert attach_lists == [] + assert lcp_lens == [] + + +def test_compressed_trie_empty_inputs_construction_has_only_root(): + """CompressedTrie([], []) is legal and contains just the root node.""" + ct = CompressedTrie([], []) + assert len(ct.nodes) == 1 + + root = ct.nodes[0] + assert root.depth == 0 + assert root.seq_id == -1 + assert root.child_ids == [] + + +def test_compressed_trie_empty_inputs_rejects_mismatched_lcp_lens(): + """Keep the invariant |lcp_lens| == max(|lens| - 1, 0) enforced.""" + try: + CompressedTrie([], [0]) + except ValueError as exc: + assert "len(lcp_lens)" in str(exc) + else: + raise AssertionError( + "CompressedTrie([], [0]) should raise on invariant mismatch" + ) + + +def test_compressed_trie_empty_inputs_dfs_orders_are_empty(): + """DFS order outputs on an empty trie must all be empty tuples.""" + ct = CompressedTrie([], []) + + order_fwd, lens_fwd, lcp_fwd = ct.get_order_forward() + assert order_fwd == [] + assert lens_fwd == [] + assert lcp_fwd == [] + + order_bwd, lens_bwd, lcp_bwd = ct.get_order_backward() + assert order_bwd == [] + assert lens_bwd == [] + assert lcp_bwd == [] + + order_rnd = ct.get_order_random() + assert order_rnd == [] + + +def test_token_trie_single_sequence_permute_roundtrip(): + """Sanity check: n == 1 path still works and does not touch ``child_ids[0]``.""" + seq = torch.tensor([1, 2, 3], dtype=torch.long) + trie = TokenTrie([seq]) + + trie.forward_permute() + assert len(trie.inputs) == 1 + assert torch.equal(trie.inputs[0], seq) + assert trie.lens == [3] + assert trie.lcp_lens == [] + + +def test_leafization_absorbs_leading_empty_into_next_leaf(): + """Empty sequences must be absorbed into the next non-empty leaf with length=0. + + Contract for downstream ``DTAEngine``: + - Leaves themselves are never empty (unless *all* inputs are empty). + - Empty source sequences survive as ``(attach, 0)`` entries inside the + merged leaf's ``attach_list`` so we can later decide whether they + contribute to loss / logprob returns. + """ + empty = torch.tensor([], dtype=torch.long) + non_empty = torch.tensor([1, 2, 3], dtype=torch.long) + + empty_att = {"_sequence_batch_id": 0} + non_empty_att = {"_sequence_batch_id": 1} + + input_ids_leafed, attach_lists, lcp_lens = _leafization( + [empty, non_empty], [empty_att, non_empty_att] + ) + + assert len(input_ids_leafed) == 1 + assert torch.equal(input_ids_leafed[0], non_empty) + assert lcp_lens == [] + + assert len(attach_lists) == 1 + leaf_attaches = attach_lists[0] + assert len(leaf_attaches) == 2 + + # The leading empty sequence shows up as a length-0 attachment on the leaf. + (att0, len0), (att1, len1) = leaf_attaches + assert att0 is empty_att + assert len0 == 0 + assert att1 is non_empty_att + assert len1 == 3 + + +def test_token_trie_mixed_empty_keeps_length_zero_entry(): + """TokenTrie public API preserves the (_, 0) attach for empty inputs.""" + empty = torch.tensor([], dtype=torch.long) + non_empty = torch.tensor([5, 6], dtype=torch.long) + + trie = TokenTrie([empty, non_empty]) + + assert len(trie.inputs) == 1 + assert trie.lens == [2] + assert trie.n_sequences == 2 + assert trie.n_tokens == 2 + + lengths = [length for _, length in trie.attach_lists[0]] + assert 0 in lengths, ( + "Empty source sequence must be preserved as a length-0 entry in " + "attach_list so DTAEngine can opt it out of loss/logprob computation." + ) diff --git a/tests/test_examples.py b/tests/test_examples.py index 03bfb46c36..a90234c344 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -850,7 +850,7 @@ def test_tau2(tmp_path_factory): "econfig.max_steps=3", # Limit steps for faster testing f"econfig.user_llm_base_url={user_llm_base_url}", "econfig.user_llm=openai/self-hosted-qwen3", - "actor.enable_tree_training=false", # Disable tree training for simpler test + "actor.tree_training_mode=disabled", # Disable tree training for simpler test "scheduler.type=local", "stats_logger.wandb.mode=disabled", timeout=600, diff --git a/tests/test_tree_training.py b/tests/test_tree_training.py index e88c358033..0485586d1a 100644 --- a/tests/test_tree_training.py +++ b/tests/test_tree_training.py @@ -166,7 +166,7 @@ def _check_nan_params(params: dict[str, torch.Tensor], label: str) -> list[str]: def _create_engine( engine_type: str, - enable_tree_training: bool = False, + tree_training_mode: str = "disabled", port: str = "7777", experiment_name: str = "test", max_tokens_per_mb: int = 256, @@ -194,7 +194,7 @@ def _create_engine( path=MODEL_PATH, mb_spec=MicroBatchSpec(**mb_spec_kwargs), optimizer=OptimizerConfig(), - enable_tree_training=enable_tree_training, + tree_training_mode=tree_training_mode, pad_to_maximum=True, ) @@ -245,7 +245,7 @@ def test_tree_training_forward(engine_type, tree_attn_backend): inputs = mock_tree_input() tree_engine = _create_engine( engine_type, - enable_tree_training=True, + tree_training_mode="sparse", port="7778", ) tree_engine.eval() @@ -347,7 +347,7 @@ def loss_weight_fn(input_data): inputs = mock_tree_input() tree_engine = _create_engine( engine_type, - enable_tree_training=True, + tree_training_mode="sparse", port="7778", experiment_name="test_tree", )