Skip to content

Commit bd66d44

Browse files
committed
feat: Expose quantization API in torch_tensorrt.dynamo
1 parent d25b50c commit bd66d44

File tree

3 files changed

+41
-10
lines changed

3 files changed

+41
-10
lines changed

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
save_cross_compiled_exported_program,
1616
)
1717
from ._exporter import export
18+
from ._quantization import quantize
1819
from ._refit import refit_module_weights
1920
from ._settings import CompilationSettings
2021
from ._SourceIR import SourceIR
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import logging
2+
from typing import Any, Callable
3+
4+
import torch
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
def quantize(
10+
model: torch.nn.Module,
11+
quant_format: str,
12+
calibrate_loop: Callable[[], Any],
13+
debug: bool = False,
14+
) -> torch.nn.Module:
15+
try:
16+
import modelopt.torch.quantization as mtq
17+
18+
assert torch.ops.tensorrt.quantize_op.default
19+
except Exception:
20+
logger.warning(
21+
"Unable to import quantization op. Please install modelopt library"
22+
)
23+
24+
if quant_format == "fp8":
25+
quant_cfg = mtq.FP8_DEFAULT_CFG
26+
elif quant_format == "nvfp4":
27+
quant_cfg = mtq.NVFP4_DEFAULT_CFG
28+
else:
29+
raise RuntimeError("Unsupported quantization format")
30+
31+
quantized_model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
32+
if debug:
33+
mtq.print_quant_summary(quantized_model)
34+
35+
return quantized_model

tools/llm/quantize_utils.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import huggingface_hub
66
import torch
7+
import torch_tensorrt
78
from huggingface_hub import snapshot_download
89

910
logger = logging.getLogger(__name__)
@@ -51,17 +52,11 @@ def quantize_model(model, args, tokenizer):
5152
num_samples=512,
5253
device="cuda:0",
5354
)
54-
if args.quant_format == "fp8":
55-
quant_cfg = mtq.FP8_DEFAULT_CFG
56-
elif args.quant_format == "nvfp4":
57-
quant_cfg = mtq.NVFP4_DEFAULT_CFG
58-
else:
59-
raise RuntimeError("Unsupported quantization format")
60-
calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
6155

62-
model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
63-
if args.debug:
64-
mtq.print_quant_summary(model)
56+
calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
57+
model = torch_tensorrt.dynamo.quantize(
58+
model, args.quant_format, calibrate_loop, debug=args.debug
59+
)
6560

6661
return model
6762

0 commit comments

Comments
 (0)