-
Notifications
You must be signed in to change notification settings - Fork 495
feat(archon): add ZERO1 DTA path with configs and tests #1287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |
| import argparse | ||
| import json | ||
| import os | ||
| import warnings | ||
| from dataclasses import MISSING as dataclass_missing | ||
| from dataclasses import asdict, dataclass, field, fields | ||
| from enum import Enum | ||
|
|
@@ -1132,10 +1133,52 @@ class TrainEngineConfig: | |
| metadata={"help": "peft method type. Only LoRA is supported for now."}, | ||
| ) | ||
|
|
||
| # Tree training | ||
| # Tree training (str, not Literal: OmegaConf.structured rejects Literal here) | ||
| tree_training_mode: str = field( | ||
| default="disabled", | ||
| metadata={ | ||
| "help": ( | ||
| "Tree training mode. " | ||
| "'sparse' enables tree training with Flex Attention module (flex attention), " | ||
| "'dta' enables Dynamic Tree Attention (dynamic tree training), " | ||
| "'disabled' disables tree training." | ||
| ), | ||
| "choices": ["disabled", "sparse", "dta"], | ||
| }, | ||
| ) | ||
| enable_tree_training: bool = field( | ||
| default=False, | ||
| metadata={"help": "Enable tree training with flex attention module."}, | ||
| metadata={ | ||
| "help": ( | ||
| "[DEPRECATED] Use tree_training_mode instead. " | ||
| "enable_tree_training=True maps to tree_training_mode='sparse'. " | ||
| "If both are set, tree_training_mode takes precedence." | ||
| ) | ||
| }, | ||
| ) | ||
| dta_block_size: int = field( | ||
| default=2048, | ||
| metadata={ | ||
| "help": ( | ||
| "Block size for Dynamic Tree Attention. " | ||
| "Set to -1 to disable block-size limit. " | ||
| "Only effective when tree_training_mode='dta'." | ||
| ) | ||
| }, | ||
| ) | ||
| packing_algorithm: str = field( | ||
| default="ffd", | ||
| metadata={ | ||
| "help": ( | ||
| "Trajectory packing across data-parallel ranks during distributed rollout " | ||
| "(``redistribute_trajectories``). " | ||
| "'ffd' / 'kk' balance by total sequence length; 'dta' uses DTA DFS-order " | ||
| "n_tree_tokens. " | ||
| "Not to be confused with ``mb_spec.packing_algorithm``, which only " | ||
| "controls micro-batch formation (ffd/kk) during training." | ||
| ), | ||
| "choices": ["ffd", "kk", "dta"], | ||
| }, | ||
| ) | ||
|
|
||
| # Scheduling | ||
|
|
@@ -1208,6 +1251,45 @@ def __post_init__(self): | |
| "memory_efficient_load is for loading pretrained weights on CPU, " | ||
| "but init_from_scratch creates a model without loading any weights." | ||
| ) | ||
| valid_tree_modes = {"disabled", "sparse", "dta"} | ||
| if self.tree_training_mode not in valid_tree_modes: | ||
| raise ValueError( | ||
| f"tree_training_mode must be one of {valid_tree_modes}, got '{self.tree_training_mode}'" | ||
| ) | ||
| valid_rollout_packing = {"ffd", "kk", "dta"} | ||
| if self.packing_algorithm not in valid_rollout_packing: | ||
| raise ValueError( | ||
| f"packing_algorithm (rollout) must be one of {valid_rollout_packing}, " | ||
| f"got '{self.packing_algorithm}'" | ||
| ) | ||
| if self.tree_training_mode == "dta": | ||
| if self.dta_block_size == 0 or self.dta_block_size < -1: | ||
| raise ValueError( | ||
| f"dta_block_size must be -1 or a positive integer when tree_training_mode='dta', got {self.dta_block_size}." | ||
| ) | ||
|
|
||
| if self.enable_tree_training: | ||
| warnings.warn( | ||
| "`enable_tree_training` is deprecated and will be removed in a future version. " | ||
| "Use `tree_training_mode='sparse'` instead.", | ||
| FutureWarning, | ||
| stacklevel=2, | ||
| ) | ||
| if self.tree_training_mode != "disabled": | ||
| warnings.warn( | ||
| f"`tree_training_mode` is already set to '{self.tree_training_mode}', " | ||
| "`enable_tree_training=True` is ignored.", | ||
| FutureWarning, | ||
| stacklevel=2, | ||
| ) | ||
| else: | ||
| self.tree_training_mode = "sparse" | ||
| warnings.warn( | ||
| "`tree_training_mode` is overridden to 'sparse' from deprecated " | ||
| "`enable_tree_training=True`.", | ||
| FutureWarning, | ||
| stacklevel=2, | ||
| ) | ||
|
Comment on lines
+1271
to
+1292
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove these. |
||
| if self._version not in ("v1", "v2"): | ||
| raise ValueError( | ||
| f"_version must be either 'v1' or 'v2', got '{self._version}'" | ||
|
|
@@ -1576,6 +1658,22 @@ def __post_init__(self): | |
| "Please set `actor.use_decoupled_loss=false` in your configuration." | ||
| ) | ||
|
|
||
| if self.packing_algorithm == "dta": | ||
| for norm_name in ["adv_norm", "reward_norm"]: | ||
| norm_config = getattr(self, norm_name) | ||
| if norm_config is not None: | ||
| if ( | ||
| norm_config.mean_level == "group" | ||
| or norm_config.std_level == "group" | ||
| ): | ||
| raise ValueError( | ||
| f"{norm_name} uses 'group' level normalization, which is incompatible " | ||
| "with packing_algorithm='dta'. DTA requires sequence-level independence, " | ||
| "but 'group' normalization relies on contiguous group slices. Please use " | ||
| "'batch' level normalization or set packing_algorithm='ffd'. " | ||
| "(Group-level support for DTA will be provided in a future release.)" | ||
| ) | ||
|
|
||
| super().__post_init__() | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Archon experimental testing helpers exposed under `areal.experimental`.""" | ||
|
Comment on lines
+1
to
+3
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Torchrun helpers for Archon experimental tests.""" | ||
|
Comment on lines
+1
to
+3
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Small Archon utility helpers shared by experimental test runners.""" | ||
|
|
||
|
|
||
| def strip_wrapper_prefixes(name: str) -> str: | ||
| """Drop wrapper-generated path segments from parameter names.""" | ||
| return name.replace("._checkpoint_wrapped_module", "").replace("._orig_mod", "") | ||
|
Comment on lines
+1
to
+8
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If these are patches for testing only, can we move them to somewhere under the |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove this option. There are too much code changes introduced for this fallback. It is better to just make it clear in the help message of
tree_training_modeoption.