Skip to content

feat(archon): add ZERO1 DTA path with configs and tests#1287

Open
ezoicoder wants to merge 1 commit into
areal-project:mainfrom
ezoicoder:feat/zero1-dta-archon-dp
Open

feat(archon): add ZERO1 DTA path with configs and tests#1287
ezoicoder wants to merge 1 commit into
areal-project:mainfrom
ezoicoder:feat/zero1-dta-archon-dp

Conversation

@ezoicoder
Copy link
Copy Markdown
Collaborator

@ezoicoder ezoicoder commented Apr 28, 2026

Description

This PR adds a consolidated ZERO1 + Dynamic Tree Attention (DTA) path for Archon: tree-training mode wiring (tree_training_mode, rollout packing_algorithm), DTA runtime modules (trie/stack engine, DP partitioning), integration with Archon attention and ZERO1 DP, runnable GSM8K/Tau2 YAML examples, docs and figures, and experimental/unit tests (Archon torchrun helpers live under tests/ only).

Related Issue

N/A

Type of Change

  • 🐛 Bug fix
  • ✨ New feature
  • 💥 Breaking change
  • 📝 Documentation update
  • ♻️ Refactoring
  • ⚡ Performance improvement
  • ✅ Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • Pre-commit hooks pass (pre-commit run --all-files)
  • Relevant tests pass; new tests added for new functionality (run targeted suites locally before merge; GPU/multi-GPU tests optional)
  • Documentation updated (if applicable; built with ./docs/build_all.sh) (docs/examples updated; full docs build not re-run in this update step)
  • Branch is up to date with main
  • Self-reviewed via /review-pr command
  • This PR was created by a coding agent via /create-pr
  • This PR is a breaking change

Breaking Change Details (if applicable):

N/A

Additional Context

  • Branch refreshed via /create-pr: fetched origin/main, rebased (already up to date), squashed to one commit, force-pushed feat/zero1-dta-archon-dp.
  • enable_tree_training remains deprecated in favor of tree_training_mode; FSDP/Megatron raise if tree_training_mode=dta.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Dynamic Tree Attention (DTA) as a new tree training mode, complementing the existing sparse mode. The changes include the core DTA engine, trie-based sequence management, and integration into the Archon engine with Zero-1 optimization. Trajectory redistribution logic is updated to support DTA-based partitioning, and attention modules are enhanced to handle variable sequence lengths and right-aligned KV caches. Feedback highlights critical logic flaws in the redistribution process and an incorrect type assertion in the RL trainer. High-severity concerns were raised regarding potential OOM errors from static KV cache preallocation and the use of hardcoded absolute paths in experimental logging. Additionally, design improvements are needed for fragile relative path calculations in module loading, memory-inefficient mask creation in SDPA, and a potential IndexError when accessing empty KV caches.

Comment on lines 278 to 280
batch = broadcast_tensor_container(
batch,
src_rank=self.train_engine.current_data_parallel_head(),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The redistribution logic appears flawed. redistribute_trajectories is called locally on each rank, but if trajectories is only provided to the DP head, other ranks will return empty batches. Subsequently calling broadcast_tensor_container will result in all ranks receiving the same data shard from the head, which is not true redistribution.

Comment on lines +28 to +31
with open(
"/data/jiarui/dta/AReaL/.cursor/debug-b1b34b.log", "a", encoding="utf-8"
) as f:
f.write(json.dumps(payload, ensure_ascii=True) + "\n")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Hardcoded absolute path for debug logging detected. This should be removed or replaced with a configurable path or standard logging mechanism to avoid issues in different environments.

else model_config.hidden_size // model_config.num_attention_heads
)

kv_buffer_shape = (1, n_kv_heads, max_seq_len, head_dim)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Preallocating KV cache buffers for the full max_seq_len across all layers and for both forward and backward (grad_kv) can lead to massive memory consumption. For large sequence lengths and deep models, this will likely cause OOM errors even on high-end GPUs. Consider dynamic allocation or a more memory-efficient stack management strategy.

):
values = self.critic.compute_values(rollout_batch)
if config.actor.packing_algorithm == "dta":
assert isinstance(values, list), (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion assert isinstance(values, list) will fail because ArchonLMEngine.forward_batch (and thus compute_values) returns a padded tensor, not a list, even in DTA mode.

Comment on lines +36 to +44
repo_root = Path(__file__).resolve().parents[4]
impl_path = (
repo_root
/ "tests"
/ "experimental"
/ "archon"
/ "torchrun"
/ "training_test_config.py"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _load_impl_module function relies on a fragile relative path calculation (parents[4]) and assumes that test files are present in the package distribution. Library code should generally not depend on files located in the tests/ directory.

Comment on lines +97 to 98
mask = torch.full((q_len, k_len), float("-inf"), device=device, dtype=dtype)
mask = mask.masked_fill(same_seq & causal, 0.0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Creating a full boolean mask of shape [q_len, k_len] using unsqueeze operations can lead to OOM for long sequences (e.g., 32k tokens results in ~1B elements). Consider using a more memory-efficient approach or specialized kernels for block-causal masks.

Comment on lines +409 to +411
max_seqlen = int(tokens.shape[1]) + int(
past_key_values.layers[0].keys.shape[2]
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Accessing past_key_values.layers[0] without checking if layers is non-empty will raise an IndexError during the first step of generation or training when the cache is empty.

@ezoicoder ezoicoder force-pushed the feat/zero1-dta-archon-dp branch 3 times, most recently from 825a57b to 15807a1 Compare April 30, 2026 12:42
@ezoicoder ezoicoder changed the title feat(archon): add ZERO1 DTA for Archon with DP-only path feat(archon): add ZERO1 DTA path with configs and tests Apr 30, 2026
Introduce Dynamic Tree Attention for Archon ZERO1 data-parallel training with tree_training_mode wiring, DTA engine and trie stack, and rollout redistribution for DTA packing. FSDP and Megatron reject tree_training_mode=dta; enable_tree_training is deprecated in favor of tree_training_mode.

Key changes:
- Add DTA modules (engine, trie, DP load balance) and Archon model attention paths
- Extend TrainEngineConfig and PPO actor validation for DTA vs group norms
- Add GSM8K/Tau2 example YAMLs, figures, and testing documentation
- Add experimental Archon torchrun tests and DTA unit tests; keep training config under tests/
@ezoicoder ezoicoder force-pushed the feat/zero1-dta-archon-dp branch from 8756d00 to 497e5e2 Compare April 30, 2026 13:21
Comment thread areal/api/cli_args.py
Comment on lines 1149 to +1158
enable_tree_training: bool = field(
default=False,
metadata={"help": "Enable tree training with flex attention module."},
metadata={
"help": (
"[DEPRECATED] Use tree_training_mode instead. "
"enable_tree_training=True maps to tree_training_mode='sparse'. "
"If both are set, tree_training_mode takes precedence."
)
},
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this option. There are too much code changes introduced for this fallback. It is better to just make it clear in the help message of tree_training_mode option.

Comment thread areal/api/cli_args.py
Comment on lines +1271 to +1292
if self.enable_tree_training:
warnings.warn(
"`enable_tree_training` is deprecated and will be removed in a future version. "
"Use `tree_training_mode='sparse'` instead.",
FutureWarning,
stacklevel=2,
)
if self.tree_training_mode != "disabled":
warnings.warn(
f"`tree_training_mode` is already set to '{self.tree_training_mode}', "
"`enable_tree_training=True` is ignored.",
FutureWarning,
stacklevel=2,
)
else:
self.tree_training_mode = "sparse"
warnings.warn(
"`tree_training_mode` is overridden to 'sparse' from deprecated "
"`enable_tree_training=True`.",
FutureWarning,
stacklevel=2,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove these.

Comment on lines +1 to +8
# SPDX-License-Identifier: Apache-2.0

"""Small Archon utility helpers shared by experimental test runners."""


def strip_wrapper_prefixes(name: str) -> str:
"""Drop wrapper-generated path segments from parameter names."""
return name.replace("._checkpoint_wrapped_module", "").replace("._orig_mod", "")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these are patches for testing only, can we move them to somewhere under the tests/ folder?

Comment on lines +1 to +3
# SPDX-License-Identifier: Apache-2.0

"""Archon experimental testing helpers exposed under `areal.experimental`."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Comment on lines +1 to +3
# SPDX-License-Identifier: Apache-2.0

"""Torchrun helpers for Archon experimental tests."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of putting all these configs in examples/tau2/dta, you should remove these DTA specific configs and just modify the examples/tau2/README.md to tell users how to enable DTA in tau2 example.

Also, put the reward curve in the result section as a reference.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer to this file to write the integration test for DTA on archon engine, instead of rewriting the whole end-to-end training test.

https://github.com/inclusionAI/AReaL/blob/main/tests/experimental/archon/test_qwen3_parallelize.py

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is not needed. Remove it.

| `eps_clip_higher` | float \| None | `None` | 上裁剪边界:设置时,比率裁剪到 `[1-eps_clip, 1+eps_clip_higher]` |

当 `eps_clip_higher` 为 `None` 时,使用对称裁剪: $\text{clip}(r, 1-\epsilon, 1+\epsilon)$。
当 `eps_clip_higher` 为 `None` 时,使用对称裁剪: $\\text{clip}(r, 1-\\epsilon, 1+\\epsilon)$。
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert this change since it should not belong to this PR.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This figure is not used anywhere. You could write a doc for DTA and put it in your doc in the next PR. Just remove it for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants