Skip to content

Conversation

@Fridah-nv
Copy link
Contributor

@Fridah-nv Fridah-nv commented Nov 12, 2025

What does this PR do?

Type of change: ?
new feature

Overview: ?

Add per tensor/per channel MSE calibrator.

Usage

Can be enabled with "algorithm" field in quantization configs.

"algorithm": {"method": "mse", "num_steps": 20, "stop_multiplier": 8.0},

Testing

Unit test for the MseCalibrator,
E2E test with NVFP4 and INT8,

**results: **
start_multiplier=0.25
stop_multiplier=4.0
num_steps=20

Qwen3-8B MMLU:

BF16 baseline: 72.94

Calib Algo NVFP4 FP8 INT8
MSE 70.88 72.65 55.46
MAX 70.83 72.7 24.52

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

TODO: for the follow up PR:

  • TP sync for HF models
  • Calculate weight quantizer only once

Summary by CodeRabbit

Release Notes

  • New Features

    • Added MSE-based quantization calibration supporting per-tensor and per-channel optimization with configurable parameters (step count, multiplier ranges).
  • Tests

    • Added comprehensive test coverage for MSE calibration functionality.
  • Documentation

    • Updated changelog to reflect MSE calibrator support.

✏️ Tip: You can customize this high-level summary in your review settings.

@Fridah-nv Fridah-nv self-assigned this Nov 12, 2025
@Fridah-nv Fridah-nv requested a review from a team as a code owner November 12, 2025 01:28
@Fridah-nv Fridah-nv changed the title Add per tensor/per channel MSE calibrator [OMNIML-3015]Add per tensor/per channel MSE calibrator Nov 13, 2025
@codecov
Copy link

codecov bot commented Nov 14, 2025

Codecov Report

❌ Patch coverage is 94.63087% with 8 lines in your changes missing coverage. Please review.
✅ Project coverage is 74.64%. Comparing base (38550b0) to head (2fca0db).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/calib/mse.py 94.11% 4 Missing ⚠️
modelopt/torch/quantization/model_calib.py 92.50% 3 Missing ⚠️
modelopt/torch/quantization/utils.py 95.45% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #540      +/-   ##
==========================================
+ Coverage   74.45%   74.64%   +0.19%     
==========================================
  Files         182      183       +1     
  Lines       18250    18389     +139     
==========================================
+ Hits        13588    13727     +139     
  Misses       4662     4662              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Fridah-nv <[email protected]>
@realAsma
Copy link
Contributor

@coderabbitai review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 20, 2025

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 20, 2025

Walkthrough

This pull request introduces MSE-based calibration support for per-tensor and per-channel quantization. It adds a new MseCalibrator class, configuration structures, model calibration integration, and utility functions, along with comprehensive unit tests. The feature enables optimal amax selection through mean-squared-error minimization during the calibration phase.

Changes

Cohort / File(s) Summary
Core MSE Calibrator Implementation
modelopt/torch/quantization/calib/mse.py, modelopt/torch/quantization/calib/__init__.py
New MseCalibrator class performing per-tensor/per-channel MSE-based calibration. Collects candidate amax values and computes optimal amax via error minimization. Module exposed via package __init__.py.
Configuration & Registration
modelopt/torch/quantization/config.py, modelopt/torch/quantization/mode.py
Added MseCalibConfig with tunable parameters (num_steps, start/stop multipliers, distributed_sync). Introduced MseCalibrateModeDescriptor registered in calibration mode registry.
Model Calibration Integration
modelopt/torch/quantization/model_calib.py
New mse_calibrate function orchestrating MSE-based amax calibration. Replaces calibrators, collects data, and loads optimal amax. Enhanced finish_stats_collection to handle multiple calibrator types.
Infrastructure & Utilities
modelopt/torch/quantization/utils.py, modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Added reduce_sum function, enable_quant and disable_calib context managers. Added is_static_block_quant property to TensorQuantizer for cleaner static block detection.
Documentation & Testing
CHANGELOG.rst, tests/unit/torch/quantization/test_mse_calibrator.py
Updated changelog. Comprehensive unit tests covering per-tensor/per-channel calibration, reset behavior, custom error functions, and multi-collection scenarios.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant mse_calibrate
    participant TensorQuantizer
    participant MseCalibrator
    participant ForwardLoop

    User->>mse_calibrate: mse_calibrate(model, forward_loop, ...)
    mse_calibrate->>mse_calibrate: Initial max calibration
    mse_calibrate->>TensorQuantizer: Create MseCalibrator per quantizer
    mse_calibrate->>TensorQuantizer: Replace existing calibrator
    
    alt With forward_loop
        mse_calibrate->>ForwardLoop: Iterate over batches
        ForwardLoop->>TensorQuantizer: Run forward passes
    else Without forward_loop
        mse_calibrate->>MseCalibrator: Manual collect() calls
    end
    
    TensorQuantizer->>MseCalibrator: collect(x) - test candidates
    MseCalibrator->>MseCalibrator: Compute losses for amax variants
    MseCalibrator->>MseCalibrator: Accumulate statistics
    
    mse_calibrate->>MseCalibrator: compute_amax()
    MseCalibrator->>MseCalibrator: Select optimal amax (min loss)
    MseCalibrator-->>mse_calibrate: Best amax value
    
    mse_calibrate->>TensorQuantizer: Load optimal amax
    mse_calibrate-->>User: Calibrated model
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Areas requiring extra attention:

  • modelopt/torch/quantization/model_calib.py — Core calibration logic with calibrator replacement, quant_func context management, and error handling. Verify static block quant validation and distributed sync placeholder.
  • modelopt/torch/quantization/calib/mse.py — MSE loss computation, candidate amax generation, per-channel vs. per-tensor axis reduction, and averaging logic in compute_amax.
  • tests/unit/torch/quantization/test_mse_calibrator.py — Validate test coverage of edge cases (signed/unsigned, multi-collection, custom error functions) and correctness of per-channel amax shapes.
  • Context manager interactions in enable_quant and disable_calib — Ensure proper state restoration in all execution paths.

Poem

🐰 Hop skip and a quantize!
Per-tensor, per-channel, we minimize
MSE magic makes amax precise,
Calibration's dance—errors sacrifice!
A fuzzy addition to our toolkit so nice! ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 62.22% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The pull request title '[OMNIML-3015] Add per tensor/per channel MSE calibrator' clearly and specifically describes the main change: adding MSE calibrator functionality with per-tensor and per-channel support capabilities.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch fridah/calib

Tip

📝 Customizable high-level summaries are now available in beta!

You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.

  • Provide your own instructions using the high_level_summary_instructions setting.
  • Format the summary however you like (bullet lists, tables, multi-section layouts, contributor stats, etc.).
  • Use high_level_summary_in_walkthrough to move the summary from the description to the walkthrough section.

Example instruction:

"Divide the high-level summary into five sections:

  1. 📝 Description — Summarize the main change in 50–60 words, explaining what was done.
  2. 📓 References — List relevant issues, discussions, documentation, or related PRs.
  3. 📦 Dependencies & Requirements — Mention any new/updated dependencies, environment variable changes, or configuration updates.
  4. 📊 Contributor Summary — Include a Markdown table showing contributions:
    | Contributor | Lines Added | Lines Removed | Files Changed |
  5. ✔️ Additional Notes — Add any extra reviewer context.
    Keep each section concise (under 200 words) and use bullet or numbered lists for clarity."

Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tests/unit/torch/quantization/test_mse_calibrator.py (1)

26-26: Consider extracting the quant_func wrapper to reduce duplication.

The TODO comment on Line 26 correctly identifies code duplication. The quant_func wrapper pattern (Lines 51-67, 98-114, and repeated in all subsequent tests) could be extracted as a test fixture or helper method to improve maintainability.

Example refactoring:

def create_quant_func(tq):
    """Create a quant_func wrapper for the given TensorQuantizer."""
    def quant_func(x, amax):
        original_amax = tq._amax.clone() if hasattr(tq, "_amax") else None
        was_quant_enabled = tq._if_quant
        was_calib_enabled = tq._if_calib

        tq._amax = amax
        tq._if_quant = True
        tq._if_calib = False

        with enable_fake_quant(tq):
            xq = tq(x)

        if original_amax is not None:
            tq._amax = original_amax
        tq._if_quant = was_quant_enabled
        tq._if_calib = was_calib_enabled
        return xq
    return quant_func

Also applies to: 51-67, 98-114

modelopt/torch/quantization/model_calib.py (1)

188-267: Clarify distributed sync semantics and harden _amax handling in mse_calibrate

The overall two‑stage flow (initial max_calibrate → swap in MseCalibrator → second stats pass → finish_stats_collection(method="mse")) looks sound, and the per‑quantizer quant_func wrapper is a nice way to reuse existing quantization logic.

Two follow‑ups are worth considering:

  1. distributed_sync only affects the max phase right now
  • distributed_sync is passed to max_calibrate (Step 1), but the final MSE‑refined amax values are never synchronized; there is only a # TODO: Sync amax across distributed processes comment at Line 267.
  • As written, different DP/TP ranks can end up with different final amax values even when distributed_sync=True, which diverges from the docstring expectation.

Consider either:

  • Wiring the same DP/TP sync logic used in max_calibrate to run after finish_stats_collection(model, method="mse"), or
  • Explicitly documenting that distributed_sync currently applies only to the initial max pass until full sync is implemented.
  1. Make quant_func more robust if _amax is missing or None
  • initial_amax = module._amax.clone().detach() and original_amax = quantizer._amax.clone() if hasattr(quantizer, "_amax") else None assume _amax is always a tensor whenever the attribute exists.
  • If a quantizer somehow reaches this code path without a valid tensor amax (e.g., not hit during the initial forward_loop), this will raise.

A small defensive tweak keeps the behavior the same in the common case while avoiding a hard failure in edge cases:

-                initial_amax = module._amax.clone().detach()
+                initial_amax = getattr(module, "_amax", None)
+                if initial_amax is None:
+                    # Skip MSE search for quantizers without an initial amax
+                    continue
+                initial_amax = initial_amax.clone().detach()
@@
-                def quant_func(x, amax, quantizer=module):
-                    original_amax = quantizer._amax.clone() if hasattr(quantizer, "_amax") else None
+                def quant_func(x, amax, quantizer=module):
+                    original_amax = getattr(quantizer, "_amax", None)
+                    if original_amax is not None:
+                        original_amax = original_amax.clone()
                     quantizer._amax = amax
@@
-                    if original_amax is not None:
-                        quantizer._amax = original_amax
+                    if original_amax is not None:
+                        quantizer._amax = original_amax

This keeps mse_calibrate resilient if some quantizers never obtain an initial amax in the first pass.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b1fc1fe and 3de3593.

📒 Files selected for processing (9)
  • CHANGELOG.rst (1 hunks)
  • modelopt/torch/quantization/calib/__init__.py (1 hunks)
  • modelopt/torch/quantization/calib/mse.py (1 hunks)
  • modelopt/torch/quantization/config.py (1 hunks)
  • modelopt/torch/quantization/mode.py (3 hunks)
  • modelopt/torch/quantization/model_calib.py (3 hunks)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4 hunks)
  • modelopt/torch/quantization/utils.py (3 hunks)
  • tests/unit/torch/quantization/test_mse_calibrator.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-18T20:15:04.615Z
Learnt from: realAsma
Repo: NVIDIA/TensorRT-Model-Optimizer PR: 332
File: modelopt/torch/quantization/algorithms.py:323-326
Timestamp: 2025-09-18T20:15:04.615Z
Learning: In modelopt/torch/quantization/algorithms.py, the `_is_auto_quantize_module` method requires `isinstance(module, QuantModule)` because some modules like MCore Column/Row Parallel Linear are `QuantModule` but not `QuantLinearConvBase`. The check ensures all quantization-capable modules are included in AutoQuantize search.

Applied to files:

  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py
🔇 Additional comments (20)
CHANGELOG.rst (1)

16-16: LGTM!

The changelog entry clearly documents the new MSE calibrator feature.

modelopt/torch/quantization/utils.py (3)

45-45: LGTM!

The addition of reduce_sum to __all__ correctly exposes the new utility function.


191-218: LGTM!

The reduce_sum function is well-implemented:

  • Properly decorated with @torch.no_grad() for calibration use
  • Consistent API with reduce_amax
  • Handles all axis configurations correctly

633-661: LGTM!

The context managers enable_quant and disable_calib are well-implemented:

  • Follow standard context manager patterns
  • Properly restore original state in finally blocks
  • Clear docstrings explaining their purpose
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)

428-435: LGTM!

The is_static_block_quant property is a good refactoring that:

  • Centralizes the logic for detecting static block quantization
  • Improves code readability and maintainability
  • Has a clear implementation checking the necessary conditions

638-638: LGTM!

The usage of is_static_block_quant property correctly replaces inline checks and improves code clarity. The comment update on Line 638 also helps distinguish between dynamic and static block quantization paths.

Also applies to: 908-912, 944-945

modelopt/torch/quantization/calib/mse.py (4)

32-66: LGTM!

The __init__ method is well-structured:

  • Parameters are clearly documented
  • Uses lists to track per-step statistics which is appropriate for typical num_steps values
  • Properly inherits from _Calibrator base class

68-108: LGTM!

The collect method is well-implemented:

  • Good defensive programming with the quant_func validation (Lines 76-78)
  • Converts to float32 for accurate MSE computation (Line 80)
  • Uses += for loss accumulation (Line 107) which addresses the past review comment about memory efficiency
  • Properly handles per-channel reduction via convert_quantization_axis_to_reduce_axis

Based on learnings


110-115: LGTM!

The reset method properly clears all accumulated state, allowing the calibrator to be reused.


117-189: LGTM!

The compute_amax method correctly handles both per-tensor and per-channel cases:

  • Returns None if no data collected (good defensive behavior)
  • Per-tensor case (Lines 137-153): Correctly averages losses and selects best candidate
  • Per-channel case (Lines 154-188): Uses proper tensor indexing to select per-channel optimal amax
  • Verbose output provides useful debugging information
  • Line 178 properly reshapes result to match initial amax shape
modelopt/torch/quantization/calib/__init__.py (1)

26-26: LGTM!

The import correctly exposes the new mse module via wildcard import, making MseCalibrator available in the calib namespace.

modelopt/torch/quantization/config.py (1)

984-1021: LGTM!

The MseCalibConfig class is well-designed:

  • Clear docstring explaining the MSE calibration approach
  • Appropriate field validators (num_steps ≥ 1, multipliers > 0)
  • Consistent with other calibration configs in the module
  • Good default values (num_steps=10, start_multiplier=0.25, stop_multiplier=4.0)
modelopt/torch/quantization/mode.py (2)

41-41: LGTM!

The imports correctly add MseCalibConfig and mse_calibrate to support the new MSE calibration mode.

Also applies to: 58-58


367-377: LGTM!

The MseCalibrateModeDescriptor follows the established pattern for calibration mode descriptors:

  • Properly inherits from BaseCalibrateModeDescriptor
  • Returns correct config class
  • References the mse_calibrate function
  • Correctly registered with @CalibrateModeRegistry.register_mode
tests/unit/torch/quantization/test_mse_calibrator.py (4)

27-36: LGTM!

The _mse_at_a helper function is a useful utility for computing MSE at a given amax value for validation purposes.


40-86: LGTM!

Good test coverage for per-tensor MSE calibration with outliers:

  • Tests that calibrator can find better amax than initial max
  • Validates finite and bounded results
  • Confirms loss improvement over baseline

88-133: LGTM!

Excellent test coverage for various per-tensor scenarios:

  • Negative outliers (Lines 88-133)
  • Unsigned quantization (Lines 135-185)
  • Multiple collections (Lines 187-231)
  • Custom error function (Lines 233-280)
  • Reset functionality (Lines 282-319)

All tests validate correct behavior with appropriate assertions.

Also applies to: 135-185, 187-231, 233-280, 282-319


321-423: LGTM!

Comprehensive test coverage for per-channel MSE calibration:

  • Basic per-channel test with axis=0 (Lines 321-370)
  • Multiple collections (Lines 371-423)
  • Independent per-channel optimization with different scales (Lines 424-478)
  • Custom error function support (Lines 480-528)

All tests properly validate shape, finiteness, and positivity of per-channel results.

Also applies to: 424-478, 480-528

modelopt/torch/quantization/model_calib.py (2)

32-38: Imports for MSE calibrator and quantizer state helpers are consistent

MseCalibrator, disable_calib, enable_fake_quant, and enable_quant are all used in mse_calibrate, and no unused symbols are introduced here.


283-303: finish_stats_collection correctly distinguishes MSE/entropy vs max calibrators

The refactor to iterate only over TensorQuantizer modules and branch on method in {"mse", "entropy"} vs the default (compute_amax() with no method) is clean and keeps existing callers (which pass no method) compatible. Bias handling and the final enable_quant()/disable_calib() are preserved.

Just ensure that any calibrator used with method="mse" or method="entropy" implements compute_amax(method: str) as expected; with that contract satisfied, this helper looks good.

Copy link
Contributor

@realAsma realAsma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!!

Can we also have an end to end unittest of the flow in tests/unit/torch/quantization/test_quantize_cpu.py::test_quantize?

@Fridah-nv Fridah-nv merged commit 93f5bbf into main Nov 21, 2025
27 checks passed
@Fridah-nv Fridah-nv deleted the fridah/calib branch November 21, 2025 20:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants