@@ -60,6 +60,7 @@ class DeepEPPreCombineResult(PreCombineResult):
60
60
61
61
class DeepEPCombineResult (CombineResult ):
62
62
forward_finished_event : EventOverlap | None
63
+ backward_previous_event : EventOverlap | None
63
64
64
65
65
66
DeepEPPostCombineResult = PostCombineResult
@@ -144,6 +145,7 @@ def forward(
144
145
handle : DeepEPHandle ,
145
146
group : dist .ProcessGroup ,
146
147
forward_previous_event : EventOverlap | None = None ,
148
+ backward_previous_event : EventOverlap | None = None ,
147
149
backward_finished_event : EventOverlap | None = None ,
148
150
) -> tuple [torch .Tensor , EventOverlap ]:
149
151
combined_x , event = combine_forward (x , num_experts , handle , group , forward_previous_event )
@@ -152,17 +154,18 @@ def forward(
152
154
ctx .group = group
153
155
ctx .num_experts = num_experts
154
156
ctx .backward_finished_event = backward_finished_event
157
+ ctx .backward_previous_event = backward_previous_event
155
158
return combined_x , event
156
159
157
160
@staticmethod
158
161
def backward ( # type: ignore[invalid-override]
159
162
ctx , grad_combined_x : torch .Tensor , * args
160
- ) -> tuple [torch .Tensor | tuple [torch .Tensor , torch .Tensor ], None , None , None , None , None ]:
163
+ ) -> tuple [torch .Tensor | tuple [torch .Tensor , torch .Tensor ], None , None , None , None , None , None ]:
161
164
# load saved comm handle
162
165
handle = ctx .saved_tensors
163
- grad_x , event = combine_backward (grad_combined_x , ctx .num_experts , handle , ctx .group , buffer_capture () )
166
+ grad_x , event = combine_backward (grad_combined_x , ctx .num_experts , handle , ctx .group , ctx . backward_previous_event )
164
167
ctx .backward_finished_event .event = event .event
165
- return grad_x , None , None , None , None , None
168
+ return grad_x , None , None , None , None , None , None
166
169
167
170
168
171
_async_combine = copy_method_signature (DeepEPCombine .forward )(DeepEPCombine .apply )
@@ -388,24 +391,30 @@ def combine(
388
391
decoding : bool = False ,
389
392
) -> CombineResult :
390
393
if async_op :
394
+ backward_previous_event = EventOverlap (None )
391
395
assert pre_combined ["forward_finished_event" ] is not None , "Please use `async_op=True` for combine!"
392
396
pre_combined ["forward_finished_event" ].current_stream_wait ()
397
+ else :
398
+ backward_previous_event = None
393
399
394
400
combined_hidden_states , event = _async_combine (
395
401
pre_combined ["hidden_states" ],
396
402
self ._n_routed_experts ,
397
403
dispatched ["handle" ],
398
404
self ._process_group ,
399
405
pre_combined ["forward_finished_event" ],
406
+ backward_previous_event ,
400
407
pre_combined ["backward_previous_event" ],
401
408
)
402
409
if not async_op :
403
410
event .current_stream_wait ()
404
411
412
+
405
413
if not decoding :
406
414
return DeepEPCombineResult (
407
415
hidden_states = combined_hidden_states ,
408
416
forward_finished_event = event ,
417
+ backward_previous_event = backward_previous_event ,
409
418
)
410
419
else :
411
420
raise NotImplementedError
@@ -424,6 +433,17 @@ def combine_postprocess(
424
433
hidden_states = combined ["hidden_states" ]
425
434
forward_previous_event = combined ["forward_finished_event" ]
426
435
436
+ hidden_states = hidden_states .view_as (hidden_states )
437
+
438
+ if hidden_states .grad_fn is not None :
439
+ hidden_states .grad_fn .register_hook (
440
+ get_backward_hook (
441
+ backward_finished_event = combined ["backward_previous_event" ],
442
+ name = "DeeEPDispatcher.combine_postprocess" ,
443
+ debug = XTUNER_DISPATCHER_DEBUG ,
444
+ )
445
+ )
446
+
427
447
if async_op :
428
448
assert forward_previous_event is not None , "Please use `async_op=True` for combine!"
429
449
forward_previous_event .current_stream_wait ()
0 commit comments