diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 055210f93a..cc1b5a865c 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -23,6 +23,8 @@ from transformer_engine.debug.features.utils.stats_computation import ( compute_max_blockwise_dynamic_range, BlockwiseDynamicRangeStat, + STATS, + stats_to_num, ) import math @@ -61,6 +63,7 @@ "underflows%", "scale_inv_min", "scale_inv_max", + "scale_inv_std", "mse", ] @@ -248,6 +251,10 @@ def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs): debug_api.step() dequantized_tensor = quantized_tensor.dequantize() + if hasattr(quantized_tensor, "_scale_inv"): + scale_inv_rowwise = quantized_tensor._scale_inv.float() + else: + scale_inv_rowwise = quantized_tensor._rowwise_scale_inv.float() output = read_log(log_dir) for line in output.splitlines(): @@ -267,6 +274,17 @@ def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs): (abs(dequantized_tensor) > abs(tensor)).sum() / dequantized_tensor.numel() * 100 ) assert overflows == pytest.approx(expected.cpu(), abs=1e-4) + # Rowwise scale_inv stats only; logger formats with {:.4f} so abs<1e-4. + if "scale_inv_min" in line and "_columnwise" not in line: + value = float(line.split("value=")[1]) + assert value == pytest.approx(scale_inv_rowwise.min().cpu().item(), abs=1e-4) + if "scale_inv_max" in line and "_columnwise" not in line: + value = float(line.split("value=")[1]) + assert value == pytest.approx(scale_inv_rowwise.max().cpu().item(), abs=1e-4) + if "scale_inv_std" in line and "_columnwise" not in line: + value = float(line.split("value=")[1]) + expected = torch.std(scale_inv_rowwise, unbiased=False).cpu().item() + assert value == pytest.approx(expected, abs=1e-4) LOG_HIGH_PRECISION_CONFIG = """ @@ -370,6 +388,111 @@ def test_log_stats_numerics(feature_dirs, tensor_name): assert found_dynamic_range, "dynamic_range not found in output" +def test_stats_computation_microbatch_reduction(): + """Reducing per-microbatch stats must equal computing them over the concatenation. + + Sweeps every aux-free (high precision) stat in the registry, using its own + compute/combine fns, with microbatches of different means and unequal sizes to + stress the parallel variance combine that uniform ~N(0,1) tests miss. + """ + torch.manual_seed(0) + microbatches = [ + 0.5 * torch.randn(4, 4).cuda() + 0.0, + 0.5 * torch.randn(8, 4).cuda() + 8.0, + 0.5 * torch.randn(6, 4).cuda() - 6.0, + ] + full = torch.cat([mb.flatten() for mb in microbatches]) + + # Aux-free stats compute from the tensor alone; the rest need aux_dict + # (quantized tensors) and raise on {}, so they are out of scope here. + def is_aux_free(stat): + if isinstance(stat, BlockwiseDynamicRangeStat): + return False # parametrized stat, covered by its own direct test + try: + STATS[stat][0](full, {}) + return True + except Exception: # pylint: disable=broad-except + return False + + aux_free = [stat for stat in stats_to_num if is_aux_free(stat)] + + # Fill a [num_microbatches, num_stats] buffer exactly like _Buffer would, then + # check every combinator against its own compute fn over the concatenation. + buffers = torch.zeros(len(microbatches), len(stats_to_num)).cuda() + for i, mb in enumerate(microbatches): + for stat in aux_free: + buffers[i, stats_to_num[stat]] = float(STATS[stat][0](mb, {})) + + for stat in aux_free: + reduced = float(STATS[stat][1](buffers)) + expected = float(STATS[stat][0](full, {})) + assert reduced == pytest.approx( + expected, rel=1e-4, abs=1e-4 + ), f"{stat}: reduced {reduced}, expected {expected}" + + +@pytest.mark.parametrize( + "fp8_recipe, recipe_name", + [ + pytest.param(recipe.MXFP8BlockScaling(), "mxfp8", id="mxfp8"), + pytest.param(recipe.Float8BlockScaling(), "fp8_block_scaling", id="fp8_block_scaling"), + ], +) +def test_scale_inv_std_microbatch_reduction(fp8_recipe, recipe_name): + """scale_inv_std reduced across microbatches must equal std over the concatenation. + + Complements test_stats_computation_microbatch_reduction, which only sweeps + aux-free stats and therefore skips scale_inv_std. Here we drive the same + parallel-variance combine through the aux-dependent path: each microbatch is + quantized with a block recipe (per-block scale_inv -> non-trivial variance), + fed through the registry's own compute/combine fns, and checked against + torch.std over the concatenated scale_inv values (unbiased=False). + """ + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if recipe_name == "mxfp8" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if recipe_name == "fp8_block_scaling" and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) + + torch.manual_seed(0) + # Different means and unequal sizes stress the (mean_i - mean)^2 term. + microbatches = [ + 0.5 * torch.randn(256, 1024).cuda() + 0.0, + 2.0 * torch.randn(512, 1024).cuda() + 8.0, + 0.1 * torch.randn(384, 1024).cuda() - 6.0, + ] + + def rowwise_scale_inv(quantized_tensor): + if hasattr(quantized_tensor, "_scale_inv"): + return quantized_tensor._scale_inv.float() + return quantized_tensor._rowwise_scale_inv.float() + + stat_std = f"{recipe_name}_scale_inv_std" + stat_var = f"{recipe_name}_scale_inv_variance" + stat_numel = f"{recipe_name}_scale_inv_numel" + stat_sum = f"{recipe_name}_scale_inv_sum" + + # Fill a [num_microbatches, num_stats] buffer exactly like _Buffer would, + # using the registry's own per-tensor compute fns through the aux path. + buffers = torch.zeros(len(microbatches), len(stats_to_num)).cuda() + scale_invs = [] + for i, mb in enumerate(microbatches): + recipe_state = RecipeState.create(fp8_recipe, mode="forward", num_quantizers=1) + quantizer = recipe_state.make_quantizers()[0] + quantized_tensor = quantizer(mb) + aux_dict = {recipe_name: quantized_tensor} + scale_invs.append(rowwise_scale_inv(quantized_tensor).flatten()) + for stat in (stat_var, stat_numel, stat_sum): + buffers[i, stats_to_num[stat]] = float(STATS[stat][0](mb, aux_dict)) + + reduced = float(STATS[stat_std][1](buffers)) + expected = float(torch.std(torch.cat(scale_invs), unbiased=False)) + assert reduced == pytest.approx( + expected, rel=1e-4, abs=1e-4 + ), f"{stat_std}: reduced {reduced}, expected {expected}" + + @pytest.mark.parametrize("layer", ["linear", "transformer"]) def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): if not fp8_available: @@ -403,7 +526,8 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): with open( os.path.join( - temp_dir, "nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log" + temp_dir, + "nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log", ), "r", ) as f: diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index d26f9ef7f6..32329b289a 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -119,10 +119,13 @@ class LogFp8TensorStats(BaseLogTensorStats): - overflows% - percentage of elements of tensor that were clipped to the max/min value of the FP8 range - supported only for fp8_delayed_scaling, - scale_inv_min - minimum of the inverse of the scaling factors, - scale_inv_max - maximum of the inverse of the scaling factors, + - scale_inv_std - population standard deviation of the inverse of the scaling factors; + useful for spotting clipping that min/max alone can miss (degenerate to 0 for + fp8_delayed_scaling / fp8_current_scaling since those use a single scalar scale). - mse - mean squared error of the quantized tensor and the original tensor = sum((quantized_tensor - original_tensor)**2) / num_elements, When collecting stats for the weight tensor with FP8 model parameters enabled, - only "scale_inv_min" and "scale_inv_max" are available. + only "scale_inv_min", "scale_inv_max" and "scale_inv_std" are available. All other statistics require access to the high precision tensor. tensors/tensors_struct: List[str] @@ -191,15 +194,8 @@ def check_if_stat_is_supported( if recipe_from_stat != "" and recipe_from_stat not in ALL_RECIPE_NAMES: raise ValueError(f"Stat {stat} contains an unsupported recipe name: {recipe_from_stat}") - # Block any NVFP4 stats in LogFp8TensorStats (FP8-specific logic won't work) - # But allow recipe-prefixed FP8 stats like "mxfp8_underflows%" even with NVFP4 quantizer - if recipe_from_stat == "nvfp4": - raise ValueError( - f"[NVTORCH INSPECT ERROR] Cannot compute NVFP4 stats '{stat}' in LogFp8TensorStats." - " FP8-specific statistics do not work with NVFP4. Use LogNvfp4TensorStats for" - " NVFP4-specific stats, or use FP8 recipe-prefixed stats (e.g.," - " 'mxfp8_underflows%', 'fp8_block_scaling_mse') for what-if FP8 comparisons." - ) + # NVFP4-resolved stats are filtered out before this point in inspect_tensor(). + assert recipe_from_stat != "nvfp4" if recipe_from_stat in ["fp8_delayed_scaling", "fp8_current_scaling"] and columnwise: raise ValueError( @@ -216,7 +212,7 @@ def check_if_stat_is_supported( if recipe_from_stat == "mxfp8" and torch.cuda.get_device_capability()[0] < 10: raise ValueError(f"Stat {stat} needs Blackwell or later GPU.") - supported_stats = ["underflows%", "scale_inv_min", "scale_inv_max", "mse"] + supported_stats = ["underflows%", "scale_inv_min", "scale_inv_max", "scale_inv_std", "mse"] if stat_without_recipe not in supported_stats: raise ValueError( f"Stat {stat} contains an unsupported stat name: {stat_without_recipe}" @@ -338,6 +334,27 @@ def inspect_tensor( recipe_name = _get_recipe_name(quantizer) + # If the layer uses NVFP4, drop bare stats (which would target the NVFP4 + # recipe that LogFp8TensorStats can't handle) but keep stats explicitly + # prefixed with an FP8 recipe (e.g. "mxfp8_mse") for what-if FP8 comparison. + if _nvfp4_available and isinstance(quantizer, NVFP4Quantizer): + kept_stats, dropped_stats = [], [] + for stat in config["stats"]: + if any(r in stat for r in ALL_RECIPE_NAMES): + kept_stats.append(stat) + else: + dropped_stats.append(stat) + if dropped_stats: + warnings.warn( + f"[LogFp8TensorStats] Skipping stats {dropped_stats} for layer " + f"'{layer_name}', tensor '{tensor_name}': layer uses NVFP4. Use " + "LogNvfp4TensorStats for NVFP4 stats, or prefix stats with an FP8 " + "recipe name (e.g. 'mxfp8_mse') for what-if FP8 comparisons." + ) + if not kept_stats: + return + config = {**config, "stats": kept_stats} + for stat in config["stats"]: self.check_if_stat_is_supported( stat, recipe_name, high_precision_tensor_provided=tensor is not None diff --git a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py index 8a76f4edcf..5c98ce49d8 100644 --- a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py +++ b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py @@ -45,6 +45,10 @@ class LogNvfp4TensorStats(BaseLogTensorStats): List of statistics to collect. Available stats: - underflows% - percentage of non-zero elements clipped to 0 (from packed FP4 data) - mse - mean squared error = sum((quantized_tensor - original_tensor)**2) / num_elements + - scale_inv_min - minimum of the inverse of the scaling factors + - scale_inv_max - maximum of the inverse of the scaling factors + - scale_inv_std - population standard deviation of the inverse of the scaling factors; + useful for spotting clipping that min/max alone can miss tensors/tensors_struct: List[str] list of tensors to log @@ -85,13 +89,23 @@ class LogNvfp4TensorStats(BaseLogTensorStats): def check_if_stat_is_supported(self, stat: str): """Returns True if stat is supported, raises ValueError otherwise.""" - supported_stats = [ - "underflows%", - "mse", - ] - if stat not in supported_stats: + # Only scale_inv_* stats have a columnwise variant (separate rowwise/columnwise + # scale_inv); underflows%/mse are computed from the single quantized tensor and + # have no '_columnwise' form, so they must not accept the suffix. + columnwise_stats = ["scale_inv_min", "scale_inv_max", "scale_inv_std"] + supported_stats = ["underflows%", "mse"] + columnwise_stats + + if stat.endswith("_columnwise"): + bare = stat[: -len("_columnwise")] + if bare not in columnwise_stats: + raise ValueError( + f"Stat {stat} is not supported for NVFP4. The '_columnwise' suffix is only" + f" valid for {columnwise_stats}." + ) + elif stat not in supported_stats: raise ValueError( f"Stat {stat} is not supported for NVFP4. Supported stats: {supported_stats}" + " (scale_inv_* may take an optional '_columnwise' suffix)." ) return True diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 6668400017..ab95a0f5c5 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -121,18 +121,30 @@ def _compute_for_one_orientation(tensor): @torch.compile -def compute_variance(variances, numels, sums): - """Welford algorithm is used for numerically stable distributed variance computation.""" - mean = torch.sum(sums) / torch.sum(numels) +def compute_variance(variances, numels, sums, unbiased=True): + """Parallel (Chan/Welford) combination of per-group variances. + + `unbiased` describes both the stored per-group variances and the result: + True -> sample (divide by N-1), False -> population (divide by N). The combine + is done on M2 (sum of squared deviations), which is convention-agnostic, so + both biased and unbiased inputs combine exactly across groups with different + means. The stored per-group variances must use the same `unbiased` convention + as this call - M2 is reconstructed as variance * (n-1) when unbiased else + variance * n, so mixing conventions silently corrupts the result. + """ + total = torch.sum(numels) + mean = torch.sum(sums) / total means = sums / numels - var = torch.sum(numels * (variances - torch.pow((means - mean), 2))) / torch.sum(numels) - return var + # Reconstruct each group's M2, recombine, then re-divide once at the end. + per_group_div = (numels - 1) if unbiased else numels + m2 = torch.sum(variances * per_group_div + numels * torch.pow(means - mean, 2)) + return m2 / ((total - 1) if unbiased else total) @torch.compile -def compute_std(variances, numels, sums): - """Computates standard deviation.""" - return torch.sqrt(compute_variance(variances, numels, sums)) +def compute_std(variances, numels, sums, unbiased=True): + """Computes standard deviation; see `compute_variance` for `unbiased`.""" + return torch.sqrt(compute_variance(variances, numels, sums, unbiased=unbiased)) def compute_fp8_delayed_scaling_overflows_num(tensor, quantized_tensor): @@ -335,11 +347,13 @@ def add_underflows_stats(recipe_name: str, columnwise: bool = False): def add_scale_inv_stats(recipe_name: str, columnwise: bool = False): - """Register *both* scale-inv min and max stats for a given recipe. + """Register scale-inv min/max/std stats for a given recipe. - This replaces the earlier separate helpers and avoids duplicated boilerplate. + The std uses Welford's algorithm to combine partial variances across + microbatches/ranks, so helper buffers for variance/numel/sum are also + registered. Population variance (unbiased=False) is used so single-element + scale_inv tensors (delayed/current scaling) yield std=0 rather than NaN. """ - # Determine which attribute holds the scale-inverse tensor. def get_scale_inv(quantized_tensor, columnwise): if hasattr(quantized_tensor, "_scale_inv"): @@ -348,48 +362,86 @@ def get_scale_inv(quantized_tensor, columnwise): return getattr(quantized_tensor, "_columnwise_scale_inv") return getattr(quantized_tensor, "_rowwise_scale_inv") - def nonzero_min(scale_inv): + def _prefix(): + return f"{recipe_name}{'_' if recipe_name != '' else ''}" + + def real_scale_inv(quantized_tensor, columnwise): # MXFP8/NVFP4 quantizers round the scale_inv shape up to multiples of - # 128 along one axis and 4 along the other and fill the extra slots - # with zeros (via torch.nn.functional.pad with the default value=0), - # so a plain .min() always returns 0 for shapes that needed padding. - # A real scale_inv entry is never 0: compute_scale_from_amax returns - # scale=1.0 for all-zero blocks and clamps the inf case to a finite - # fallback, so zeros uniquely identify padding and masking them out - # gives the true minimum. The empty-after-mask branch is a safety - # net for the (in practice unreachable) all-zero tensor. - nz = scale_inv[scale_inv != 0] - if nz.numel() == 0: - return scale_inv.new_zeros(()) - return nz.min() + # 128 along one axis and 4 along the other and fill the extra slots with + # zeros (via torch.nn.functional.pad with the default value=0). A real + # scale_inv entry is never 0: compute_scale_from_amax returns scale=1.0 + # for all-zero blocks and clamps the inf case to a finite fallback, so + # zeros uniquely identify padding. Every scale_inv stat masks them out; + # otherwise the padding deflates the mean and shows up as spurious spread + # (min/std) and inflated counts (numel), which would also corrupt the + # numel-weighted variance reduction across microbatches/ranks. + scale_inv = get_scale_inv(quantized_tensor, columnwise).float() + return scale_inv[scale_inv != 0] + + def scale_inv_min(quantized_tensor, columnwise): + nz = real_scale_inv(quantized_tensor, columnwise) + # Empty only for the (in practice unreachable) all-zero tensor. + return nz.min() if nz.numel() > 0 else nz.new_zeros(()) columnwise_suffix = "_columnwise" if columnwise else "" - # Prepare stat names. - stat_name_min = ( - f"{recipe_name}{'_' if recipe_name != '' else ''}scale_inv_min{columnwise_suffix}" - ) - stat_name_max = ( - f"{recipe_name}{'_' if recipe_name != '' else ''}scale_inv_max{columnwise_suffix}" - ) + stat_name_min = f"{_prefix()}scale_inv_min{columnwise_suffix}" + stat_name_max = f"{_prefix()}scale_inv_max{columnwise_suffix}" + stat_name_std = f"{_prefix()}scale_inv_std{columnwise_suffix}" + stat_name_var = f"{_prefix()}scale_inv_variance{columnwise_suffix}" + stat_name_numel = f"{_prefix()}scale_inv_numel{columnwise_suffix}" + stat_name_sum = f"{_prefix()}scale_inv_sum{columnwise_suffix}" # Assign indices in `stats_to_num` (order matters — keep insertion order deterministic). - stats_to_num[stat_name_min] = len(stats_to_num) - stats_to_num[stat_name_max] = len(stats_to_num) + for name in ( + stat_name_min, + stat_name_max, + stat_name_std, + stat_name_var, + stat_name_numel, + stat_name_sum, + ): + stats_to_num[name] = len(stats_to_num) # Capture the attribute name inside lambdas via default args to avoid late binding. STATS[stat_name_min] = ( - lambda x, aux_dict, _col=columnwise: nonzero_min( - get_scale_inv(aux_dict[recipe_name], _col) - ), + lambda x, aux_dict, _col=columnwise: scale_inv_min(aux_dict[recipe_name], _col), lambda buffers, _sn=stat_name_min: min(_get(buffers, _sn)), ) STATS[stat_name_max] = ( lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col).max(), lambda buffers, _sn=stat_name_max: max(_get(buffers, _sn)), ) + STATS[stat_name_var] = ( + lambda x, aux_dict, _col=columnwise: torch.var( + real_scale_inv(aux_dict[recipe_name], _col), unbiased=False + ), + lambda buffers, _sv=stat_name_var, _sn=stat_name_numel, _ss=stat_name_sum: compute_variance( + _get(buffers, _sv), _get(buffers, _sn), _get(buffers, _ss), unbiased=False + ), + ) + STATS[stat_name_numel] = ( + lambda x, aux_dict, _col=columnwise: real_scale_inv(aux_dict[recipe_name], _col).numel(), + lambda buffers, _sn=stat_name_numel: sum(_get(buffers, _sn)), + ) + STATS[stat_name_sum] = ( + lambda x, aux_dict, _col=columnwise: real_scale_inv(aux_dict[recipe_name], _col).sum(), + lambda buffers, _ss=stat_name_sum: sum(_get(buffers, _ss)), + ) + STATS[stat_name_std] = ( + lambda x, aux_dict, _col=columnwise: torch.std( + real_scale_inv(aux_dict[recipe_name], _col), unbiased=False + ), + lambda buffers, _sv=stat_name_var, _sn=stat_name_numel, _ss=stat_name_sum: compute_std( + _get(buffers, _sv), _get(buffers, _sn), _get(buffers, _ss), unbiased=False + ), + ) DEPENDENCIES[stat_name_min] = {stat_name_min} DEPENDENCIES[stat_name_max] = {stat_name_max} + DEPENDENCIES[stat_name_numel] = {stat_name_numel} + DEPENDENCIES[stat_name_sum] = {stat_name_sum} + DEPENDENCIES[stat_name_var] = {stat_name_var, stat_name_numel, stat_name_sum} + DEPENDENCIES[stat_name_std] = {stat_name_var, stat_name_numel, stat_name_sum} def add_mse_stats(recipe_name: str, columnwise: bool = False): @@ -522,3 +574,5 @@ def add_nvfp4_underflows_stats(): # Register NVFP4 stats add_nvfp4_underflows_stats() add_mse_stats("nvfp4") # Reuse existing MSE function +for _columnwise in [True, False]: + add_scale_inv_stats("nvfp4", _columnwise)