-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Refactor communication logic of DeepSeek for extensibility and understandability #6321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This reverts commit 6c5c726.
context=self._compute_context(forward_batch), | ||
) | ||
|
||
def forward_layer_end( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forward_post_ffn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering why that is better than layer_end
Indeed it does the job of handling layer end transformations (e.g. make output shape correct for last layer). For non-last layers, it does not cooperate with pre_mlp, but indeed cooperate with next layer's pre_attn. So a verbose name (don't use it - only for illustration) should be forward_layer_end_transformation_or_cooperate_with_next_pre_attn
self._enable_moe_dense_fully_dp() | ||
and (not self.info.is_sparse) | ||
enable_moe_dense_fully_dp() | ||
and (not self.is_layer_sparse) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checking enable_moe_dense_fully_dp() and (not self.is_layer_sparse)
should be handled by the communicator. We can simplify this statement like this:
if hidden_states.shape[0] > 0 or self.layer_communicator.require_ffn_sync():
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit confused about this, why is it related to communicators?
I originally wanted to move it inside DeepseekMLP with logic like "if tp_size=1 + hidden states is empty, then I do not compute" though. -> This is moved now to make code clearer (though it does not fit the PR title and I originally want to separately handle it).
self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn) | ||
is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False) | ||
|
||
self.layer_scatter_modes = LayerScatterModes.init_new( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can be handled internally in LayerCommunicator
. With this change, LayerScatterModes
becomes a private class of communicator.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I make it exposed to allow in the future the LayerScatterModes to be used by modules other than LayerCommunicator, since LayerScatterModes tells us some facts about how layer input / attn / mlp / layer out shape looks like.
# Conflicts: # python/sglang/srt/models/deepseek_v2.py
# Conflicts: # python/sglang/srt/models/deepseek_v2.py
Motivation
Make code clean, not error-prune, extensible
Modifications
Checklist