Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 100 additions & 2 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse
import json
import os
import warnings
from dataclasses import MISSING as dataclass_missing
from dataclasses import asdict, dataclass, field, fields
from enum import Enum
Expand Down Expand Up @@ -1132,10 +1133,52 @@ class TrainEngineConfig:
metadata={"help": "peft method type. Only LoRA is supported for now."},
)

# Tree training
# Tree training (str, not Literal: OmegaConf.structured rejects Literal here)
tree_training_mode: str = field(
default="disabled",
metadata={
"help": (
"Tree training mode. "
"'sparse' enables tree training with Flex Attention module (flex attention), "
"'dta' enables Dynamic Tree Attention (dynamic tree training), "
"'disabled' disables tree training."
),
"choices": ["disabled", "sparse", "dta"],
},
)
enable_tree_training: bool = field(
default=False,
metadata={"help": "Enable tree training with flex attention module."},
metadata={
"help": (
"[DEPRECATED] Use tree_training_mode instead. "
"enable_tree_training=True maps to tree_training_mode='sparse'. "
"If both are set, tree_training_mode takes precedence."
)
},
)
Comment on lines 1149 to +1158
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this option. There are too much code changes introduced for this fallback. It is better to just make it clear in the help message of tree_training_mode option.

dta_block_size: int = field(
default=2048,
metadata={
"help": (
"Block size for Dynamic Tree Attention. "
"Set to -1 to disable block-size limit. "
"Only effective when tree_training_mode='dta'."
)
},
)
packing_algorithm: str = field(
default="ffd",
metadata={
"help": (
"Trajectory packing across data-parallel ranks during distributed rollout "
"(``redistribute_trajectories``). "
"'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order "
"n_tree_tokens. "
"Not to be confused with ``mb_spec.packing_algorithm``, which only "
"controls micro-batch formation (ffd/kk) during training."
),
"choices": ["ffd", "kk", "dta"],
},
)

# Scheduling
Expand Down Expand Up @@ -1208,6 +1251,45 @@ def __post_init__(self):
"memory_efficient_load is for loading pretrained weights on CPU, "
"but init_from_scratch creates a model without loading any weights."
)
valid_tree_modes = {"disabled", "sparse", "dta"}
if self.tree_training_mode not in valid_tree_modes:
raise ValueError(
f"tree_training_mode must be one of {valid_tree_modes}, got '{self.tree_training_mode}'"
)
valid_rollout_packing = {"ffd", "kk", "dta"}
if self.packing_algorithm not in valid_rollout_packing:
raise ValueError(
f"packing_algorithm (rollout) must be one of {valid_rollout_packing}, "
f"got '{self.packing_algorithm}'"
)
if self.tree_training_mode == "dta":
if self.dta_block_size == 0 or self.dta_block_size < -1:
raise ValueError(
f"dta_block_size must be -1 or a positive integer when tree_training_mode='dta', got {self.dta_block_size}."
)

if self.enable_tree_training:
warnings.warn(
"`enable_tree_training` is deprecated and will be removed in a future version. "
"Use `tree_training_mode='sparse'` instead.",
FutureWarning,
stacklevel=2,
)
if self.tree_training_mode != "disabled":
warnings.warn(
f"`tree_training_mode` is already set to '{self.tree_training_mode}', "
"`enable_tree_training=True` is ignored.",
FutureWarning,
stacklevel=2,
)
else:
self.tree_training_mode = "sparse"
warnings.warn(
"`tree_training_mode` is overridden to 'sparse' from deprecated "
"`enable_tree_training=True`.",
FutureWarning,
stacklevel=2,
)
Comment on lines +1271 to +1292
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove these.

if self._version not in ("v1", "v2"):
raise ValueError(
f"_version must be either 'v1' or 'v2', got '{self._version}'"
Expand Down Expand Up @@ -1576,6 +1658,22 @@ def __post_init__(self):
"Please set `actor.use_decoupled_loss=false` in your configuration."
)

if self.packing_algorithm == "dta":
for norm_name in ["adv_norm", "reward_norm"]:
norm_config = getattr(self, norm_name)
if norm_config is not None:
if (
norm_config.mean_level == "group"
or norm_config.std_level == "group"
):
raise ValueError(
f"{norm_name} uses 'group' level normalization, which is incompatible "
"with packing_algorithm='dta'. DTA requires sequence-level independence, "
"but 'group' normalization relies on contiguous group slices. Please use "
"'batch' level normalization or set packing_algorithm='ffd'. "
"(Group-level support for DTA will be provided in a future release.)"
)

super().__post_init__()


Expand Down
27 changes: 16 additions & 11 deletions areal/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,14 @@ def __init__(self, config: TrainEngineConfig):
self.dp_rank: int

self.is_offload: bool = False
self.tree_training_mode: str = self.config.tree_training_mode
if self.tree_training_mode == "dta":
raise ValueError(
"tree_training_mode='dta' is only supported by ArchonEngine. "
"Please use Archon backend or set tree_training_mode to 'disabled'/'sparse'."
)
self._offload_depth: int = 0
self._per_layer_optim_wrapper: PerLayerOptimWrapper | None = None
self.enable_tree_training: bool = self.config.enable_tree_training

@classmethod
def from_pretrained(
Expand Down Expand Up @@ -383,7 +388,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
# Create device model
self._create_device_model()

if self.enable_tree_training and self.parallel_helper.sp_size > 1:
if self.tree_training_mode == "sparse" and self.parallel_helper.sp_size > 1:
raise ValueError(
"Tree training currently cannot be enabled with sp_size > 1."
)
Expand All @@ -394,7 +399,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
shard_vision_across_sp=self.config.fsdp.shard_vision_across_sp,
)
# Monkey patch: replace attention's forward() with tree attention.
patch_fsdp_for_tree_training(enable=self.enable_tree_training)
patch_fsdp_for_tree_training(enable=self.tree_training_mode == "sparse")

if self.config.use_lora:
self._apply_peft_wrapper()
Expand Down Expand Up @@ -732,7 +737,7 @@ def forward_backward_batch(
# module_fsdp.py reads these keys from the **kwargs that transformers
# forwards through.
tree_attn_keys: list[str] = []
if self.enable_tree_training and ctx.trie_node is not None:
if self.tree_training_mode == "sparse" and ctx.trie_node is not None:
padded_size = mb_item.padded_to_length
assert padded_size is not None
tree_kwargs = build_tree_attn_kwargs(
Expand Down Expand Up @@ -880,8 +885,8 @@ def process_output(logits: torch.Tensor, ctx_dict: dict[str, Any]) -> None:
self.forward_backward_batch(mb_list, process_output, forward_only=True)

# Step 4: Aggregate and reorder outputs
if self.enable_tree_training:
result = merge_packed_tree_results(outputs, batch_size)
if self.tree_training_mode == "sparse":
return merge_packed_tree_results(outputs, batch_size)
else:
result = reorder_and_pad_outputs(
outputs, output_seqlens, mb_list, aggregate_fn
Expand Down Expand Up @@ -1589,7 +1594,7 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList:
input_ = input_.copy()

# Tree training path
if self.enable_tree_training:
if self.tree_training_mode == "sparse":
mb_list = build_packed_tree_batch(
input_,
mb_spec=self.config.mb_spec,
Expand Down Expand Up @@ -1855,12 +1860,12 @@ def _compute_logprobs_and_loss(
if local_weight == 0:
return logits.mean() * 0.0

if self.config.is_critic and self.enable_tree_training:
if self.config.is_critic and self.tree_training_mode == "sparse":
raise NotImplementedError(
"Tree training with critic model is not supported yet."
)
if not self.config.is_critic:
if self.enable_tree_training:
if self.tree_training_mode == "sparse":
# Handle dummy trie (empty tree for DP synchronization)
# When trie has no sequences, return zero loss with grad connection
if ctx.trie_node is None or not ctx.trie_node.all_sequence_ids:
Expand Down Expand Up @@ -1918,12 +1923,12 @@ def _compute_forward_result(
ctx: FSDPTrainContext,
) -> torch.Tensor | dict[int, torch.Tensor]:
"""Compute forward output (logprobs or values)."""
if self.config.is_critic and self.enable_tree_training:
if self.config.is_critic and self.tree_training_mode == "sparse":
raise NotImplementedError(
"Tree training with critic model is not supported yet."
)
if not self.config.is_critic:
if self.enable_tree_training:
if self.tree_training_mode == "sparse":
result = _gather_packed_tree_logprobs(
logits,
ctx.trie_node,
Expand Down
23 changes: 14 additions & 9 deletions areal/engine/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,13 @@ def __init__(self, config: TrainEngineConfig):
self.seed: int = 0
self.own_global_group: bool = False
self.is_offload: bool = False
self.tree_training_mode: str = self.config.tree_training_mode
if self.tree_training_mode == "dta":
raise ValueError(
"tree_training_mode='dta' is only supported by ArchonEngine. "
"Please use Archon backend or set tree_training_mode to 'disabled'/'sparse'."
)
self._offload_depth: int = 0
self.enable_tree_training: bool = self.config.enable_tree_training
# FP8 configuration
self.fp8_config = self.mcore_config.fp8_config
self.enable_fp8: bool = self.fp8_config is not None
Expand Down Expand Up @@ -323,7 +328,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
self.tokenizer = load_hf_tokenizer(self.config.path)

with patch_bridge_for_tree_training(
self.enable_tree_training and self.bridge_cls == "mbridge"
self.tree_training_mode == "sparse" and self.bridge_cls == "mbridge"
):
self.bridge = self._build_hf_mcore_bridge()

Expand Down Expand Up @@ -807,7 +812,7 @@ def forward_step(batch_iter, model):
# save_for_backward() which can only save torch.Tensor objects;
# BlockMask is recreated inside PytorchFlexAttention.forward().
tree_attn_keys: list[str] = []
if self.enable_tree_training:
if self.tree_training_mode == "sparse":
trie_node = mb_input.padded_mb.get("trie_node", None)
# Ensure trie_node is also in orig_mb for _compute_logprobs_and_loss
if trie_node is not None and "trie_node" not in mb_input.orig_mb:
Expand Down Expand Up @@ -1029,7 +1034,7 @@ def process_output(output: torch.Tensor, inputs: dict[str, Any]) -> None:
# Step 4: Aggregate, reorder, and broadcast outputs
res = None
if mpu.is_pipeline_last_stage():
if self.enable_tree_training:
if self.tree_training_mode == "sparse":
res = merge_packed_tree_results(outputs, batch_size)
else:
res = reorder_and_pad_outputs(
Expand Down Expand Up @@ -1812,7 +1817,7 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList:
pp_size = self.parallel_strategy.pipeline_parallel_size
cp_size = self.parallel_strategy.context_parallel_size
tp_size = self.parallel_strategy.tensor_parallel_size
if self.enable_tree_training:
if self.tree_training_mode == "sparse":
assert cp_size == 1, (
"Context parallelism is not supported in tree training."
)
Expand Down Expand Up @@ -1922,12 +1927,12 @@ def _compute_logprobs_and_loss(
if local_weight == 0:
return output.mean() * 0.0

if self.config.is_critic and self.enable_tree_training:
if self.config.is_critic and self.tree_training_mode == "sparse":
raise NotImplementedError(
"Tree training with critic model is not supported yet."
)
if not self.config.is_critic:
if self.enable_tree_training:
if self.tree_training_mode == "sparse":
# Handle dummy trie (empty tree for DP synchronization)
# When trie has no sequences, return zero loss with grad connection
trie_node = inputs.get("trie_node")
Expand Down Expand Up @@ -1987,12 +1992,12 @@ def _compute_forward_result(
output: torch.Tensor,
inputs: dict[str, Any],
) -> torch.Tensor | dict[int, torch.Tensor]:
if self.config.is_critic and self.enable_tree_training:
if self.config.is_critic and self.tree_training_mode == "sparse":
raise NotImplementedError(
"Tree training with critic model is not supported yet."
)
if not self.config.is_critic:
if self.enable_tree_training:
if self.tree_training_mode == "sparse":
logprobs = _gather_packed_tree_logprobs(
output,
inputs["trie_node"],
Expand Down
3 changes: 3 additions & 0 deletions areal/experimental/archon/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0

"""Archon experimental testing helpers exposed under `areal.experimental`."""
Comment on lines +1 to +3
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

3 changes: 3 additions & 0 deletions areal/experimental/archon/torchrun/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0

"""Torchrun helpers for Archon experimental tests."""
Comment on lines +1 to +3
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

8 changes: 8 additions & 0 deletions areal/experimental/archon/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SPDX-License-Identifier: Apache-2.0

"""Small Archon utility helpers shared by experimental test runners."""


def strip_wrapper_prefixes(name: str) -> str:
"""Drop wrapper-generated path segments from parameter names."""
return name.replace("._checkpoint_wrapped_module", "").replace("._orig_mod", "")
Comment on lines +1 to +8
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these are patches for testing only, can we move them to somewhere under the tests/ folder?

Loading
Loading