Skip to content

Commit 987ef52

Browse files
committed
benchmark
1 parent aaf3d45 commit 987ef52

File tree

7 files changed

+44
-14
lines changed

7 files changed

+44
-14
lines changed

xtuner/v1/float8/fsdp_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
DEVICE_MODULE = get_torch_device_module()
2323

2424

25-
@maybe_compile(fullgraph=True)
2625
def tensor_to_per_block_fp8_devided_64_scales(
2726
tensor: "WeightWithDynamicTilewiseFloat8CastTensor",
2827
reduce_mesh_devided_64: Optional[DeviceMesh] = None,
@@ -224,7 +223,6 @@ def cast_to_per_block_fp8_with_scales(
224223
return tensor_bits_fp8
225224

226225

227-
@maybe_compile(fullgraph=True)
228226
def cast_to_per_block_fp8_devided_64_with_scales(
229227
tensor: torch.Tensor,
230228
scales: torch.Tensor,

xtuner/v1/model/moe/deepseek_v3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from transformers.models.deepseek_v3 import DeepseekV3Config as HFDeepseekV3Config
77
from xtuner.v1.model.moe.moe import BalancingLossConfig, MoEConfig, ZLossConfig
88
from xtuner.v1.module.attention import MLAConfig
9+
from xtuner.v1.module.router.greedy import GreedyRouterConfig
910
from xtuner.v1.module.router.noaux_router import NoAuxRouterConfig
1011
from xtuner.v1.utils import get_logger
1112

@@ -89,7 +90,7 @@ class DeepSeekV3Config(MoEConfig):
8990
num_experts_per_tok: int = 8
9091
hidden_factor: float = 1.0
9192
moe_intermediate_size: int = 2048
92-
router: NoAuxRouterConfig = NoAuxRouterConfig(
93+
router: NoAuxRouterConfig | GreedyRouterConfig = NoAuxRouterConfig(
9394
n_group=8,
9495
topk_group=4,
9596
scoring_func="sigmoid",

xtuner/v1/model/moe/moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def _micro_batch_forward(
305305
cat_position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None
306306
cat_hidden_states: torch.Tensor | None = None
307307

308+
moe_forawrd = False
308309
for idx, decoder_layer in self.layers.items():
309310
layer_idx = int(idx)
310311

@@ -322,8 +323,9 @@ def _micro_batch_forward(
322323
seq_ctx=cat_seq_ctx,
323324
)
324325
else:
325-
if cat_hidden_states is not None:
326+
if cat_hidden_states is not None and not moe_forawrd:
326327
hidden_states_list = list(cat_hidden_states.chunk(len(seq_ctx_list), dim=1))
328+
moe_forawrd = True
327329

328330
layer_results = decoder_layer(
329331
*hidden_states_list,

xtuner/v1/module/dispatcher/deepep.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class DeepEPPreCombineResult(PreCombineResult):
6060

6161
class DeepEPCombineResult(CombineResult):
6262
forward_finished_event: EventOverlap | None
63+
backward_previous_event: EventOverlap | None
6364

6465

6566
DeepEPPostCombineResult = PostCombineResult
@@ -144,6 +145,7 @@ def forward(
144145
handle: DeepEPHandle,
145146
group: dist.ProcessGroup,
146147
forward_previous_event: EventOverlap | None = None,
148+
backward_previous_event: EventOverlap | None = None,
147149
backward_finished_event: EventOverlap | None = None,
148150
) -> tuple[torch.Tensor, EventOverlap]:
149151
combined_x, event = combine_forward(x, num_experts, handle, group, forward_previous_event)
@@ -152,17 +154,18 @@ def forward(
152154
ctx.group = group
153155
ctx.num_experts = num_experts
154156
ctx.backward_finished_event = backward_finished_event
157+
ctx.backward_previous_event = backward_previous_event
155158
return combined_x, event
156159

157160
@staticmethod
158161
def backward( # type: ignore[invalid-override]
159162
ctx, grad_combined_x: torch.Tensor, *args
160-
) -> tuple[torch.Tensor | tuple[torch.Tensor, torch.Tensor], None, None, None, None, None]:
163+
) -> tuple[torch.Tensor | tuple[torch.Tensor, torch.Tensor], None, None, None, None, None, None]:
161164
# load saved comm handle
162165
handle = ctx.saved_tensors
163-
grad_x, event = combine_backward(grad_combined_x, ctx.num_experts, handle, ctx.group, buffer_capture())
166+
grad_x, event = combine_backward(grad_combined_x, ctx.num_experts, handle, ctx.group, ctx.backward_previous_event)
164167
ctx.backward_finished_event.event = event.event
165-
return grad_x, None, None, None, None, None
168+
return grad_x, None, None, None, None, None, None
166169

167170

168171
_async_combine = copy_method_signature(DeepEPCombine.forward)(DeepEPCombine.apply)
@@ -388,24 +391,30 @@ def combine(
388391
decoding: bool = False,
389392
) -> CombineResult:
390393
if async_op:
394+
backward_previous_event = EventOverlap(None)
391395
assert pre_combined["forward_finished_event"] is not None, "Please use `async_op=True` for combine!"
392396
pre_combined["forward_finished_event"].current_stream_wait()
397+
else:
398+
backward_previous_event = None
393399

394400
combined_hidden_states, event = _async_combine(
395401
pre_combined["hidden_states"],
396402
self._n_routed_experts,
397403
dispatched["handle"],
398404
self._process_group,
399405
pre_combined["forward_finished_event"],
406+
backward_previous_event,
400407
pre_combined["backward_previous_event"],
401408
)
402409
if not async_op:
403410
event.current_stream_wait()
404411

412+
405413
if not decoding:
406414
return DeepEPCombineResult(
407415
hidden_states=combined_hidden_states,
408416
forward_finished_event=event,
417+
backward_previous_event=backward_previous_event,
409418
)
410419
else:
411420
raise NotImplementedError
@@ -424,6 +433,17 @@ def combine_postprocess(
424433
hidden_states = combined["hidden_states"]
425434
forward_previous_event = combined["forward_finished_event"]
426435

436+
hidden_states = hidden_states.view_as(hidden_states)
437+
438+
if hidden_states.grad_fn is not None:
439+
hidden_states.grad_fn.register_hook(
440+
get_backward_hook(
441+
backward_finished_event=combined["backward_previous_event"],
442+
name="DeeEPDispatcher.combine_postprocess",
443+
debug=XTUNER_DISPATCHER_DEBUG,
444+
)
445+
)
446+
427447
if async_op:
428448
assert forward_previous_event is not None, "Please use `async_op=True` for combine!"
429449
forward_previous_event.current_stream_wait()

xtuner/v1/module/router/noaux_router.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,6 @@ def __init__(
5959
)
6060

6161
def forward(self, logits) -> RouterResults:
62-
if os.getenv("XTUNER_ROUTER_DEBUG") == "true":
63-
noise = torch.randn_like(logits) * 50
64-
logits = logits + noise
6562

6663
if self.scoring_func == "sigmoid":
6764
scores = logits.sigmoid()
@@ -71,6 +68,10 @@ def forward(self, logits) -> RouterResults:
7168

7269
scores_for_choice = scores + self.e_score_correction_bias.unsqueeze(0)
7370

71+
if os.getenv("XTUNER_ROUTER_DEBUG") == "true":
72+
noise = torch.randn_like(scores) * 50
73+
scores_for_choice= scores + noise
74+
7475
# select top-k experts
7576
# (only applicable when ep_size >= 64. when ep_size=32 (4 nodes), there is no need to employ this strategy)
7677
_, topk_idx = torch.topk(scores_for_choice, k=self.top_k, dim=-1)

xtuner/v1/ops/comm/deepep_op.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
_low_latency_buffer: Optional[Buffer] = None
2121
# Set the number of SMs to use
2222
# NOTES: this is a static variable
23-
Buffer.set_num_sms(24)
23+
# Buffer.set_num_sms(24)
24+
Buffer.set_num_sms(20)
2425

2526

2627
# You may call this function at the framework initialization

xtuner/v1/train/trainer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
record_git_info,
4444
)
4545
from xtuner.v1.utils.device import get_device, get_torch_device_module
46+
import gc
4647

4748
from .toy_tokenizer import UTF8ByteTokenizer
4849

@@ -142,7 +143,7 @@ class TrainerConfig(BaseModel):
142143
hf_interval: int | None = None
143144
hf_max_keep: int | None = None
144145
exp_tracker: Literal["tensorboard", "jsonl"] = "jsonl"
145-
profile_step: int | None = None
146+
profile_step: list | int | None = None
146147
profile_time: bool = True
147148
profile_memory: bool = False
148149
intra_layer_micro_batch: int = 1
@@ -237,7 +238,7 @@ def __init__(
237238
hf_interval: int | None = None,
238239
hf_max_keep: int | None = None,
239240
exp_tracker: Literal["tensorboard", "jsonl"] = "jsonl",
240-
profile_step: int | None = None,
241+
profile_step: list | None = None,
241242
profile_time: bool = True,
242243
profile_memory: bool = False,
243244
intra_layer_micro_batch: int = 1,
@@ -257,6 +258,8 @@ def __init__(
257258

258259
self._micro_batch_size: int | None = None
259260

261+
if type(profile_step) is int:
262+
profile_step = [profile_step]
260263
self._profile_step = profile_step
261264
self._profile_time = profile_time
262265
self._profile_memory = profile_memory
@@ -492,6 +495,9 @@ def fit(self):
492495

493496
time_before_get_data = time.time()
494497

498+
if self.cur_step % 50 == 0:
499+
gc.collect()
500+
495501
@property
496502
def world_size(self) -> int:
497503
"""Get the total number of processes in the distributed training group.
@@ -944,7 +950,7 @@ def _init_xtuner_meta(self, work_dir: Path, auto_resume: bool) -> XTunerMeta:
944950
@contextmanager
945951
def _maybe_profiling(self):
946952
"""Check if profiling is enabled and perform profiling if necessary."""
947-
if self._profile_step is not None and self._cur_step == self._profile_step:
953+
if self._profile_step is not None and self._cur_step in self._profile_step:
948954
with contextlib.ExitStack() as stack:
949955
if self._profile_time:
950956
time_dir = self.exp_dir / self._PROFILE_TIME_PATH / f"step-{self._cur_step}"
@@ -1177,6 +1183,7 @@ def _resume_dataloader(self, dataloader_path: Path):
11771183
self._dataloader.load_state_dict(dataloader_state)
11781184

11791185
def _setup_env(self):
1186+
gc.disable()
11801187
os.environ["TOKENIZERS_PARALLELISM"] = "true"
11811188

11821189
log_str = "\n============XTuner Training Environment============\n"

0 commit comments

Comments
 (0)