-
Notifications
You must be signed in to change notification settings - Fork 495
feat: Support Linear Cross Entropy fuse kernel #1322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c2852f3
fc5211b
e494738
9495141
77ea280
47fd3e2
203a7e4
e2e5d70
fecfcbd
0b5eefe
e69674e
d444cbe
07f03d8
003296b
0020faa
fa82bb9
4e9f8c7
64b91a4
92c4298
c640e03
8b1a243
a6952f3
6346396
0b044e5
5e35bbe
09eff54
825d44c
8aaff93
292dab1
8addb48
c5525d6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -64,6 +64,11 @@ | |
| from areal.engine.megatron_utils.checkpointer import MegatronCheckpointManager | ||
| from areal.engine.megatron_utils.deterministic import set_deterministic_algorithms | ||
| from areal.engine.megatron_utils.fp8 import FP8BlockwiseTensorHelper | ||
| from areal.engine.megatron_utils.fused_lce_capture import ( | ||
| FUSED_LCE_HIDDEN_KEY, | ||
| FUSED_LCE_WEIGHT_KEY, | ||
| capture_lm_head_hidden, | ||
| ) | ||
| from areal.engine.megatron_utils.megatron import ( | ||
| all_gather_param, | ||
| convert_to_hf, | ||
|
|
@@ -117,7 +122,12 @@ | |
| split_padded_tensor_dict_into_mb_list, | ||
| unpad_logits, | ||
| ) | ||
| from areal.utils.functional import gather_logprobs, gather_logprobs_entropy | ||
| from areal.utils.functional import ( | ||
| gather_logprobs, | ||
| gather_logprobs_entropy, | ||
| linear_cross_entropy_logprobs, | ||
| linear_cross_entropy_logprobs_entropy, | ||
| ) | ||
| from areal.utils.hf_utils import load_hf_processor_and_tokenizer, load_hf_tokenizer | ||
| from areal.utils.lock import DistributedLock | ||
| from areal.utils.network import find_free_ports, format_host_for_url, gethostip | ||
|
|
@@ -805,6 +815,12 @@ def forward_backward_batch( | |
| ) -> None: | ||
| self._ensure_ready() | ||
|
|
||
| use_fused_lce = ( | ||
| getattr(self.config, "use_fused_linear_ce", False) | ||
| and not self.config.is_critic | ||
| and not self.enable_tree_training | ||
| ) | ||
|
|
||
| def forward_step(batch_iter, model): | ||
| mb_input: MicroBatchItem = next(batch_iter) | ||
|
|
||
|
|
@@ -835,13 +851,32 @@ def forward_step(batch_iter, model): | |
| cp_size = mpu.get_context_parallel_world_size() | ||
| cp_local = cp_size > 1 | ||
|
|
||
| output = packed_context_parallel_forward( | ||
| model, | ||
| mb_input.padded_mb, | ||
| gather_cp_output=not cp_local, | ||
| is_vision_model=self.is_vision_model, | ||
| model_vp_stage_for_capture = getattr(model, "vp_stage", 0) | ||
| should_capture = ( | ||
| use_fused_lce | ||
| and mpu.is_pipeline_last_stage( | ||
| ignore_virtual=False, vp_stage=model_vp_stage_for_capture | ||
| ) | ||
| and not cp_local | ||
| ) | ||
|
|
||
| with capture_lm_head_hidden(model, enabled=should_capture) as capture: | ||
| output = packed_context_parallel_forward( | ||
| model, | ||
| mb_input.padded_mb, | ||
| gather_cp_output=not cp_local, | ||
| is_vision_model=self.is_vision_model, | ||
| ) | ||
|
|
||
| if ( | ||
| capture is not None | ||
| and capture.hidden is not None | ||
| and capture.weight is not None | ||
| ): | ||
| mb_input.orig_mb[FUSED_LCE_HIDDEN_KEY] = capture.hidden | ||
| mb_input.orig_mb[FUSED_LCE_WEIGHT_KEY] = capture.weight | ||
| mb_input.orig_mb["_fused_lce_active"] = True | ||
|
|
||
| # Release tree attention metadata after forward pass | ||
| for key in tree_attn_keys: | ||
| del mb_input.padded_mb[key] | ||
|
|
@@ -877,6 +912,15 @@ def _process_output(input_, output_): | |
| cu_seqlens=cu_seqlens, | ||
| old_cu_seqlens=mb_input.old_cu_seqlens, | ||
| ) | ||
| # Re-align Float16Module's fp32 hidden to lm-head weight dtype. | ||
| if mb_input.orig_mb.get("_fused_lce_active", False): | ||
| fused_weight = mb_input.orig_mb.get(FUSED_LCE_WEIGHT_KEY) | ||
| if ( | ||
| fused_weight is not None | ||
| and output.dtype != fused_weight.dtype | ||
| ): | ||
| output = output.to(fused_weight.dtype) | ||
| mb_input.orig_mb[FUSED_LCE_HIDDEN_KEY] = output | ||
|
Comment on lines
+917
to
+923
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we usually require fp32 logits, will this downcast operation cause a precision issue? |
||
| return output, functools.partial(_process_output, mb_input.orig_mb) | ||
|
|
||
| forward_backward_func = get_forward_backward_func() | ||
|
|
@@ -2069,58 +2113,84 @@ def _compute_logprobs_and_loss( | |
| labels = cp_local_labels | ||
| else: | ||
| labels = torch.roll(inputs["input_ids"], shifts=-1, dims=-1) | ||
| logprobs, entropy = gather_logprobs_entropy( | ||
| output, | ||
| labels, | ||
| temperature=self.config.temperature, | ||
| tp_group=mpu.get_tensor_model_parallel_group() | ||
| if mpu.get_tensor_model_parallel_world_size() > 1 | ||
| else None, | ||
| ) | ||
| vocab_min_logits = output.detach().min(-1).values.float() | ||
| vocab_max_logits = output.detach().max(-1).values.float() | ||
| if cp_padded_cu_seqlens is not None: | ||
| logprobs = reassemble_cp_packed_logprobs( | ||
| logprobs, cp_padded_cu_seqlens | ||
| ) | ||
| entropy = reassemble_cp_packed_logprobs( | ||
| entropy, cp_padded_cu_seqlens | ||
| ) | ||
| vocab_min_logits = reassemble_cp_packed_logprobs( | ||
| vocab_min_logits, cp_padded_cu_seqlens | ||
| ) | ||
| vocab_max_logits = reassemble_cp_packed_logprobs( | ||
| vocab_max_logits, cp_padded_cu_seqlens | ||
| ) | ||
| cp_padding_length = inputs.get("_cp_padding_length", 0) | ||
| cp_old_cu_seqlens = inputs.get("_cp_old_cu_seqlens") | ||
| logprobs = unpad_logits( | ||
| logprobs, | ||
| cp_padding_length, | ||
| cp_padded_cu_seqlens, | ||
| cp_old_cu_seqlens, | ||
| ) | ||
| entropy = unpad_logits( | ||
| entropy, | ||
| cp_padding_length, | ||
| cp_padded_cu_seqlens, | ||
| cp_old_cu_seqlens, | ||
| ) | ||
| vocab_min_logits = unpad_logits( | ||
| vocab_min_logits, | ||
| cp_padding_length, | ||
| cp_padded_cu_seqlens, | ||
| cp_old_cu_seqlens, | ||
| fused_active = inputs.get("_fused_lce_active", False) | ||
| fused_hidden = inputs.get(FUSED_LCE_HIDDEN_KEY) | ||
| fused_weight = inputs.get(FUSED_LCE_WEIGHT_KEY) | ||
| if ( | ||
| fused_active | ||
| and fused_hidden is not None | ||
| and fused_weight is not None | ||
| ): | ||
|
Comment on lines
+2119
to
+2123
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The if-else indentation is too deep here. Try extracting the whole if-else code into a separate granular method and use early return to eliminate the else. |
||
| logprobs, entropy, vocab_max_logits = ( | ||
| linear_cross_entropy_logprobs_entropy( | ||
| fused_hidden, | ||
| fused_weight, | ||
| labels, | ||
| temperature=self.config.temperature, | ||
| tp_group=mpu.get_tensor_model_parallel_group() | ||
| if mpu.get_tensor_model_parallel_world_size() > 1 | ||
| else None, | ||
| return_max_logits=True, | ||
| ) | ||
| ) | ||
| vocab_max_logits = unpad_logits( | ||
| vocab_max_logits, | ||
| cp_padding_length, | ||
| cp_padded_cu_seqlens, | ||
| cp_old_cu_seqlens, | ||
| # Fused kernel does not track per-token vocab min logits; | ||
| # skip the min telemetry rather than report a misleading | ||
| # proxy. Consumers must guard ``vocab_min_logits`` and | ||
| # ``vocab_max_logits`` independently. | ||
| vocab_min_logits = None | ||
| else: | ||
| logprobs, entropy = gather_logprobs_entropy( | ||
| output, | ||
| labels, | ||
| temperature=self.config.temperature, | ||
| tp_group=mpu.get_tensor_model_parallel_group() | ||
| if mpu.get_tensor_model_parallel_world_size() > 1 | ||
| else None, | ||
| ) | ||
| inputs = { | ||
| k: v for k, v in inputs.items() if not k.startswith("_cp_") | ||
| } | ||
| vocab_min_logits = output.detach().min(-1).values.float() | ||
| vocab_max_logits = output.detach().max(-1).values.float() | ||
| if cp_padded_cu_seqlens is not None: | ||
| logprobs = reassemble_cp_packed_logprobs( | ||
| logprobs, cp_padded_cu_seqlens | ||
| ) | ||
| entropy = reassemble_cp_packed_logprobs( | ||
| entropy, cp_padded_cu_seqlens | ||
| ) | ||
| vocab_min_logits = reassemble_cp_packed_logprobs( | ||
| vocab_min_logits, cp_padded_cu_seqlens | ||
| ) | ||
| vocab_max_logits = reassemble_cp_packed_logprobs( | ||
| vocab_max_logits, cp_padded_cu_seqlens | ||
| ) | ||
| cp_padding_length = inputs.get("_cp_padding_length", 0) | ||
| cp_old_cu_seqlens = inputs.get("_cp_old_cu_seqlens") | ||
| logprobs = unpad_logits( | ||
| logprobs, | ||
| cp_padding_length, | ||
| cp_padded_cu_seqlens, | ||
| cp_old_cu_seqlens, | ||
| ) | ||
| entropy = unpad_logits( | ||
| entropy, | ||
| cp_padding_length, | ||
| cp_padded_cu_seqlens, | ||
| cp_old_cu_seqlens, | ||
| ) | ||
| vocab_min_logits = unpad_logits( | ||
| vocab_min_logits, | ||
| cp_padding_length, | ||
| cp_padded_cu_seqlens, | ||
| cp_old_cu_seqlens, | ||
| ) | ||
| vocab_max_logits = unpad_logits( | ||
| vocab_max_logits, | ||
| cp_padding_length, | ||
| cp_padded_cu_seqlens, | ||
| cp_old_cu_seqlens, | ||
| ) | ||
| inputs = { | ||
| k: v for k, v in inputs.items() if not k.startswith("_cp_") | ||
| } | ||
| loss = loss_fn( | ||
| logprobs, | ||
| entropy, | ||
|
|
@@ -2157,6 +2227,20 @@ def _compute_forward_result( | |
| ) | ||
| return logprobs | ||
| labels = torch.roll(inputs["input_ids"], shifts=-1, dims=-1) | ||
| fused_active = inputs.get("_fused_lce_active", False) | ||
| fused_hidden = inputs.get(FUSED_LCE_HIDDEN_KEY) | ||
| fused_weight = inputs.get(FUSED_LCE_WEIGHT_KEY) | ||
| if fused_active and fused_hidden is not None and fused_weight is not None: | ||
| logprobs = linear_cross_entropy_logprobs( | ||
| fused_hidden, | ||
| fused_weight, | ||
| labels, | ||
| temperature=self.config.temperature, | ||
| tp_group=mpu.get_tensor_model_parallel_group() | ||
| if mpu.get_tensor_model_parallel_world_size() > 1 | ||
| else None, | ||
| ) | ||
| return logprobs | ||
| logprobs = gather_logprobs( | ||
| output, | ||
| labels, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this field to MegatronEngineConfig for now.
We should also note that "vocab_min_logits" stats won't be available when enabling this feature. Could you please open an issue about this and call for community's fix?