Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,14 @@ class TrainEngineConfig:
default=False,
metadata={"help": "Enable tree training with flex attention module."},
)
use_fused_linear_ce: bool = field(
default=False,
Comment on lines +1146 to +1147
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this field to MegatronEngineConfig for now.

We should also note that "vocab_min_logits" stats won't be available when enabling this feature. Could you please open an issue about this and call for community's fix?

metadata={
"help": "Fuse the linear projection with cross-entropy so that the "
"[num_tokens, vocab_size] logits tensor is never materialised. "
"Only effective for the Megatron actor backend with parallel_output=True."
},
)

# Scheduling
scheduling_spec: tuple[SchedulingSpec, ...] = field(
Expand Down
196 changes: 140 additions & 56 deletions areal/engine/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@
from areal.engine.megatron_utils.checkpointer import MegatronCheckpointManager
from areal.engine.megatron_utils.deterministic import set_deterministic_algorithms
from areal.engine.megatron_utils.fp8 import FP8BlockwiseTensorHelper
from areal.engine.megatron_utils.fused_lce_capture import (
FUSED_LCE_HIDDEN_KEY,
FUSED_LCE_WEIGHT_KEY,
capture_lm_head_hidden,
)
from areal.engine.megatron_utils.megatron import (
all_gather_param,
convert_to_hf,
Expand Down Expand Up @@ -117,7 +122,12 @@
split_padded_tensor_dict_into_mb_list,
unpad_logits,
)
from areal.utils.functional import gather_logprobs, gather_logprobs_entropy
from areal.utils.functional import (
gather_logprobs,
gather_logprobs_entropy,
linear_cross_entropy_logprobs,
linear_cross_entropy_logprobs_entropy,
)
from areal.utils.hf_utils import load_hf_processor_and_tokenizer, load_hf_tokenizer
from areal.utils.lock import DistributedLock
from areal.utils.network import find_free_ports, format_host_for_url, gethostip
Expand Down Expand Up @@ -805,6 +815,12 @@ def forward_backward_batch(
) -> None:
self._ensure_ready()

use_fused_lce = (
getattr(self.config, "use_fused_linear_ce", False)
and not self.config.is_critic
and not self.enable_tree_training
)

def forward_step(batch_iter, model):
mb_input: MicroBatchItem = next(batch_iter)

Expand Down Expand Up @@ -835,13 +851,32 @@ def forward_step(batch_iter, model):
cp_size = mpu.get_context_parallel_world_size()
cp_local = cp_size > 1

output = packed_context_parallel_forward(
model,
mb_input.padded_mb,
gather_cp_output=not cp_local,
is_vision_model=self.is_vision_model,
model_vp_stage_for_capture = getattr(model, "vp_stage", 0)
should_capture = (
use_fused_lce
and mpu.is_pipeline_last_stage(
ignore_virtual=False, vp_stage=model_vp_stage_for_capture
)
and not cp_local
)

with capture_lm_head_hidden(model, enabled=should_capture) as capture:
output = packed_context_parallel_forward(
model,
mb_input.padded_mb,
gather_cp_output=not cp_local,
is_vision_model=self.is_vision_model,
)

if (
capture is not None
and capture.hidden is not None
and capture.weight is not None
):
mb_input.orig_mb[FUSED_LCE_HIDDEN_KEY] = capture.hidden
mb_input.orig_mb[FUSED_LCE_WEIGHT_KEY] = capture.weight
mb_input.orig_mb["_fused_lce_active"] = True

# Release tree attention metadata after forward pass
for key in tree_attn_keys:
del mb_input.padded_mb[key]
Expand Down Expand Up @@ -877,6 +912,15 @@ def _process_output(input_, output_):
cu_seqlens=cu_seqlens,
old_cu_seqlens=mb_input.old_cu_seqlens,
)
# Re-align Float16Module's fp32 hidden to lm-head weight dtype.
if mb_input.orig_mb.get("_fused_lce_active", False):
fused_weight = mb_input.orig_mb.get(FUSED_LCE_WEIGHT_KEY)
if (
fused_weight is not None
and output.dtype != fused_weight.dtype
):
output = output.to(fused_weight.dtype)
mb_input.orig_mb[FUSED_LCE_HIDDEN_KEY] = output
Comment on lines +917 to +923
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we usually require fp32 logits, will this downcast operation cause a precision issue?

return output, functools.partial(_process_output, mb_input.orig_mb)

forward_backward_func = get_forward_backward_func()
Expand Down Expand Up @@ -2069,58 +2113,84 @@ def _compute_logprobs_and_loss(
labels = cp_local_labels
else:
labels = torch.roll(inputs["input_ids"], shifts=-1, dims=-1)
logprobs, entropy = gather_logprobs_entropy(
output,
labels,
temperature=self.config.temperature,
tp_group=mpu.get_tensor_model_parallel_group()
if mpu.get_tensor_model_parallel_world_size() > 1
else None,
)
vocab_min_logits = output.detach().min(-1).values.float()
vocab_max_logits = output.detach().max(-1).values.float()
if cp_padded_cu_seqlens is not None:
logprobs = reassemble_cp_packed_logprobs(
logprobs, cp_padded_cu_seqlens
)
entropy = reassemble_cp_packed_logprobs(
entropy, cp_padded_cu_seqlens
)
vocab_min_logits = reassemble_cp_packed_logprobs(
vocab_min_logits, cp_padded_cu_seqlens
)
vocab_max_logits = reassemble_cp_packed_logprobs(
vocab_max_logits, cp_padded_cu_seqlens
)
cp_padding_length = inputs.get("_cp_padding_length", 0)
cp_old_cu_seqlens = inputs.get("_cp_old_cu_seqlens")
logprobs = unpad_logits(
logprobs,
cp_padding_length,
cp_padded_cu_seqlens,
cp_old_cu_seqlens,
)
entropy = unpad_logits(
entropy,
cp_padding_length,
cp_padded_cu_seqlens,
cp_old_cu_seqlens,
)
vocab_min_logits = unpad_logits(
vocab_min_logits,
cp_padding_length,
cp_padded_cu_seqlens,
cp_old_cu_seqlens,
fused_active = inputs.get("_fused_lce_active", False)
fused_hidden = inputs.get(FUSED_LCE_HIDDEN_KEY)
fused_weight = inputs.get(FUSED_LCE_WEIGHT_KEY)
if (
fused_active
and fused_hidden is not None
and fused_weight is not None
):
Comment on lines +2119 to +2123
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if-else indentation is too deep here. Try extracting the whole if-else code into a separate granular method and use early return to eliminate the else.

logprobs, entropy, vocab_max_logits = (
linear_cross_entropy_logprobs_entropy(
fused_hidden,
fused_weight,
labels,
temperature=self.config.temperature,
tp_group=mpu.get_tensor_model_parallel_group()
if mpu.get_tensor_model_parallel_world_size() > 1
else None,
return_max_logits=True,
)
)
vocab_max_logits = unpad_logits(
vocab_max_logits,
cp_padding_length,
cp_padded_cu_seqlens,
cp_old_cu_seqlens,
# Fused kernel does not track per-token vocab min logits;
# skip the min telemetry rather than report a misleading
# proxy. Consumers must guard ``vocab_min_logits`` and
# ``vocab_max_logits`` independently.
vocab_min_logits = None
else:
logprobs, entropy = gather_logprobs_entropy(
output,
labels,
temperature=self.config.temperature,
tp_group=mpu.get_tensor_model_parallel_group()
if mpu.get_tensor_model_parallel_world_size() > 1
else None,
)
inputs = {
k: v for k, v in inputs.items() if not k.startswith("_cp_")
}
vocab_min_logits = output.detach().min(-1).values.float()
vocab_max_logits = output.detach().max(-1).values.float()
if cp_padded_cu_seqlens is not None:
logprobs = reassemble_cp_packed_logprobs(
logprobs, cp_padded_cu_seqlens
)
entropy = reassemble_cp_packed_logprobs(
entropy, cp_padded_cu_seqlens
)
vocab_min_logits = reassemble_cp_packed_logprobs(
vocab_min_logits, cp_padded_cu_seqlens
)
vocab_max_logits = reassemble_cp_packed_logprobs(
vocab_max_logits, cp_padded_cu_seqlens
)
cp_padding_length = inputs.get("_cp_padding_length", 0)
cp_old_cu_seqlens = inputs.get("_cp_old_cu_seqlens")
logprobs = unpad_logits(
logprobs,
cp_padding_length,
cp_padded_cu_seqlens,
cp_old_cu_seqlens,
)
entropy = unpad_logits(
entropy,
cp_padding_length,
cp_padded_cu_seqlens,
cp_old_cu_seqlens,
)
vocab_min_logits = unpad_logits(
vocab_min_logits,
cp_padding_length,
cp_padded_cu_seqlens,
cp_old_cu_seqlens,
)
vocab_max_logits = unpad_logits(
vocab_max_logits,
cp_padding_length,
cp_padded_cu_seqlens,
cp_old_cu_seqlens,
)
inputs = {
k: v for k, v in inputs.items() if not k.startswith("_cp_")
}
loss = loss_fn(
logprobs,
entropy,
Expand Down Expand Up @@ -2157,6 +2227,20 @@ def _compute_forward_result(
)
return logprobs
labels = torch.roll(inputs["input_ids"], shifts=-1, dims=-1)
fused_active = inputs.get("_fused_lce_active", False)
fused_hidden = inputs.get(FUSED_LCE_HIDDEN_KEY)
fused_weight = inputs.get(FUSED_LCE_WEIGHT_KEY)
if fused_active and fused_hidden is not None and fused_weight is not None:
logprobs = linear_cross_entropy_logprobs(
fused_hidden,
fused_weight,
labels,
temperature=self.config.temperature,
tp_group=mpu.get_tensor_model_parallel_group()
if mpu.get_tensor_model_parallel_world_size() > 1
else None,
)
return logprobs
logprobs = gather_logprobs(
output,
labels,
Expand Down
Loading
Loading