Skip to content

Commit d25b50c

Browse files
committed
chore: Detect pre-quantized hf model
1 parent 25c17ef commit d25b50c

File tree

3 files changed

+71
-69
lines changed

3 files changed

+71
-69
lines changed

tools/llm/README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ python run_vlm.py --model nvidia/Eagle2-2B --precision FP16 --num_tokens 128 --c
5555
- `--prompt`: Input prompt for generation.
5656
- `--image_path`: (Optional) Path to input image file for VLM models. If not provided, will use a sample image.
5757
- `--precision`: Precision mode (`FP16`, `FP32`).
58-
- `--qformat`: Quantization format (`fp8`, `nvfp4`) to apply.
59-
- `--pre_quantized`: Flag to use pre-quantized models from HuggingFace.
58+
- `--quant_format`: Quantization format (`fp8`, `nvfp4`) to apply.
6059
- `--num_tokens`: Number of output tokens to generate.
6160
- `--cache`: KV cache type (`static_v1`, `static_v2`, or empty for no KV caching).
6261
- `--benchmark`: Enable benchmarking mode.
@@ -71,15 +70,15 @@ Torch-TensorRT supports quantization to reduce model memory footprint and improv
7170
To use pre-quantized models from HuggingFace:
7271

7372
```bash
74-
python run_llm.py --model nvidia/Llama-3.1-8B-Instruct-FP8 --pre_quantized --prompt "What is parallel programming?" --precision FP16 --num_tokens 128
73+
python run_llm.py --model nvidia/Llama-3.1-8B-Instruct-FP8 --prompt "What is parallel programming?" --precision FP16 --num_tokens 128
7574
```
7675

7776
#### Applying quantization by ModelOpt
7877

7978
Apply fp8 quantization from HuggingFace:
8079

8180
```bash
82-
python run_llm.py --model meta-llama/Llama-3.1-8B --qformat fp8 --prompt "What is parallel programming?" --precision FP16 --num_tokens 128
81+
python run_llm.py --model meta-llama/Llama-3.1-8B --quant_format fp8 --prompt "What is parallel programming?" --precision FP16 --num_tokens 128
8382
```
8483

8584
#### Quantization Requirements

tools/llm/quantize_utils.py

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,22 @@
2727

2828
def quantize_model(model, args, tokenizer):
2929
"""
30-
Quantize a PyTorch model using ModelOpt quantization.
30+
Quantize a PyTorch model using ModelOpt post-training quantization (PTQ).
3131
32-
This function performs post-training quantization (PTQ) on the model using
33-
calibration data from the provided tokenizer. It supports both FP8 and NVFP4
34-
quantization formats.
32+
This function applies quantization to reduce model precision for faster inference
33+
while maintaining acceptable accuracy. It uses calibration data generated from
34+
the provided tokenizer to determine optimal quantization parameters.
3535
36+
Supported quantization formats:
37+
- fp8: 8-bit floating point quantization
38+
- nvfp4: 4-bit NVIDIA floating point quantization
3639
Args:
37-
model: PyTorch model to quantize
38-
args: Arguments containing quantization format and debug settings
39-
tokenizer: Tokenizer for creating calibration dataloader
40+
model: PyTorch model to quantize. Must be in evaluation mode.
41+
args: Command line arguments containing quant_format and debug
42+
tokenizer: Hugging Face tokenizer for creating calibration data
4043
4144
Returns:
42-
Quantized model with reduced precision weights and activations
43-
44-
Raises:
45-
RuntimeError: If unsupported quantization format is specified
45+
Quantized model
4646
"""
4747
# Create calibration dataloader for quantization
4848
calib_dataloader = get_dataset_dataloader(
@@ -51,9 +51,9 @@ def quantize_model(model, args, tokenizer):
5151
num_samples=512,
5252
device="cuda:0",
5353
)
54-
if args.qformat == "fp8":
54+
if args.quant_format == "fp8":
5555
quant_cfg = mtq.FP8_DEFAULT_CFG
56-
elif args.qformat == "nvfp4":
56+
elif args.quant_format == "nvfp4":
5757
quant_cfg = mtq.NVFP4_DEFAULT_CFG
5858
else:
5959
raise RuntimeError("Unsupported quantization format")
@@ -108,7 +108,38 @@ def forward(self, input):
108108
return torch.nn.functional.linear(input, weight, self.bias)
109109

110110

111-
def convert_linear_to_tensorrt_quantized(model, model_name):
111+
def load_quantization_config(model_name):
112+
"""
113+
Load quantization configuration from a Hugging Face model.
114+
Args:
115+
model_name (str): Local directory path or model identifier
116+
Returns:
117+
dict or None: Quantization configuration. None if no config found.
118+
"""
119+
# Determine if model_name is a local directory or needs to be downloaded
120+
if os.path.isdir(model_name):
121+
model_path = model_name
122+
else:
123+
# Download model from Hugging Face Hub
124+
model_path = snapshot_download(
125+
model_name,
126+
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
127+
ignore_patterns=["original/**/*"],
128+
revision=None,
129+
)
130+
hf_quant_config = None
131+
# Load and parse quantization configuration
132+
hf_quant_config_path = f"{model_path}/hf_quant_config.json"
133+
if os.path.exists(hf_quant_config_path):
134+
with open(hf_quant_config_path, "r") as f:
135+
hf_quant_config = json.load(f)
136+
hf_quant_config = hf_quant_config["quantization"]
137+
hf_quant_config["model_path"] = model_path
138+
139+
return hf_quant_config
140+
141+
142+
def convert_linear_to_tensorrt_quantized(model, hf_quant_config):
112143
"""
113144
Convert linear layers in a model to TensorRT quantized versions from pre-quantized weights.
114145
@@ -119,58 +150,37 @@ def convert_linear_to_tensorrt_quantized(model, model_name):
119150
120151
The function:
121152
1. Loads quantization scales from Hugging Face model files (SafeTensors)
122-
2. Parses quantization configuration from hf_quant_config.json
123-
3. Replaces standard linear layers with TensorRTQuantizedLinear layers
124-
4. Applies appropriate quantization based on the model's quantization format
153+
2. Replaces standard linear layers with TensorRTQuantizedLinear layers
154+
3. Applies appropriate quantization based on the model's quantization format
125155
126156
Note: This function only quantizes linear operations and is intended for use
127157
with pre-quantized Hugging Face models that have been quantized using ModelOpt.
128158
129159
Args:
130160
model: PyTorch model to quantize
131-
model_name: Path to Hugging Face model directory or model identifier
161+
hf_quant_config: Quantization configuration
132162
133163
Returns:
134164
Model with quantized linear layers
135165
136166
Raises:
137167
RuntimeError: If quantization config is not found or unsupported format
138168
"""
139-
# Determine if model_name is a local directory or needs to be downloaded
140-
if os.path.isdir(model_name):
141-
hf_folder = model_name
142-
else:
143-
# Download model from Hugging Face Hub
144-
hf_folder = snapshot_download(
145-
model_name,
146-
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
147-
ignore_patterns=["original/**/*"],
148-
revision=None,
149-
)
150-
169+
model_path = hf_quant_config["model_path"]
151170
# Load all tensors from SafeTensors files
152171
tensors = {}
153-
for file in os.listdir(hf_folder):
172+
for file in os.listdir(model_path):
154173
if file.endswith(".safetensors"):
155174
with safe_open(
156-
os.path.join(hf_folder, file), framework="pt", device="cpu"
175+
os.path.join(model_path, file), framework="pt", device="cpu"
157176
) as f:
158177
tensor_names = f.keys()
159178
for name in tensor_names:
160179
tensors[name] = f.get_tensor(name)
161180

162-
# Load and parse quantization configuration
163-
hf_quant_config_path = f"{hf_folder}/hf_quant_config.json"
164-
if os.path.exists(hf_quant_config_path):
165-
with open(hf_quant_config_path, "r") as f:
166-
hf_quant_config = json.load(f)
167-
hf_quant_config = hf_quant_config["quantization"]
168-
169-
hf_quant_algo = hf_quant_config.pop("quant_algo", None)
170-
if hf_quant_algo != "FP8" and hf_quant_algo != "NVFP4":
171-
raise RuntimeError("Only FP8 or NVFP4 quantization is supported")
172-
else:
173-
raise RuntimeError("No quantization config found")
181+
hf_quant_algo = hf_quant_config.get("quant_algo", None)
182+
if hf_quant_algo != "FP8" and hf_quant_algo != "NVFP4":
183+
raise RuntimeError("Only FP8 or NVFP4 quantization is supported")
174184

175185
# Iterate through all modules in the model
176186
for name, module in model.named_modules():

tools/llm/run_llm.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@
1919
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2020
import torch
2121
import torch_tensorrt
22+
from modelopt.torch.quantization.utils import export_torch_mode
23+
from quantize_utils import (
24+
convert_linear_to_tensorrt_quantized,
25+
load_quantization_config,
26+
quantize_model,
27+
)
2228
from torchtrt_ext import register_sdpa
2329
from transformers import AutoModelForCausalLM, AutoTokenizer
2430
from utils import (
@@ -62,8 +68,11 @@ def get_model(args):
6268
)
6369
# register SDPA variant for the model
6470
register_sdpa.enable_sdpa_converter(args.model, model.config)
65-
if args.pre_quantized:
66-
model = convert_linear_to_tensorrt_quantized(model, args.model).cuda()
71+
72+
hf_quant_config = load_quantization_config(args.model)
73+
if hf_quant_config:
74+
model = convert_linear_to_tensorrt_quantized(model, hf_quant_config).cuda()
75+
print(f"Model converted to TensorRT quantized")
6776

6877
if args.precision == "FP16":
6978
model = model.to(torch.float16)
@@ -97,7 +106,7 @@ def compile_torchtrt(model, input_ids, args):
97106
for optimized inference
98107
"""
99108
max_seq_len = input_ids.shape[1] + args.num_tokens
100-
with export_torch_mode() if args.qformat or args.pre_quantized else nullcontext():
109+
with export_torch_mode():
101110
ep = export_llm(model, input_ids, max_seq_len=max_seq_len)
102111
position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE)
103112
# Set precision specific flags
@@ -242,28 +251,12 @@ def measure_perf(trt_model, input_signature, backend_name):
242251
"--benchmark", action="store_true", help="Enable benchmark (default: False)"
243252
)
244253
arg_parser.add_argument(
245-
"--qformat",
254+
"--quant_format",
246255
help=("Apply quantization format. Options: fp8, nvfp4 (default: None)"),
247256
default=None,
248257
)
249-
arg_parser.add_argument(
250-
"--pre_quantized",
251-
action="store_true",
252-
help="Use pre-quantized hf model weights (default: False)",
253-
)
254258
args = arg_parser.parse_args()
255259

256-
if args.qformat and args.pre_quantized:
257-
print("Error: --qformat and --pre_quantized cannot be used together")
258-
exit()
259-
260-
if args.qformat or args.pre_quantized:
261-
from modelopt.torch.quantization.utils import export_torch_mode
262-
from quantize_utils import (
263-
convert_linear_to_tensorrt_quantized,
264-
quantize_model,
265-
)
266-
267260
with torch.inference_mode():
268261
model = get_model(args)
269262

@@ -288,7 +281,7 @@ def measure_perf(trt_model, input_signature, backend_name):
288281
pyt_timings = None
289282
pyt_stats = None
290283

291-
if args.qformat != None:
284+
if args.quant_format != None:
292285
model = quantize_model(model, args, tokenizer)
293286
if args.enable_pytorch_run:
294287
pyt_gen_tokens = generate(

0 commit comments

Comments
 (0)