Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions miles/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
from miles.utils.structured_log import log_structured
from miles.utils.test_utils.ft_test_actions import FTTestActionActorExecutor
from miles.utils.witness.allocator import WitnessInfo
from miles.utils.witness.module import witness_dump_and_clear_stale

from ...utils.misc import filter_keys
from ..training_utils.ci_utils import check_grad_norm, check_kl
from ..training_utils.data import DataIterator, get_batch
from ..training_utils.log_utils import aggregate_forward_results, aggregate_train_losses, log_train_step
Expand Down Expand Up @@ -275,6 +277,7 @@ def forward_step(
"total_lengths",
"response_lengths",
"max_seq_lens",
"witness_ids",
],
args.data_pad_size_multiplier,
args.qkv_format,
Expand All @@ -293,6 +296,7 @@ def forward_step(
labels=None,
packed_seq_params=packed_seq_params,
loss_mask=batch["full_loss_masks"],
**(filter_keys(batch, ["witness_ids"]) if args.enable_witness else {}),
**(batch["multimodal_train_inputs"] if batch["multimodal_train_inputs"] is not None else {}),
)

Expand Down Expand Up @@ -431,6 +435,7 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p
"returns",
"rollout_log_probs",
"max_seq_lens",
"witness_ids",
"opd_reverse_kl",
],
args.data_pad_size_multiplier,
Expand All @@ -446,6 +451,7 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p

if return_schedule_plan:
assert not args.enable_mtp_training, "MTP training should not be enabled when using combined 1f1b"
assert not args.enable_witness, "Witness is not supported with combined 1f1b (build_schedule_plan)"
output_tensor = model.build_schedule_plan(
input_ids=batch["tokens"],
position_ids=None,
Expand All @@ -462,6 +468,7 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p
"labels": None,
"packed_seq_params": get_packed_seq_params(batch, args),
"loss_mask": batch["full_loss_masks"],
**(filter_keys(batch, ["witness_ids"]) if args.enable_witness else {}),
}

if args.enable_mtp_training:
Expand Down Expand Up @@ -558,6 +565,8 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p

if outcome == TrainStepOutcome.NORMAL:
dump_local_weight_checksums(args=args, model=model, optimizer=optimizer)
if args.enable_witness:
witness_dump_and_clear_stale(model=model, witness_info=witness_info, optimizer=optimizer)

if mpu.is_pipeline_last_stage(ignore_virtual=True):
loss_reduced = (
Expand Down
Loading