-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
158 lines (119 loc) · 4.84 KB
/
Copy pathtrain.py
File metadata and controls
158 lines (119 loc) · 4.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""Unified TorchTitan launcher for NanoGPT and DINOv3 SSL training."""
from __future__ import annotations
import logging
import sys
from typing import Any, Callable
MODEL_NAME_FLAGS = {"--model.name", "--model-name", "--model_name"}
MODEL_FLAVOR_FLAGS = {"--model.flavor", "--model-flavor", "--model_flavor"}
CUSTOM_CONFIG_FLAGS = {"--job.custom_config_module", "--job.custom-config-module"}
SUPPORTED_TARGETS = ("nanogpt", "dinov3")
def _read_last_flag_value(args: list[str], flags: set[str]) -> tuple[bool, str | None]:
found = False
value: str | None = None
i = 0
while i < len(args):
token = args[i]
if "=" in token:
key, parsed = token.split("=", 1)
if key in flags:
found = True
value = parsed
i += 1
continue
if token in flags:
found = True
if i + 1 < len(args) and not args[i + 1].startswith("--"):
value = args[i + 1]
i += 2
continue
value = None
i += 1
return found, value
def _has_flag(args: list[str], flags: set[str]) -> bool:
found, _ = _read_last_flag_value(args, flags)
return found
def _resolve_target(args: list[str]) -> str:
found, model_name = _read_last_flag_value(args, MODEL_NAME_FLAGS)
if not found:
return "nanogpt"
if not model_name:
raise ValueError(
"Missing value for --model.name/--model-name/--model_name. "
f"Supported model names: {', '.join(SUPPORTED_TARGETS)}"
)
if model_name == "nanogpt":
return "nanogpt"
if model_name == "dinov3":
return "dinov3"
raise ValueError(
f"Unsupported model.name '{model_name}'. "
f"Supported model names: {', '.join(SUPPORTED_TARGETS)}"
)
def _inject_defaults(args: list[str], target: str) -> list[str]:
out = list(args)
defaults: list[str] = []
if target == "nanogpt":
if not _has_flag(out, MODEL_NAME_FLAGS):
defaults.append("--model.name=nanogpt")
if not _has_flag(out, MODEL_FLAVOR_FLAGS):
defaults.append("--model.flavor=gpt2_small")
return [*defaults, *out]
if target == "dinov3":
if not _has_flag(out, MODEL_FLAVOR_FLAGS):
defaults.append("--model.flavor=default")
if not _has_flag(out, CUSTOM_CONFIG_FLAGS):
defaults.append("--job.custom_config_module=dinov3.job_config")
return [*defaults, *out]
raise ValueError(f"Unknown target '{target}'")
def _silence_known_torchtitan_warnings() -> None:
class _HFAssetsWarningFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
message = record.getMessage()
return "HF assets path" not in message or "does not exist!" not in message
logging.getLogger().addFilter(_HFAssetsWarningFilter())
def _register_train_spec(name: str, spec_factory: Callable[[], Any]) -> None:
from torchtitan.protocols.train_spec import register_train_spec
spec = spec_factory()
try:
register_train_spec(name, spec)
except ValueError as exc:
if "already registered" not in str(exc):
raise
def _patch_config_manager_parse_args() -> None:
from torchtitan.config.manager import ConfigManager
original = ConfigManager.parse_args
if getattr(original, "_dinov3_launcher_patched", False):
return
def _parse_args(self: ConfigManager, args: list[str] | None = None) -> Any:
return original(self, sys.argv[1:] if args is None else args)
setattr(_parse_args, "_dinov3_launcher_patched", True)
ConfigManager.parse_args = _parse_args # type: ignore[method-assign]
def _launch_nanogpt() -> None:
from nanogpt import get_train_spec as get_nanogpt_train_spec
from torchtitan.train import Trainer, main as torchtitan_main
_register_train_spec("nanogpt", get_nanogpt_train_spec)
torchtitan_main(Trainer)
def _launch_dinov3() -> None:
from dinov3 import get_train_spec as get_dinov3_train_spec
from dinov3.trainer import DinoV3Trainer
_patch_config_manager_parse_args()
_register_train_spec("dinov3", get_dinov3_train_spec)
# Import after argv injection because ConfigManager.parse_args uses a
# default argument bound at import time.
from torchtitan.train import main as torchtitan_main
torchtitan_main(DinoV3Trainer)
def main(argv: list[str] | None = None) -> None:
args = sys.argv[1:] if argv is None else argv
target = _resolve_target(args)
args = _inject_defaults(args, target)
sys.argv = [sys.argv[0], *args]
_silence_known_torchtitan_warnings()
if target == "nanogpt":
_launch_nanogpt()
return
if target == "dinov3":
_launch_dinov3()
return
raise ValueError(f"Unknown target '{target}'")
if __name__ == "__main__":
main()