Skip to content

Commit efa63db

Browse files
committed
chore: address reviews
1 parent bd66d44 commit efa63db

File tree

6 files changed

+63
-71
lines changed

6 files changed

+63
-71
lines changed

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
save_cross_compiled_exported_program,
1616
)
1717
from ._exporter import export
18-
from ._quantization import quantize
1918
from ._refit import refit_module_weights
2019
from ._settings import CompilationSettings
2120
from ._SourceIR import SourceIR

py/torch_tensorrt/dynamo/_quantization.py

Lines changed: 0 additions & 35 deletions
This file was deleted.

tools/llm/README.md

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ We have officially verified support for the following models:
3939
#### Text-only LLMs: `run_llm.py`
4040

4141
```bash
42+
python run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --model_precision FP16 --num_tokens 128 --cache static_v2 --benchmark
4243
python run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --precision FP16 --num_tokens 128 --cache static_v2 --benchmark
4344
```
4445

@@ -54,8 +55,8 @@ python run_vlm.py --model nvidia/Eagle2-2B --precision FP16 --num_tokens 128 --c
5455
- `--tokenizer`: (Optional) Tokenizer name; defaults to model.
5556
- `--prompt`: Input prompt for generation.
5657
- `--image_path`: (Optional) Path to input image file for VLM models. If not provided, will use a sample image.
57-
- `--precision`: Precision mode (`FP16`, `FP32`).
58-
- `--quant_format`: Quantization format (`fp8`, `nvfp4`) to apply.
58+
- `--model_precision`: Precision of model weight/buffer (`FP16`, `BF16`, `FP32`).
59+
- `--quant_format`: (Optional) Quantization format (`fp8`, `nvfp4`) to apply.
5960
- `--num_tokens`: Number of output tokens to generate.
6061
- `--cache`: KV cache type (`static_v1`, `static_v2`, or empty for no KV caching).
6162
- `--benchmark`: Enable benchmarking mode.
@@ -68,17 +69,26 @@ Torch-TensorRT supports quantization to reduce model memory footprint and improv
6869
#### Using Pre-quantized Models
6970

7071
To use pre-quantized models from HuggingFace:
72+
If a model contains quantization configuration (detected automatically), the model's linear layers are converted to TensorRT quantized versions using the specified quantization algorithm (e.g., FP8, NVFP4). The quantization algorithm type is displayed during conversion.
73+
74+
**Note:** The `--quant_format` option will raise an error if it's used with pre-quantized models, as quantization cannot be applied to models that are already quantized.
7175

7276
```bash
73-
python run_llm.py --model nvidia/Llama-3.1-8B-Instruct-FP8 --prompt "What is parallel programming?" --precision FP16 --num_tokens 128
77+
python run_llm.py --model nvidia/Llama-3.1-8B-Instruct-FP8 --prompt "What is parallel programming?" --model_precision FP16 --num_tokens 128
78+
```
79+
80+
**Expected output:**
81+
```
82+
Model is FP8 pre-quantized hf model. Quantized linear layers are applied
7483
```
7584

7685
#### Applying quantization by ModelOpt
7786

78-
Apply fp8 quantization from HuggingFace:
87+
To apply quantization to non-quantized models using ModelOpt:
88+
The `--quant_format` option calls `mtq.quantize()` to apply ModelOpt post-training quantization to the model.
7989

8090
```bash
81-
python run_llm.py --model meta-llama/Llama-3.1-8B --quant_format fp8 --prompt "What is parallel programming?" --precision FP16 --num_tokens 128
91+
python run_llm.py --model meta-llama/Llama-3.1-8B --quant_format fp8 --prompt "What is parallel programming?" --model_precision FP16 --num_tokens 128
8292
```
8393

8494
#### Quantization Requirements

tools/llm/quantize_utils.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import huggingface_hub
66
import torch
7-
import torch_tensorrt
87
from huggingface_hub import snapshot_download
98

109
logger = logging.getLogger(__name__)
@@ -25,6 +24,11 @@
2524
)
2625
from safetensors import safe_open
2726

27+
# FP8 E4M3 format has a maximum representable value of 448.0
28+
MAX_BOUND_FP8 = 448.0
29+
# Additional scaling factor for NVFP4
30+
MAX_BOUND_NVFP4 = 6.0
31+
2832

2933
def quantize_model(model, args, tokenizer):
3034
"""
@@ -52,11 +56,17 @@ def quantize_model(model, args, tokenizer):
5256
num_samples=512,
5357
device="cuda:0",
5458
)
55-
59+
if args.quant_format == "fp8":
60+
quant_cfg = mtq.FP8_DEFAULT_CFG
61+
elif args.quant_format == "nvfp4":
62+
quant_cfg = mtq.NVFP4_DEFAULT_CFG
63+
else:
64+
raise RuntimeError("Unsupported quantization format")
5665
calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
57-
model = torch_tensorrt.dynamo.quantize(
58-
model, args.quant_format, calibrate_loop, debug=args.debug
59-
)
66+
67+
model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
68+
if args.debug:
69+
mtq.print_quant_summary(model)
6070

6171
return model
6272

@@ -83,12 +93,6 @@ def __init__(
8393
# Store reference to original linear layer for weight access
8494
self.original_linear = original_linear
8595

86-
# Copy bias from original layer if it exists
87-
if original_linear.bias is not None:
88-
self.bias = torch.nn.Parameter(original_linear.bias.clone()).cuda()
89-
else:
90-
self.bias = None
91-
9296
# Create quantizers for input and weight tensors
9397
self.input_quantizer = TensorQuantizer(
9498
quant_attribute_cfg=quant_cfg, amax=input_amax
@@ -100,7 +104,7 @@ def __init__(
100104
def forward(self, input):
101105
input = self.input_quantizer(input)
102106
weight = self.weight_quantizer(self.original_linear.weight)
103-
return torch.nn.functional.linear(input, weight, self.bias)
107+
return torch.nn.functional.linear(input, weight, self.original_linear.bias)
104108

105109

106110
def load_quantization_config(model_name):
@@ -134,7 +138,7 @@ def load_quantization_config(model_name):
134138
return hf_quant_config
135139

136140

137-
def convert_linear_to_tensorrt_quantized(model, hf_quant_config):
141+
def convert_linear_to_tensorrt_quantized(model, model_precision, hf_quant_config):
138142
"""
139143
Convert linear layers in a model to TensorRT quantized versions from pre-quantized weights.
140144
@@ -177,6 +181,13 @@ def convert_linear_to_tensorrt_quantized(model, hf_quant_config):
177181
if hf_quant_algo != "FP8" and hf_quant_algo != "NVFP4":
178182
raise RuntimeError("Only FP8 or NVFP4 quantization is supported")
179183

184+
if model_precision == "FP16":
185+
weight_dtype = torch.float16
186+
elif model_precision == "BF16":
187+
weight_dtype = torch.bfloat16
188+
else:
189+
weight_dtype = torch.float32
190+
180191
# Iterate through all modules in the model
181192
for name, module in model.named_modules():
182193
# Check if the module is a linear layer
@@ -195,14 +206,13 @@ def convert_linear_to_tensorrt_quantized(model, hf_quant_config):
195206
continue
196207

197208
if hf_quant_algo == "FP8":
198-
# FP8 E4M3 format has a maximum representable value of 448.0
199209
# Scale the quantization parameters accordingly
200210
weight_scale = tensors.pop(weight_scale_name)
201-
weight_amax = weight_scale * 448.0
202-
input_amax = tensors.pop(input_scale_name) * 448.0
211+
weight_amax = weight_scale * MAX_BOUND_FP8
212+
input_amax = tensors.pop(input_scale_name) * MAX_BOUND_FP8
203213

204214
# Dequantize the weight using the scale factor
205-
dequantized_weight_data = module.weight.to(torch.float32) * weight_scale
215+
dequantized_weight_data = module.weight.to(weight_dtype) * weight_scale
206216

207217
# Configure quantizer for FP8 format (4 exponent bits, 3 mantissa bits)
208218
quantizer_attribute_config = QuantizerAttributeConfig(
@@ -218,15 +228,15 @@ def convert_linear_to_tensorrt_quantized(model, hf_quant_config):
218228
weight_scale2 = tensors.pop(weight_scale2_name)
219229

220230
# Calculate amax values with additional scaling factor for NVFP4
221-
input_amax = input_scale * 448.0 * 6.0
222-
weight_amax = weight_scale2 * 448.0 * 6.0
231+
input_amax = input_scale * MAX_BOUND_FP8 * MAX_BOUND_NVFP4
232+
weight_amax = weight_scale2 * MAX_BOUND_FP8 * MAX_BOUND_NVFP4
223233

224234
# Handle NVFP4 tensor format
225235
weight_data = tensors.pop(weight_name)
226236
original_shape = list(weight_data.shape)
227237
original_shape[-1] *= 2 # NVFP4 packs 2 values per element
228238
nvfp4_tensor = NVFP4QTensor(
229-
torch.Size(original_shape), torch.float32, weight_data
239+
torch.Size(original_shape), weight_dtype, weight_data
230240
)
231241

232242
# Dequantize using both scales and block size configuration

tools/llm/run_llm.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,20 @@ def get_model(args):
7171

7272
hf_quant_config = load_quantization_config(args.model)
7373
if hf_quant_config:
74-
model = convert_linear_to_tensorrt_quantized(model, hf_quant_config).cuda()
75-
print(f"Model converted to TensorRT quantized")
74+
model = convert_linear_to_tensorrt_quantized(
75+
model, args.model_precision, hf_quant_config
76+
).cuda()
77+
print(
78+
f"Model is {hf_quant_config['quant_algo']} pre-quantized hf model. Quantized linear layers are applied"
79+
)
80+
if args.quant_format:
81+
raise RuntimeError(
82+
f"Quantization cannot be applied for pre-quantized hf model"
83+
)
7684

77-
if args.precision == "FP16":
85+
if args.model_precision == "FP16":
7886
model = model.to(torch.float16)
79-
elif args.precision == "BF16":
87+
elif args.model_precision == "BF16":
8088
model = model.to(torch.bfloat16)
8189
else:
8290
model = model.to(torch.float32)
@@ -112,11 +120,11 @@ def compile_torchtrt(model, input_ids, args):
112120
# Set precision specific flags
113121
use_fp32_acc = False
114122
use_explicit_typing = False
115-
if args.precision == "FP16":
123+
if args.model_precision == "FP16":
116124
enabled_precisions = {torch.float32}
117125
use_fp32_acc = True
118126
use_explicit_typing = True
119-
elif args.precision == "BF16":
127+
elif args.model_precision == "BF16":
120128
enabled_precisions = {torch.bfloat16}
121129
use_fp32_acc = False
122130
else:
@@ -204,7 +212,7 @@ def measure_perf(trt_model, input_signature, backend_name):
204212
"--prompt", type=str, default="What is parallel programming ?", help="Prompt"
205213
)
206214
arg_parser.add_argument(
207-
"--precision",
215+
"--model_precision",
208216
type=str,
209217
default="FP16",
210218
help="Precision to use in the model. Options: FP16, BF16, FP32",
@@ -299,7 +307,7 @@ def measure_perf(trt_model, input_signature, backend_name):
299307
pyt_stats = record_stats(
300308
"PyTorch",
301309
pyt_timings,
302-
args.precision,
310+
args.model_precision,
303311
batch_size=args.batch_size,
304312
compile_time_s=None,
305313
)
@@ -357,7 +365,7 @@ def measure_perf(trt_model, input_signature, backend_name):
357365
trt_stats = record_stats(
358366
"TensorRT",
359367
trt_timings,
360-
args.precision,
368+
args.model_precision,
361369
batch_size=args.batch_size,
362370
compile_time_s=None,
363371
)

tools/llm/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def record_stats(backend, timings, precision, batch_size=1, compile_time_s=None)
247247

248248
stats = {
249249
"Backend": backend,
250-
"Precision": precision,
250+
"Model Precision": precision,
251251
"Batch size": batch_size,
252252
"Median(FPS)": speed_med,
253253
"Mean(FPS)": speed_mean,

0 commit comments

Comments
 (0)