-
Notifications
You must be signed in to change notification settings - Fork 195
[OMNIML-3015]Add per tensor/per channel MSE calibrator #540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Fridah-nv <[email protected]>
Signed-off-by: Fridah-nv <[email protected]>
Signed-off-by: Fridah-nv <[email protected]>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #540 +/- ##
==========================================
+ Coverage 74.45% 74.64% +0.19%
==========================================
Files 182 183 +1
Lines 18250 18389 +139
==========================================
+ Hits 13588 13727 +139
Misses 4662 4662 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Fridah-nv <[email protected]>
Signed-off-by: Fridah-nv <[email protected]>
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
WalkthroughThis pull request introduces MSE-based calibration support for per-tensor and per-channel quantization. It adds a new Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant mse_calibrate
participant TensorQuantizer
participant MseCalibrator
participant ForwardLoop
User->>mse_calibrate: mse_calibrate(model, forward_loop, ...)
mse_calibrate->>mse_calibrate: Initial max calibration
mse_calibrate->>TensorQuantizer: Create MseCalibrator per quantizer
mse_calibrate->>TensorQuantizer: Replace existing calibrator
alt With forward_loop
mse_calibrate->>ForwardLoop: Iterate over batches
ForwardLoop->>TensorQuantizer: Run forward passes
else Without forward_loop
mse_calibrate->>MseCalibrator: Manual collect() calls
end
TensorQuantizer->>MseCalibrator: collect(x) - test candidates
MseCalibrator->>MseCalibrator: Compute losses for amax variants
MseCalibrator->>MseCalibrator: Accumulate statistics
mse_calibrate->>MseCalibrator: compute_amax()
MseCalibrator->>MseCalibrator: Select optimal amax (min loss)
MseCalibrator-->>mse_calibrate: Best amax value
mse_calibrate->>TensorQuantizer: Load optimal amax
mse_calibrate-->>User: Calibrated model
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Areas requiring extra attention:
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tests/unit/torch/quantization/test_mse_calibrator.py (1)
26-26: Consider extracting the quant_func wrapper to reduce duplication.The TODO comment on Line 26 correctly identifies code duplication. The
quant_funcwrapper pattern (Lines 51-67, 98-114, and repeated in all subsequent tests) could be extracted as a test fixture or helper method to improve maintainability.Example refactoring:
def create_quant_func(tq): """Create a quant_func wrapper for the given TensorQuantizer.""" def quant_func(x, amax): original_amax = tq._amax.clone() if hasattr(tq, "_amax") else None was_quant_enabled = tq._if_quant was_calib_enabled = tq._if_calib tq._amax = amax tq._if_quant = True tq._if_calib = False with enable_fake_quant(tq): xq = tq(x) if original_amax is not None: tq._amax = original_amax tq._if_quant = was_quant_enabled tq._if_calib = was_calib_enabled return xq return quant_funcAlso applies to: 51-67, 98-114
modelopt/torch/quantization/model_calib.py (1)
188-267: Clarify distributed sync semantics and harden_amaxhandling inmse_calibrateThe overall two‑stage flow (initial
max_calibrate→ swap inMseCalibrator→ second stats pass →finish_stats_collection(method="mse")) looks sound, and the per‑quantizerquant_funcwrapper is a nice way to reuse existing quantization logic.Two follow‑ups are worth considering:
distributed_synconly affects the max phase right now
distributed_syncis passed tomax_calibrate(Step 1), but the final MSE‑refined amax values are never synchronized; there is only a# TODO: Sync amax across distributed processescomment at Line 267.- As written, different DP/TP ranks can end up with different final amax values even when
distributed_sync=True, which diverges from the docstring expectation.Consider either:
- Wiring the same DP/TP sync logic used in
max_calibrateto run afterfinish_stats_collection(model, method="mse"), or- Explicitly documenting that
distributed_synccurrently applies only to the initial max pass until full sync is implemented.
- Make
quant_funcmore robust if_amaxis missing orNone
initial_amax = module._amax.clone().detach()andoriginal_amax = quantizer._amax.clone() if hasattr(quantizer, "_amax") else Noneassume_amaxis always a tensor whenever the attribute exists.- If a quantizer somehow reaches this code path without a valid tensor amax (e.g., not hit during the initial
forward_loop), this will raise.A small defensive tweak keeps the behavior the same in the common case while avoiding a hard failure in edge cases:
- initial_amax = module._amax.clone().detach() + initial_amax = getattr(module, "_amax", None) + if initial_amax is None: + # Skip MSE search for quantizers without an initial amax + continue + initial_amax = initial_amax.clone().detach() @@ - def quant_func(x, amax, quantizer=module): - original_amax = quantizer._amax.clone() if hasattr(quantizer, "_amax") else None + def quant_func(x, amax, quantizer=module): + original_amax = getattr(quantizer, "_amax", None) + if original_amax is not None: + original_amax = original_amax.clone() quantizer._amax = amax @@ - if original_amax is not None: - quantizer._amax = original_amax + if original_amax is not None: + quantizer._amax = original_amaxThis keeps
mse_calibrateresilient if some quantizers never obtain an initial amax in the first pass.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
CHANGELOG.rst(1 hunks)modelopt/torch/quantization/calib/__init__.py(1 hunks)modelopt/torch/quantization/calib/mse.py(1 hunks)modelopt/torch/quantization/config.py(1 hunks)modelopt/torch/quantization/mode.py(3 hunks)modelopt/torch/quantization/model_calib.py(3 hunks)modelopt/torch/quantization/nn/modules/tensor_quantizer.py(4 hunks)modelopt/torch/quantization/utils.py(3 hunks)tests/unit/torch/quantization/test_mse_calibrator.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-18T20:15:04.615Z
Learnt from: realAsma
Repo: NVIDIA/TensorRT-Model-Optimizer PR: 332
File: modelopt/torch/quantization/algorithms.py:323-326
Timestamp: 2025-09-18T20:15:04.615Z
Learning: In modelopt/torch/quantization/algorithms.py, the `_is_auto_quantize_module` method requires `isinstance(module, QuantModule)` because some modules like MCore Column/Row Parallel Linear are `QuantModule` but not `QuantLinearConvBase`. The check ensures all quantization-capable modules are included in AutoQuantize search.
Applied to files:
modelopt/torch/quantization/nn/modules/tensor_quantizer.py
🔇 Additional comments (20)
CHANGELOG.rst (1)
16-16: LGTM!The changelog entry clearly documents the new MSE calibrator feature.
modelopt/torch/quantization/utils.py (3)
45-45: LGTM!The addition of
reduce_sumto__all__correctly exposes the new utility function.
191-218: LGTM!The
reduce_sumfunction is well-implemented:
- Properly decorated with
@torch.no_grad()for calibration use- Consistent API with
reduce_amax- Handles all axis configurations correctly
633-661: LGTM!The context managers
enable_quantanddisable_calibare well-implemented:
- Follow standard context manager patterns
- Properly restore original state in finally blocks
- Clear docstrings explaining their purpose
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
428-435: LGTM!The
is_static_block_quantproperty is a good refactoring that:
- Centralizes the logic for detecting static block quantization
- Improves code readability and maintainability
- Has a clear implementation checking the necessary conditions
638-638: LGTM!The usage of
is_static_block_quantproperty correctly replaces inline checks and improves code clarity. The comment update on Line 638 also helps distinguish between dynamic and static block quantization paths.Also applies to: 908-912, 944-945
modelopt/torch/quantization/calib/mse.py (4)
32-66: LGTM!The
__init__method is well-structured:
- Parameters are clearly documented
- Uses lists to track per-step statistics which is appropriate for typical num_steps values
- Properly inherits from
_Calibratorbase class
68-108: LGTM!The
collectmethod is well-implemented:
- Good defensive programming with the quant_func validation (Lines 76-78)
- Converts to float32 for accurate MSE computation (Line 80)
- Uses
+=for loss accumulation (Line 107) which addresses the past review comment about memory efficiency- Properly handles per-channel reduction via
convert_quantization_axis_to_reduce_axisBased on learnings
110-115: LGTM!The
resetmethod properly clears all accumulated state, allowing the calibrator to be reused.
117-189: LGTM!The
compute_amaxmethod correctly handles both per-tensor and per-channel cases:
- Returns None if no data collected (good defensive behavior)
- Per-tensor case (Lines 137-153): Correctly averages losses and selects best candidate
- Per-channel case (Lines 154-188): Uses proper tensor indexing to select per-channel optimal amax
- Verbose output provides useful debugging information
- Line 178 properly reshapes result to match initial amax shape
modelopt/torch/quantization/calib/__init__.py (1)
26-26: LGTM!The import correctly exposes the new
msemodule via wildcard import, makingMseCalibratoravailable in thecalibnamespace.modelopt/torch/quantization/config.py (1)
984-1021: LGTM!The
MseCalibConfigclass is well-designed:
- Clear docstring explaining the MSE calibration approach
- Appropriate field validators (num_steps ≥ 1, multipliers > 0)
- Consistent with other calibration configs in the module
- Good default values (num_steps=10, start_multiplier=0.25, stop_multiplier=4.0)
modelopt/torch/quantization/mode.py (2)
41-41: LGTM!The imports correctly add
MseCalibConfigandmse_calibrateto support the new MSE calibration mode.Also applies to: 58-58
367-377: LGTM!The
MseCalibrateModeDescriptorfollows the established pattern for calibration mode descriptors:
- Properly inherits from
BaseCalibrateModeDescriptor- Returns correct config class
- References the
mse_calibratefunction- Correctly registered with
@CalibrateModeRegistry.register_modetests/unit/torch/quantization/test_mse_calibrator.py (4)
27-36: LGTM!The
_mse_at_ahelper function is a useful utility for computing MSE at a given amax value for validation purposes.
40-86: LGTM!Good test coverage for per-tensor MSE calibration with outliers:
- Tests that calibrator can find better amax than initial max
- Validates finite and bounded results
- Confirms loss improvement over baseline
88-133: LGTM!Excellent test coverage for various per-tensor scenarios:
- Negative outliers (Lines 88-133)
- Unsigned quantization (Lines 135-185)
- Multiple collections (Lines 187-231)
- Custom error function (Lines 233-280)
- Reset functionality (Lines 282-319)
All tests validate correct behavior with appropriate assertions.
Also applies to: 135-185, 187-231, 233-280, 282-319
321-423: LGTM!Comprehensive test coverage for per-channel MSE calibration:
- Basic per-channel test with axis=0 (Lines 321-370)
- Multiple collections (Lines 371-423)
- Independent per-channel optimization with different scales (Lines 424-478)
- Custom error function support (Lines 480-528)
All tests properly validate shape, finiteness, and positivity of per-channel results.
Also applies to: 424-478, 480-528
modelopt/torch/quantization/model_calib.py (2)
32-38: Imports for MSE calibrator and quantizer state helpers are consistent
MseCalibrator,disable_calib,enable_fake_quant, andenable_quantare all used inmse_calibrate, and no unused symbols are introduced here.
283-303:finish_stats_collectioncorrectly distinguishes MSE/entropy vs max calibratorsThe refactor to iterate only over
TensorQuantizermodules and branch onmethod in {"mse", "entropy"}vs the default (compute_amax()with no method) is clean and keeps existing callers (which pass nomethod) compatible. Bias handling and the finalenable_quant()/disable_calib()are preserved.Just ensure that any calibrator used with
method="mse"ormethod="entropy"implementscompute_amax(method: str)as expected; with that contract satisfied, this helper looks good.
Signed-off-by: Fridah-nv <[email protected]>
Signed-off-by: Frida Hou <[email protected]>
realAsma
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!!
Can we also have an end to end unittest of the flow in tests/unit/torch/quantization/test_quantize_cpu.py::test_quantize?
Signed-off-by: Fridah-nv <[email protected]>
What does this PR do?
Type of change: ?
new feature
Overview: ?
Add per tensor/per channel MSE calibrator.
Usage
Can be enabled with "algorithm" field in quantization configs.
Testing
Unit test for the MseCalibrator,
E2E test with NVFP4 and INT8,
**results: **
start_multiplier=0.25
stop_multiplier=4.0
num_steps=20
Qwen3-8B MMLU:
BF16 baseline: 72.94
Before your PR is "Ready for review"
Additional Information
TODO: for the follow up PR:
Summary by CodeRabbit
Release Notes
New Features
Tests
Documentation
✏️ Tip: You can customize this high-level summary in your review settings.