Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,13 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None:
self.fp8_meta["num_gemms"] = num_gemms
self.init_fp8_meta_tensors(self.fp8_meta["recipe"])

# Force the transpose cache to be kept whenever the recipe is MXFP8 / MXFP4,
# regardless of whether we are currently inside an fp8_autocast region or not.
# reset_parameters() would disable columnwise_usage for params constructed inside
# `fp8_model_init` / `quantized_model_init`, leaving `_columnwise_data=None`).
if self.fp8_meta["recipe"].mxfp8() or self.fp8_meta["recipe"].mxfp4():
self.keep_fp8_weight_transpose_cache = True

if fp8_enabled:
# Set FP8 and other FP8 metadata
self.fp8_meta["num_gemms"] = num_gemms
Expand All @@ -1092,8 +1099,6 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None:
self.fp8_initialized = True

self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
if self.fp8_meta["recipe"].mxfp8() or self.fp8_meta["recipe"].mxfp4():
self.keep_fp8_weight_transpose_cache = True

_current_recipe = self.fp8_meta["recipe"]
if _original_recipe is not None and not (
Expand Down
5 changes: 4 additions & 1 deletion transformer_engine/pytorch/tensor/mxfp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
scale_invs = [tensor._rowwise_scale_inv, tensor._columnwise_scale_inv]
split_sizes_for_scale = [split_size, split_size // MXFP8_BLOCK_SCALING_SIZE]
# Padding requirements: rowwise dim0 should be divisble by 128, columnwise dim0 should be divisble by 4
padding_multiples = [128, 4]
# NOTE: ROCm/HIP backend uses an unpadded scale-inv layout (see `MXFP8Quantizer.make_empty`),
# so applying the padding here would produce a per-shard scale-inv whose dim-0
# does not match the destination scale-inv allocated for the FSDP2 local shard.
padding_multiples = [128, 4] if not IS_HIP_EXTENSION else [1, 1]
Comment thread
alextmagro marked this conversation as resolved.
for scale_inv, scale_split_size, pad_multiple in zip(
scale_invs, split_sizes_for_scale, padding_multiples
):
Expand Down
Loading