feat(archon): add ZERO1 DTA path with configs and tests#1287
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces Dynamic Tree Attention (DTA) as a new tree training mode, complementing the existing sparse mode. The changes include the core DTA engine, trie-based sequence management, and integration into the Archon engine with Zero-1 optimization. Trajectory redistribution logic is updated to support DTA-based partitioning, and attention modules are enhanced to handle variable sequence lengths and right-aligned KV caches. Feedback highlights critical logic flaws in the redistribution process and an incorrect type assertion in the RL trainer. High-severity concerns were raised regarding potential OOM errors from static KV cache preallocation and the use of hardcoded absolute paths in experimental logging. Additionally, design improvements are needed for fragile relative path calculations in module loading, memory-inefficient mask creation in SDPA, and a potential IndexError when accessing empty KV caches.
| batch = broadcast_tensor_container( | ||
| batch, | ||
| src_rank=self.train_engine.current_data_parallel_head(), |
There was a problem hiding this comment.
The redistribution logic appears flawed. redistribute_trajectories is called locally on each rank, but if trajectories is only provided to the DP head, other ranks will return empty batches. Subsequently calling broadcast_tensor_container will result in all ranks receiving the same data shard from the head, which is not true redistribution.
| with open( | ||
| "/data/jiarui/dta/AReaL/.cursor/debug-b1b34b.log", "a", encoding="utf-8" | ||
| ) as f: | ||
| f.write(json.dumps(payload, ensure_ascii=True) + "\n") |
| else model_config.hidden_size // model_config.num_attention_heads | ||
| ) | ||
|
|
||
| kv_buffer_shape = (1, n_kv_heads, max_seq_len, head_dim) |
There was a problem hiding this comment.
Preallocating KV cache buffers for the full max_seq_len across all layers and for both forward and backward (grad_kv) can lead to massive memory consumption. For large sequence lengths and deep models, this will likely cause OOM errors even on high-end GPUs. Consider dynamic allocation or a more memory-efficient stack management strategy.
| ): | ||
| values = self.critic.compute_values(rollout_batch) | ||
| if config.actor.packing_algorithm == "dta": | ||
| assert isinstance(values, list), ( |
| repo_root = Path(__file__).resolve().parents[4] | ||
| impl_path = ( | ||
| repo_root | ||
| / "tests" | ||
| / "experimental" | ||
| / "archon" | ||
| / "torchrun" | ||
| / "training_test_config.py" | ||
| ) |
| mask = torch.full((q_len, k_len), float("-inf"), device=device, dtype=dtype) | ||
| mask = mask.masked_fill(same_seq & causal, 0.0) |
| max_seqlen = int(tokens.shape[1]) + int( | ||
| past_key_values.layers[0].keys.shape[2] | ||
| ) |
825a57b to
15807a1
Compare
Introduce Dynamic Tree Attention for Archon ZERO1 data-parallel training with tree_training_mode wiring, DTA engine and trie stack, and rollout redistribution for DTA packing. FSDP and Megatron reject tree_training_mode=dta; enable_tree_training is deprecated in favor of tree_training_mode. Key changes: - Add DTA modules (engine, trie, DP load balance) and Archon model attention paths - Extend TrainEngineConfig and PPO actor validation for DTA vs group norms - Add GSM8K/Tau2 example YAMLs, figures, and testing documentation - Add experimental Archon torchrun tests and DTA unit tests; keep training config under tests/
8756d00 to
497e5e2
Compare
| enable_tree_training: bool = field( | ||
| default=False, | ||
| metadata={"help": "Enable tree training with flex attention module."}, | ||
| metadata={ | ||
| "help": ( | ||
| "[DEPRECATED] Use tree_training_mode instead. " | ||
| "enable_tree_training=True maps to tree_training_mode='sparse'. " | ||
| "If both are set, tree_training_mode takes precedence." | ||
| ) | ||
| }, | ||
| ) |
There was a problem hiding this comment.
We can remove this option. There are too much code changes introduced for this fallback. It is better to just make it clear in the help message of tree_training_mode option.
| if self.enable_tree_training: | ||
| warnings.warn( | ||
| "`enable_tree_training` is deprecated and will be removed in a future version. " | ||
| "Use `tree_training_mode='sparse'` instead.", | ||
| FutureWarning, | ||
| stacklevel=2, | ||
| ) | ||
| if self.tree_training_mode != "disabled": | ||
| warnings.warn( | ||
| f"`tree_training_mode` is already set to '{self.tree_training_mode}', " | ||
| "`enable_tree_training=True` is ignored.", | ||
| FutureWarning, | ||
| stacklevel=2, | ||
| ) | ||
| else: | ||
| self.tree_training_mode = "sparse" | ||
| warnings.warn( | ||
| "`tree_training_mode` is overridden to 'sparse' from deprecated " | ||
| "`enable_tree_training=True`.", | ||
| FutureWarning, | ||
| stacklevel=2, | ||
| ) |
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Small Archon utility helpers shared by experimental test runners.""" | ||
|
|
||
|
|
||
| def strip_wrapper_prefixes(name: str) -> str: | ||
| """Drop wrapper-generated path segments from parameter names.""" | ||
| return name.replace("._checkpoint_wrapped_module", "").replace("._orig_mod", "") |
There was a problem hiding this comment.
If these are patches for testing only, can we move them to somewhere under the tests/ folder?
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Archon experimental testing helpers exposed under `areal.experimental`.""" |
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Torchrun helpers for Archon experimental tests.""" |
There was a problem hiding this comment.
Instead of putting all these configs in examples/tau2/dta, you should remove these DTA specific configs and just modify the examples/tau2/README.md to tell users how to enable DTA in tau2 example.
Also, put the reward curve in the result section as a reference.
There was a problem hiding this comment.
Please refer to this file to write the integration test for DTA on archon engine, instead of rewriting the whole end-to-end training test.
https://github.com/inclusionAI/AReaL/blob/main/tests/experimental/archon/test_qwen3_parallelize.py
There was a problem hiding this comment.
This file is not needed. Remove it.
| | `eps_clip_higher` | float \| None | `None` | 上裁剪边界:设置时,比率裁剪到 `[1-eps_clip, 1+eps_clip_higher]` | | ||
|
|
||
| 当 `eps_clip_higher` 为 `None` 时,使用对称裁剪: $\text{clip}(r, 1-\epsilon, 1+\epsilon)$。 | ||
| 当 `eps_clip_higher` 为 `None` 时,使用对称裁剪: $\\text{clip}(r, 1-\\epsilon, 1+\\epsilon)$。 |
There was a problem hiding this comment.
Revert this change since it should not belong to this PR.
There was a problem hiding this comment.
This figure is not used anywhere. You could write a doc for DTA and put it in your doc in the next PR. Just remove it for now.
Description
This PR adds a consolidated ZERO1 + Dynamic Tree Attention (DTA) path for Archon: tree-training mode wiring (
tree_training_mode, rolloutpacking_algorithm), DTA runtime modules (trie/stack engine, DP partitioning), integration with Archon attention and ZERO1 DP, runnable GSM8K/Tau2 YAML examples, docs and figures, and experimental/unit tests (Archon torchrun helpers live undertests/only).Related Issue
N/A
Type of Change
Checklist
pre-commit run --all-files)./docs/build_all.sh) (docs/examples updated; full docs build not re-run in this update step)main/review-prcommand/create-prBreaking Change Details (if applicable):
N/A
Additional Context
/create-pr: fetchedorigin/main, rebased (already up to date), squashed to one commit, force-pushedfeat/zero1-dta-archon-dp.enable_tree_trainingremains deprecated in favor oftree_training_mode; FSDP/Megatron raise iftree_training_mode=dta.