diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index ce3fb0853..082b9aab1 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -20,12 +20,21 @@ import sys import warnings from pathlib import Path +from typing import Any, cast import torch import transformers from accelerate import infer_auto_device_map, init_empty_weights from accelerate.utils import get_max_memory -from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer +from omegaconf import OmegaConf +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoProcessor, + AutoTokenizer, + PreTrainedTokenizerBase, + ProcessorMixin, +) try: from huggingface_hub import snapshot_download @@ -33,7 +42,8 @@ snapshot_download = None import modelopt.torch.quantization as mtq -from modelopt.torch.utils.image_processor import MllamaImageProcessor +from modelopt.torch.quantization.config import load_quant_config +from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"] @@ -127,52 +137,49 @@ def build_quant_cfg( qformat, kv_cache_qformat, awq_block_size, - auto_quantize, model_type, quant_cfg_choices, kv_quant_cfg_choices, -): - quant_cfg = {} - if not auto_quantize: - assert qformat in quant_cfg_choices, ( - f"Unsupported quantization format: {qformat} with {kv_cache_qformat} KV cache" - ) - +) -> dict[str, Any]: + if qformat in quant_cfg_choices: quant_cfg = quant_cfg_choices[qformat] + else: + quant_cfg = OmegaConf.to_container(load_quant_config(qformat)) + + quant_cfg = cast("dict[str, Any]", quant_cfg) + if "awq" in quant_cfg.get("algorithm"): + quant_cfg = copy.deepcopy(quant_cfg) + weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] + if isinstance(weight_quantizer, list): + weight_quantizer = weight_quantizer[0] + # If awq_block_size argument is provided, update weight_quantizer + if awq_block_size: + weight_quantizer["block_sizes"][-1] = awq_block_size + + # Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models + if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]: + quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1} + + enable_quant_kv_cache = kv_cache_qformat != "none" + print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") + + # Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer. + if enable_quant_kv_cache: + quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant( + quant_cfg, + getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"], + ) - if "awq" in qformat: - quant_cfg = copy.deepcopy(quant_cfg_choices[qformat]) - weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] - if isinstance(weight_quantizer, list): - weight_quantizer = weight_quantizer[0] - # If awq_block_size argument is provided, update weight_quantizer - if awq_block_size: - weight_quantizer["block_sizes"][-1] = awq_block_size - - # Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models - if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]: - quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1} - - enable_quant_kv_cache = kv_cache_qformat != "none" - print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") - - # Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer. - if enable_quant_kv_cache: - quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant( - quant_cfg, - getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"], - ) - - # Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead. - if model_type == "gemma" and "int8_sq" in qformat: - quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5} + # Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead. + if model_type == "gemma" and "int8_sq" in qformat: + quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5} - if model_type == "phi4mm": - # Only quantize the language model - quant_cfg["quant_cfg"]["*speech*"] = {"enable": False} - quant_cfg["quant_cfg"]["*audio*"] = {"enable": False} - quant_cfg["quant_cfg"]["*image*"] = {"enable": False} - quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} + if model_type == "phi4mm": + # Only quantize the language model + quant_cfg["quant_cfg"]["*speech*"] = {"enable": False} + quant_cfg["quant_cfg"]["*audio*"] = {"enable": False} + quant_cfg["quant_cfg"]["*image*"] = {"enable": False} + quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} return quant_cfg @@ -184,7 +191,7 @@ def is_speculative(hf_config): ) -def get_tokenizer(ckpt_path, trust_remote_code=False, **kwargs): +def get_tokenizer(ckpt_path, trust_remote_code=False, **kwargs) -> PreTrainedTokenizerBase: print(f"Initializing tokenizer from {ckpt_path}") if "vila" in ckpt_path.lower(): @@ -205,8 +212,12 @@ def get_tokenizer(ckpt_path, trust_remote_code=False, **kwargs): def get_processor( - ckpt_path, model_type, device=None, trust_remote_code=False, attn_implementation=None -): + ckpt_path, + model_type, + device: torch.device = "auto", + trust_remote_code=False, + attn_implementation=None, +) -> BaseImageProcessor | ProcessorMixin | None: """ Returns a :class:`modelopt.torch.utils.image_processor.MllamaImageProcessor` object. """ @@ -241,6 +252,8 @@ def get_processor( return MllamaImageProcessor(processor, device) + return None + def get_dtype(dtype): if dtype == "bf16": diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 57f0b5a89..97d833605 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -32,12 +32,15 @@ is_nemotron_vl, run_nemotron_vl_preview, ) +from torch.utils.data import DataLoader from transformers import ( AutoConfig, AutoModelForCausalLM, AutoProcessor, PreTrainedTokenizer, + PreTrainedTokenizerBase, PreTrainedTokenizerFast, + ProcessorMixin, WhisperProcessor, ) @@ -59,7 +62,7 @@ get_max_batch_size, get_supported_datasets, ) -from modelopt.torch.utils.image_processor import MllamaImageProcessor +from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor from modelopt.torch.utils.memory_monitor import launch_memory_monitor from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader @@ -94,22 +97,82 @@ mto.enable_huggingface_checkpointing() +def make_calib_dataloader( + args: argparse.Namespace, + language_model: torch.nn.Module, + processor: BaseImageProcessor | ProcessorMixin | None, + tokenizer: PreTrainedTokenizerBase | None, + device: torch.device, + model_type: str | None, +) -> tuple[DataLoader, str | None]: + calib_dataloader = None + first_text_speech_dataset = None + if model_type == "mllama": + assert processor is not None and isinstance(processor, MllamaImageProcessor), ( + "The MllamaImageProcessor must be set." + ) + assert len(args.calib_size) == 1, ( + "mllama only supports one dataset for calibration, can extend this in the future" + ) + calib_dataloader = get_vlm_dataset_dataloader( + dataset_name=args.dataset[0] if args.dataset else "scienceqa", + processor=processor, + batch_size=args.batch_size, + num_samples=args.calib_size[0], + ) + elif model_type == "whisper": + assert processor is not None and isinstance(processor, WhisperProcessor), ( + "The AutoProcessor must be set." + ) + assert len(args.calib_size) == 1, ( + "whisper only supports one dataset for calibration, can extend this in the future" + ) + calib_dataloader, first_text_speech_dataset = get_speech_dataset_dataloader( + dataset_name=args.dataset[0] if args.dataset else "peoples_speech", + processor=processor, + batch_size=args.batch_size, + num_samples=args.calib_size[0], + device=device, + dtype=language_model.dtype, + ) + else: + assert tokenizer is not None and isinstance( + tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast) + ), "The PreTrainedTokenizer must be set" + # Labels are only needed for gradient-based auto_quantize + include_labels = ( + args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient" + ) + calib_dataloader = get_dataset_dataloader( + dataset_name=args.dataset, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_samples=args.calib_size, + device=device, + include_labels=include_labels, + ) + return calib_dataloader, first_text_speech_dataset + + def auto_quantize( - model, - qformat, - calib_dataloader, - calibrate_loop, - auto_quantize_bits, - batch_size=1, + args: argparse.Namespace, + language_model: torch.nn.Module, + calib_dataloader: DataLoader, auto_quantize_method="gradient", auto_quantize_score_size=128, auto_quantize_checkpoint=None, ): - qformat_list = qformat.split(",") + """Auto search quantization of multiple formats.""" + + assert not (args.auto_quantize_bits and args.inference_pipeline_parallel > 1), ( + "Auto Quantization is not supported for pipeline parallel size > 1" + ) + + qformat_list = args.qformat.split(",") assert qformat_list, "No quantization formats provided" # Check if all provided quantization formats are supported assert all( - qformat + args.qformat in [ "fp8", "int8_sq", @@ -122,7 +185,7 @@ def auto_quantize( "w4a8_mxfp4_fp8", "nvfp4_mlp_only", ] - for qformat in qformat_list + for args.qformat in qformat_list ), "One or more quantization formats provided are not supported for unified checkpoint export" def loss_func(output, data): @@ -143,9 +206,9 @@ def forward_step(model, batch): f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'" ) - model, _ = mtq.auto_quantize( - model, - constraints={"effective_bits": auto_quantize_bits}, + language_model, _ = mtq.auto_quantize( + language_model, + constraints={"effective_bits": args.auto_quantize_bits}, data_loader=calib_dataloader, forward_step=forward_step, loss_func=loss_func, # Only used for gradient-based method @@ -153,7 +216,9 @@ def forward_step(model, batch): quantization_formats=[QUANT_CFG_CHOICES[format] for format in qformat_list], num_calib_steps=len(calib_dataloader), # AutoQuantize scoring is the costly phase; allow smaller sample counts than calibration. - num_score_steps=min(len(calib_dataloader), max(auto_quantize_score_size // batch_size, 1)), + num_score_steps=min( + len(calib_dataloader), max(auto_quantize_score_size // args.batch_size, 1) + ), verbose=True, # Disable all default disabled layers such as lm_head, mlp.gate, router etc. disabled_layers=list(_default_disabled_quantizer_cfg.keys()), @@ -161,6 +226,7 @@ def forward_step(model, batch): checkpoint=auto_quantize_checkpoint, ) + calibrate_loop = create_forward_loop(dataloader=calib_dataloader) # We need to explicitly calibrate for kv cache quantization enable_quant_kv_cache = args.kv_cache_qformat != "none" print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") @@ -169,110 +235,22 @@ def forward_step(model, batch): kv_cache_quant_cfg.pop("default") # keep other quantizers from auto_quantize mtq.set_quantizer_by_cfg( - model, + language_model, quant_cfg=kv_cache_quant_cfg, ) # Lets calibrate only the quantizers for kv cache quantization this time. Let's disable all others. with mtq.set_quantizer_by_cfg_context( - model, {"*": {"enable": False}, **kv_cache_quant_cfg} + language_model, {"*": {"enable": False}, **kv_cache_quant_cfg} ): - mtq.calibrate(model, algorithm="max", forward_loop=calibrate_loop) - return model - - -def quantize_model(model, quant_cfg, args, calib_dataloader=None, calibration_only=False): - # The calibration loop for the model can be setup using the modelopt API. - # - # Example usage: - # from modelopt.torch.utils.dataset_utils import create_forward_loop - # model = ... # Initialize the model - # tokenizer = ... # Initialize the tokenizer - # quant_cfg = ... # Setup quantization configuration - # forward_loop = create_forward_loop(model=model, dataset_name="cnn_dailymail", tokenizer=tokenizer) - # mtq.quantize(model, quant_cfg, forward_loop=forward_loop) - - # The calibrate_loop is a custom defined method to run the model with the input data. - # The basic version looks like: - # - # def calibrate_loop(model, dataloader): - # for data in dataloader: - # model(**data) - # - # We also provided a util method to generate the forward_loop with additional error handlings. - - use_calibration = args.auto_quantize_bits or need_calibration(quant_cfg) - - if not use_calibration: - warnings.warn("Dynamic quantization. Calibration skipped.") - calibrate_loop = create_forward_loop(dataloader=calib_dataloader) if use_calibration else None - - assert not (args.auto_quantize_bits and args.inference_pipeline_parallel > 1), ( - "Auto Quantization is not supported for pipeline parallel size > 1" - ) - - print("Starting quantization...") - start_time = time.time() - if args.auto_quantize_bits: - model = auto_quantize( - model, - args.qformat, - calib_dataloader, - calibrate_loop, - args.auto_quantize_bits, - args.batch_size, - args.auto_quantize_method, - args.auto_quantize_score_size, - args.auto_quantize_checkpoint, - ) - elif calibration_only: - model = mtq.calibrate(model, quant_cfg["algorithm"], forward_loop=calibrate_loop) - else: - model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - end_time = time.time() - print(f"Quantization done. Total time used: {end_time - start_time}s") - return model - - -def main(args): - if not torch.cuda.is_available(): - raise OSError("GPU is required for inference.") - - random.seed(RAND_SEED) - np.random.seed(RAND_SEED) - - # launch a memory monitor to read the currently used GPU memory. - launch_memory_monitor() + mtq.calibrate(language_model, algorithm="max", forward_loop=calibrate_loop) + return language_model - # Force eager execution for all model types. - torch.compiler.set_stance("force_eager") - - # Check that only one quantization format is provided for non auto_quant case - if not args.auto_quantize_bits: - assert len(args.qformat.split(",")) == 1, ( - "Quantization supports only one quantization format." - ) - - if not args.auto_quantize_bits: - assert ( - args.qformat - in [ - "int8_wo", - "int4_awq", - "fp8", - "nvfp4", - "nvfp4_awq", - "w4a8_awq", - "fp8_pb_wo", - "w4a8_mxfp4_fp8", - "nvfp4_mlp_only", - ] - or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES - ), f"Quantization format {args.qformat} not supported for HF export path" +def load_model(args: argparse.Namespace): # If low memory mode is enabled, we compress the model while loading the HF checkpoint. calibration_only = False if not args.low_memory_mode: - model = get_model( + full_model = get_model( args.pyt_ckpt_path, args.device, gpu_mem_percentage=args.gpu_max_mem_percentage, @@ -287,7 +265,8 @@ def main(args): quant_cfg = QUANT_CFG_CHOICES[args.qformat] if args.kv_cache_qformat != "none": quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( - quant_cfg, getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"] + quant_cfg, + getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"], ) # Do not use real quant GEMM so the calibration can be more accurate. @@ -297,25 +276,21 @@ def main(args): model_kwargs = {"trust_remote_code": args.trust_remote_code} if args.attn_implementation is not None: model_kwargs["attn_implementation"] = args.attn_implementation - model = AutoModelForCausalLM.from_pretrained( + full_model = AutoModelForCausalLM.from_pretrained( args.pyt_ckpt_path, **model_kwargs, ) calibration_only = True - model_is_already_quantized = is_quantized(model) - model_type = get_model_type(model) + model_type = get_model_type(full_model) - device = model.device - if hasattr(model, "model"): - device = model.model.device + device = full_model.device + if hasattr(full_model, "model"): + device = full_model.model.device processor = None tokenizer = None - - full_model = model - - # Detect if this is a Nemotron VL model using architecture-based detection - is_nemotron_vl_model = is_nemotron_vl(full_model) + language_model = full_model + default_padding_side = None if model_type == "mllama": processor = get_processor( @@ -327,7 +302,10 @@ def main(args): ) elif model_type == "whisper": processor = get_processor( - args.pyt_ckpt_path, model_type, device, trust_remote_code=args.trust_remote_code + args.pyt_ckpt_path, + model_type, + device, + trust_remote_code=args.trust_remote_code, ) else: if args.dataset is None: @@ -364,258 +342,127 @@ def main(args): mtq.quantize(module, disabled_quant_cfg, forward_loop=None) memo.add(module) - model = language_model - model_type = get_model_type(model) + model_type = get_model_type(language_model) if model_type == "phi4mm": warnings.warn("Please set the default input_mode to InputMode.LANGUAGE before quantizing.") - if args.sparsity_fmt != "dense": - if args.batch_size == 0: - # Sparse algorithm takes more GPU memory so we reduce the batch_size by 4. - args.batch_size = max(get_max_batch_size(model) // 4, 1) - args.batch_size = min(args.batch_size, sum(args.calib_size)) + return ( + full_model, + language_model, + model_type, + calibration_only, + processor, + tokenizer, + default_padding_side, + device, + ) - print(f"Use calib batch_size {args.batch_size}") - # Different calibration datasets are also available, e.g., "pile" and "wikipedia" - # Please also check the docstring for the datasets available - assert tokenizer is not None and isinstance( - tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast) - ), "The PreTrainedTokenizer must be set" - calib_dataloader = get_dataset_dataloader( - dataset_name=args.dataset, - tokenizer=tokenizer, - batch_size=args.batch_size, - num_samples=args.calib_size, - max_sample_length=args.calib_seq, - device=device, - ) - model = mts.sparsify( - model, - args.sparsity_fmt, - config={"data_loader": calib_dataloader, "collect_func": lambda x: x}, - ) - mts.export(model) +def sparsity_main( + args: argparse.Namespace, + full_model: torch.nn.Module, + tokenizer: PreTrainedTokenizerBase | None, + device: torch.device, +): + if args.batch_size == 0: + # Sparse algorithm takes more GPU memory so we reduce the batch_size by 4. + args.batch_size = max(get_max_batch_size(full_model) // 4, 1) + args.batch_size = min(args.batch_size, sum(args.calib_size)) + + print(f"Use calib batch_size {args.batch_size}") + + # Different calibration datasets are also available, e.g., "pile" and "wikipedia" + # Please also check the docstring for the datasets available + assert tokenizer is not None and isinstance( + tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast) + ), "The PreTrainedTokenizer must be set" + calib_dataloader = get_dataset_dataloader( + dataset_name=args.dataset, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_samples=args.calib_size, + max_sample_length=args.calib_seq, + device=device, + ) + full_model = mts.sparsify( + full_model, + args.sparsity_fmt, + config={"data_loader": calib_dataloader, "collect_func": lambda x: x}, + ) + mts.export(full_model) + + +def plain_quantize( + args: argparse.Namespace, + quant_cfg: dict[str, Any], + full_model: torch.nn.Module, + language_model: torch.nn.Module, + model_type: str | None, + calibration_only: bool, + calib_dataloader: DataLoader, + is_nemotron_vl_model: bool, +): + """Plain quantization of the given language model to a single quantization configuration.""" - if args.auto_quantize_bits or args.qformat in QUANT_CFG_CHOICES: - if "awq" in args.qformat: - print( - "\n####\nAWQ calibration could take longer than other calibration methods. " - "Consider reducing calib_size to reduce calibration time.\n####\n" - ) + model_is_already_quantized = is_quantized(language_model) - if args.batch_size == 0: - # Calibration/sparsification will actually take much more memory than regular inference - # due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio - # to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference. - sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1 - # Whisper model expects mel-spectrogram input features of length 3000 - # Whisper model needs input of shape (batch_size, num_mel_bins, 3000) - # As the encoder of Whisper doesn't have embedding layer, input dtype has to be float - # For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size() - if model_type == "whisper": - max_sample_length = 3000 - num_mel_bins = model.config.num_mel_bins - sample_input_single_batch = ( - torch.ones([1, num_mel_bins, max_sample_length], dtype=model.dtype).to( - model.device - ) - * 100 - ) - else: - sample_input_single_batch = None + if "awq" in args.qformat: + print( + "\n####\nAWQ calibration could take longer than other calibration methods. " + "Consider reducing calib_size to reduce calibration time.\n####\n" + ) - run_auto_quant = args.auto_quantize_bits is not None + # For Nemotron VL models, disable quantization of vision components + if is_nemotron_vl_model: + print("Disabling quantization for vision components in Nemotron VL model") + quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} + quant_cfg["quant_cfg"]["*image*"] = {"enable": False} + # Also disable radio model components specifically + quant_cfg["quant_cfg"]["*radio*"] = {"enable": False} + quant_cfg["quant_cfg"]["*visual*"] = {"enable": False} - args.batch_size = get_max_batch_size( - model, - max_sample_length=args.calib_seq, - sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0, - sample_input_single_batch=sample_input_single_batch, - enable_grad=run_auto_quant, - ) - args.batch_size = min(args.batch_size, sum(args.calib_size)) + if not model_is_already_quantized or calibration_only: + # quantize the model - print(f"Use calib batch_size {args.batch_size}") + use_calibration = need_calibration(quant_cfg) - calib_dataloader = None - if model_type == "mllama": - assert processor is not None and isinstance(processor, MllamaImageProcessor), ( - "The MllamaImageProcessor must be set." - ) - assert len(args.calib_size) == 1, ( - "mllama only supports one dataset for calibration, can extend this in the future" - ) - calib_dataloader = get_vlm_dataset_dataloader( - dataset_name=args.dataset[0] if args.dataset else "scienceqa", - processor=processor, - batch_size=args.batch_size, - num_samples=args.calib_size[0], - ) - elif model_type == "whisper": - assert processor is not None and isinstance(processor, WhisperProcessor), ( - "The AutoProcessor must be set." - ) - assert len(args.calib_size) == 1, ( - "whisper only supports one dataset for calibration, can extend this in the future" - ) - calib_dataloader, first_text = get_speech_dataset_dataloader( - dataset_name=args.dataset[0] if args.dataset else "peoples_speech", - processor=processor, - batch_size=args.batch_size, - num_samples=args.calib_size[0], - device=device, - dtype=model.dtype, + if not use_calibration: + warnings.warn("Dynamic quantization. Calibration skipped.") + calibrate_loop = ( + create_forward_loop(dataloader=calib_dataloader) if use_calibration else None + ) + + if calibration_only: + language_model = mtq.calibrate( + language_model, quant_cfg["algorithm"], forward_loop=calibrate_loop ) else: - assert tokenizer is not None and isinstance( - tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast) - ), "The PreTrainedTokenizer must be set" - # Labels are only needed for gradient-based auto_quantize - include_labels = ( - args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient" - ) - calib_dataloader = get_dataset_dataloader( - dataset_name=args.dataset, - tokenizer=tokenizer, - batch_size=args.batch_size, - num_samples=args.calib_size, - device=device, - include_labels=include_labels, - ) - - quant_cfg = build_quant_cfg( - args.qformat, - args.kv_cache_qformat, - args.awq_block_size, - args.auto_quantize_bits, - model_type, - QUANT_CFG_CHOICES, - KV_QUANT_CFG_CHOICES, - ) + language_model = mtq.quantize(language_model, quant_cfg, forward_loop=calibrate_loop) - # For Nemotron VL models, disable quantization of vision components + # For VL models, update full_model to use the quantized language model if is_nemotron_vl_model: - print("Disabling quantization for vision components in Nemotron VL model") - quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} - quant_cfg["quant_cfg"]["*image*"] = {"enable": False} - # Also disable radio model components specifically - quant_cfg["quant_cfg"]["*radio*"] = {"enable": False} - quant_cfg["quant_cfg"]["*visual*"] = {"enable": False} - - if not model_is_already_quantized or calibration_only: - # Only run single sample for preview - input_ids = next(iter(calib_dataloader))[ - "input_features" if model_type == "whisper" else "input_ids" - ][0:1] - - # Generate preview before quantization - if is_nemotron_vl_model and tokenizer is not None: - generated_ids_before_ptq = run_nemotron_vl_preview( - full_model, - tokenizer, - input_ids, - args.pyt_ckpt_path, - "before quantization", - allow_fallback=True, - ) - else: - # Standard generation for non-Nemotron VL models - generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) - if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only": - print("Applying nvfp4 quantization (MoE only) for gpt-oss") - - # quantize the model - model = quantize_model(model, quant_cfg, args, calib_dataloader, calibration_only) - - # For VL models, update full_model to use the quantized language model - if is_nemotron_vl_model: - language_model_lineage = get_language_model_from_vl(full_model) - if language_model_lineage is not None: - print("Updating full_model with quantized language_model...") - language_model_lineage[-2].language_model = model - - if args.verbose: - mtq.print_quant_summary(full_model) - - # Run some samples - torch.cuda.empty_cache() - generated_ids_after_ptq = None - if model_type != "llama4" and not is_nemotron_vl_model: - # Our fake quantizer may not be fully compatible with torch.compile. - generated_ids_after_ptq = full_model.generate(input_ids, max_new_tokens=100) - elif is_nemotron_vl_model and tokenizer is not None: - generated_ids_after_ptq = run_nemotron_vl_preview( - full_model, - tokenizer, - input_ids, - args.pyt_ckpt_path, - "after quantization", - allow_fallback=False, - ) - else: - warnings.warn( - "Llama4 Maverick generation after quantization has a bug. Skipping generation sample." - ) - - def input_decode(input_ids): - if processor is not None and isinstance(processor, MllamaImageProcessor): - return processor.tokenizer.batch_decode(input_ids) - elif processor is not None and isinstance(processor, WhisperProcessor): - return first_text - elif tokenizer is not None: - return tokenizer.batch_decode(input_ids) - else: - raise ValueError("The processor or tokenizer must be set") - - def output_decode(generated_ids, input_shape): - if is_enc_dec(model_type): - if processor is not None and isinstance(processor, WhisperProcessor): - return processor.tokenizer.batch_decode( - generated_ids, skip_special_tokens=True - )[0] - elif tokenizer is not None: - return tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - elif processor is not None and isinstance(processor, MllamaImageProcessor): - return processor.tokenizer.batch_decode(generated_ids[:, input_shape:]) - elif tokenizer is not None: - return tokenizer.batch_decode(generated_ids[:, input_shape:]) - else: - raise ValueError("The processor or tokenizer must be set") - - if generated_ids_after_ptq is not None: - print("--------") - if is_nemotron_vl_model: - # For Nemotron VL models, generated_ids are text strings from model.chat() - print("Nemotron VL model text-only generation results:") - print(f"Text response before quantization: {generated_ids_before_ptq}") - print("--------") - print(f"Text response after quantization: {generated_ids_after_ptq}") - print("--------") - print("Note: Additional VL tests with images were run separately above") - else: - # For regular LLMs, generated_ids are token tensors that need decoding - print(f"example test input: {input_decode(input_ids)}") - print("--------") - print( - f"example outputs before ptq: {output_decode(generated_ids_before_ptq, input_ids.shape[1])}" - ) - print("--------") - print( - f"example outputs after ptq: {output_decode(generated_ids_after_ptq, input_ids.shape[1])}" - ) - else: - warnings.warn("Skipping quantization: model is already quantized.") + language_model_lineage = get_language_model_from_vl(full_model) + if language_model_lineage is not None: + print("Updating full_model with quantized language_model...") + language_model_lineage[-2].language_model = language_model else: - assert model_type != "dbrx", f"Does not support export {model_type} without quantizaton" - print(f"qformat: {args.qformat}. No quantization applied, export {device} model") + warnings.warn("Skipping quantization: model is already quantized.") + +def export_quantized( + args: argparse.Namespace, + full_model: torch.nn.Module, + language_model: torch.nn.Module, + model_type: str | None, + tokenizer: PreTrainedTokenizerBase | None, + default_padding_side, +): with torch.inference_mode(): if model_type is None: - print(f"Unknown model type {type(model).__name__}. Continue exporting...") - model_type = f"unknown:{type(model).__name__}" + print(f"Unknown model type {type(language_model).__name__}. Continue exporting...") + model_type = f"unknown:{type(language_model).__name__}" export_path = args.export_path @@ -644,12 +491,11 @@ def output_decode(generated_ids, input_shape): print("This is normal for some VLM architectures that don't use AutoProcessor") if model_type == "mllama": - full_model_config = model.config - model = model.language_model + full_model_config = full_model.config # TRT-LLM expects both the vision_config and text_config to be set for export. - setattr(model.config, "vision_config", full_model_config.vision_config) - setattr(model.config, "text_config", full_model_config.text_config) - setattr(model.config, "architectures", full_model_config.architectures) + setattr(full_model.config, "vision_config", full_model_config.vision_config) + setattr(full_model.config, "text_config", full_model_config.text_config) + setattr(full_model.config, "architectures", full_model_config.architectures) start_time = time.time() if ( @@ -662,10 +508,10 @@ def output_decode(generated_ids, input_shape): ) # Move meta tensor back to device before exporting. - remove_hook_from_module(model, recurse=True) + remove_hook_from_module(language_model, recurse=True) export_tensorrt_llm_checkpoint( - model, + language_model, model_type, export_dir=export_path, inference_tensor_parallel=args.inference_tensor_parallel, @@ -701,11 +547,228 @@ def output_decode(generated_ids, input_shape): end_time = time.time() print( - f"Quantized model exported to :{export_path}. Total time used {end_time - start_time}s" + f"Quantized model exported to: {export_path}. Total time used {end_time - start_time}s" ) -if __name__ == "__main__": +def pre_quantize( + args: argparse.Namespace, + full_model: torch.nn.Module, + model_type: str | None, + tokenizer: PreTrainedTokenizerBase | None, + calib_dataloader: DataLoader, + is_nemotron_vl_model: bool, +): + # Only run single sample for preview + preview_input_ids = next(iter(calib_dataloader))[ + "input_features" if model_type == "whisper" else "input_ids" + ][0:1] + + # Generate preview before quantization + if is_nemotron_vl_model and tokenizer is not None: + generated_ids_before_ptq = run_nemotron_vl_preview( + full_model, + tokenizer, + preview_input_ids, + args.pyt_ckpt_path, + "before quantization", + allow_fallback=True, + ) + else: + # Standard generation for non-Nemotron VL models + generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=100) + + return preview_input_ids, generated_ids_before_ptq + + +def post_quantize( + args: argparse.Namespace, + full_model: torch.nn.Module, + model_type: str | None, + tokenizer: PreTrainedTokenizerBase | None, + processor: BaseImageProcessor | ProcessorMixin | None, + preview_input_ids, + generated_ids_before_ptq, + is_nemotron_vl_model, + first_text_speech_dataset, +): + if args.verbose: + mtq.print_quant_summary(full_model) + + # Run some samples + torch.cuda.empty_cache() + generated_ids_after_ptq = None + if model_type != "llama4" and not is_nemotron_vl_model: + # Our fake quantizer may not be fully compatible with torch.compile. + generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=100) + elif is_nemotron_vl_model and tokenizer is not None: + generated_ids_after_ptq = run_nemotron_vl_preview( + full_model, + tokenizer, + preview_input_ids, + args.pyt_ckpt_path, + "after quantization", + allow_fallback=False, + ) + else: + warnings.warn( + "Llama4 Maverick generation after quantization has a bug. Skipping generation sample." + ) + + def input_decode(input_ids): + if processor is not None and isinstance(processor, MllamaImageProcessor): + return processor.tokenizer.batch_decode(input_ids) + elif processor is not None and isinstance(processor, WhisperProcessor): + return first_text_speech_dataset + elif tokenizer is not None: + return tokenizer.batch_decode(input_ids) + else: + raise ValueError("The processor or tokenizer must be set") + + def output_decode(generated_ids, input_shape): + if is_enc_dec(model_type): + if processor is not None and isinstance(processor, WhisperProcessor): + return processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + elif tokenizer is not None: + return tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + elif processor is not None and isinstance(processor, MllamaImageProcessor): + return processor.tokenizer.batch_decode(generated_ids[:, input_shape:]) + elif tokenizer is not None: + return tokenizer.batch_decode(generated_ids[:, input_shape:]) + else: + raise ValueError("The processor or tokenizer must be set") + + if generated_ids_after_ptq is not None: + print("--------") + if is_nemotron_vl_model: + # For Nemotron VL models, generated_ids are text strings from model.chat() + print("Nemotron VL model text-only generation results:") + print(f"Text response before quantization: {generated_ids_before_ptq}") + print("--------") + print(f"Text response after quantization: {generated_ids_after_ptq}") + print("--------") + print("Note: Additional VL tests with images were run separately above") + else: + # For regular LLMs, generated_ids are token tensors that need decoding + print(f"example test input: {input_decode(preview_input_ids)}") + print("--------") + print( + f"example outputs before ptq: {output_decode(generated_ids_before_ptq, preview_input_ids.shape[1])}" + ) + print("--------") + print( + f"example outputs after ptq: {output_decode(generated_ids_after_ptq, preview_input_ids.shape[1])}" + ) + + +def quantize_main( + args: argparse.Namespace, + full_model: torch.nn.Module, + language_model: torch.nn.Module, + model_type: str | None, + calibration_only: bool, + processor: BaseImageProcessor | ProcessorMixin | None, + tokenizer: PreTrainedTokenizerBase | None, + default_padding_side, + device: torch.device, +): + if args.batch_size == 0: + # Calibration/sparsification will actually take much more memory than regular inference + # due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio + # to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference. + sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1 + # Whisper model expects mel-spectrogram input features of length 3000 + # Whisper model needs input of shape (batch_size, num_mel_bins, 3000) + # As the encoder of Whisper doesn't have embedding layer, input dtype has to be float + # For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size() + if model_type == "whisper": + max_sample_length = 3000 + num_mel_bins = language_model.config.num_mel_bins + sample_input_single_batch = ( + torch.ones([1, num_mel_bins, max_sample_length], dtype=language_model.dtype).to( + language_model.device + ) + * 100 + ) + else: + sample_input_single_batch = None + + run_auto_quant = args.auto_quantize_bits is not None + + args.batch_size = get_max_batch_size( + language_model, + max_sample_length=args.calib_seq, + sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0, + sample_input_single_batch=sample_input_single_batch, + enable_grad=run_auto_quant, + ) + args.batch_size = min(args.batch_size, sum(args.calib_size)) + + print(f"Use calib batch_size {args.batch_size}") + + calib_dataloader, first_text_speech_dataset = make_calib_dataloader( + args, language_model, processor, tokenizer, device, model_type + ) + + # Detect if this is a Nemotron VL model using architecture-based detection + is_nemotron_vl_model = is_nemotron_vl(full_model) + + preview_input_ids, generated_ids_before_ptq = pre_quantize( + args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model + ) + + if args.auto_quantize_bits: + assert len(args.qformat.split(",")) > 1, ( + "Auto quantization needs multiple quantization format." + ) + + auto_quantize( + args, + language_model, + calib_dataloader, + ) + + else: + # plain quantization + assert len(args.qformat.split(",")) == 1, ( + "Plain quantization supports only one quantization format." + ) + + quant_cfg = build_quant_cfg( + args.qformat, + args.kv_cache_qformat, + args.awq_block_size, + model_type, + QUANT_CFG_CHOICES, + KV_QUANT_CFG_CHOICES, + ) + + plain_quantize( + args, + quant_cfg, + full_model, + language_model, + model_type, + calibration_only, + calib_dataloader, + is_nemotron_vl_model, + ) + + post_quantize( + args, + full_model, + model_type, + tokenizer, + processor, + preview_input_ids, + generated_ids_before_ptq, + is_nemotron_vl_model, + first_text_speech_dataset, + ) + export_quantized(args, full_model, language_model, model_type, tokenizer, default_padding_side) + + +def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--pyt_ckpt_path", @@ -716,8 +779,13 @@ def output_decode(generated_ids, input_shape): parser.add_argument( "--qformat", help=( - "Quantization format. If --auto_quantize_bits is set, this argument specifies the quantization " - "format for optimal per-layer auto_quantize search." + "Quantization format. Multiple possible choices: " + f"In plain quantization case: 1. it can be a format name, valid names are: {QUANT_CFG_CHOICES.keys()}. " + "2. it can be a built-in quantization configuration name, they are equivalent to file names without suffix " + "under modelopt/config/quantization/. " + "3. It can be a path to a quantization configuration yaml file for custom quantization formats. " + "In auto-quantize case, i.e. --auto_quantize_bits is set: " + "it is a list of format names, separated by semicolon. " ), default="fp8", ) @@ -866,7 +934,53 @@ def output_decode(generated_ids, input_shape): ), ) - args = parser.parse_args() + return parser.parse_args() + + +def main(args: argparse.Namespace): + if not torch.cuda.is_available(): + raise OSError("GPU is required for inference.") + + random.seed(RAND_SEED) + np.random.seed(RAND_SEED) + + # launch a memory monitor to read the currently used GPU memory. + launch_memory_monitor() + + # Force eager execution for all model types. + torch.compiler.set_stance("force_eager") + + ( + full_model, + language_model, + model_type, + calibration_only, + processor, + tokenizer, + default_padding_side, + device, + ) = load_model(args) + + if args.sparsity_fmt != "dense": + # Sparse + sparsity_main(args, full_model, tokenizer, device) + else: + # Quantize + quantize_main( + args, + full_model, + language_model, + model_type, + calibration_only, + processor, + tokenizer, + default_padding_side, + device, + ) + + +if __name__ == "__main__": + args = parse_args() if args.export_fmt != "hf": warnings.warn("Deprecated. --export_fmt forced to hf.") diff --git a/examples/llm_ptq/multinode_ptq.py b/examples/llm_ptq/multinode_ptq.py index dc3c5e4b0..2ae7dde4a 100644 --- a/examples/llm_ptq/multinode_ptq.py +++ b/examples/llm_ptq/multinode_ptq.py @@ -332,7 +332,6 @@ def main(args): args.qformat, args.kv_cache_qformat, args.awq_block_size, - None, model_type, QUANT_CFG_CHOICES, KV_QUANT_CFG_CHOICES, diff --git a/modelopt/config/quantization/default_disabled.yml b/modelopt/config/quantization/default_disabled.yml new file mode 100644 index 000000000..a97ca6816 --- /dev/null +++ b/modelopt/config/quantization/default_disabled.yml @@ -0,0 +1,44 @@ +quant_cfg: + default: + enable: false + + # path: + '*block_sparse_moe.gate*': + # Skip the MOE router + enable: false + '*linear_attn.conv1d*': + enable: false + '*lm_head*': + enable: false + '*mixer.conv1d*': + enable: false + '*mlp.gate.*': + # Skip the MOE router + enable: false + '*mlp.shared_expert_gate.*': + # Skip the MOE router + enable: false + '*output_layer*': + enable: false + '*proj_out.*': + # In Whisper model, lm_head has key name proj_out + enable: false + '*router*': + # Skip the MOE router + enable: false + 'output.*': + enable: false + + # module: + torch.nn.BatchNorm1d: + '*': + enable: false + torch.nn.BatchNorm2d: + '*': + enable: false + torch.nn.BatchNorm3d: + '*': + enable: false + torch.nn.LeakyReLU: + '*': + enable: false diff --git a/modelopt/config/quantization/fp8_2d_blockwise_weight_only.yml b/modelopt/config/quantization/fp8_2d_blockwise_weight_only.yml new file mode 100644 index 000000000..15bdc095a --- /dev/null +++ b/modelopt/config/quantization/fp8_2d_blockwise_weight_only.yml @@ -0,0 +1,13 @@ +defaults: + - quantization/default_disabled + +algorithm: max +quant_cfg: + '*weight_quantizer': + num_bits: [4, 3] + block_sizes: + -1: 128 + -2: 128 + enable: true + '*input_quantizer': + enable: false diff --git a/modelopt/config/quantization/fp8_affine_kv.yml b/modelopt/config/quantization/fp8_affine_kv.yml new file mode 100644 index 000000000..d6a84d1c2 --- /dev/null +++ b/modelopt/config/quantization/fp8_affine_kv.yml @@ -0,0 +1,11 @@ +algorithm: max +quant_cfg: + '*[kv]_bmm_quantizer': + bias: + -2: null + -4: null + type: static + num_bits: [4, 3] + axis: null + default: + enable: false diff --git a/modelopt/config/quantization/fp8_default.yml b/modelopt/config/quantization/fp8_default.yml new file mode 100644 index 000000000..3d15e222c --- /dev/null +++ b/modelopt/config/quantization/fp8_default.yml @@ -0,0 +1,11 @@ +defaults: + - quantization/default_disabled + +algorithm: max +quant_cfg: + '*input_quantizer': + num_bits: [4, 3] + axis: null + '*weight_quantizer': + num_bits: [4, 3] + axis: null diff --git a/modelopt/config/quantization/fp8_kv.yml b/modelopt/config/quantization/fp8_kv.yml new file mode 100644 index 000000000..70078b2cb --- /dev/null +++ b/modelopt/config/quantization/fp8_kv.yml @@ -0,0 +1,8 @@ +algorithm: max +quant_cfg: + '*[kv]_bmm_quantizer': + num_bits: [4, 3] + axis: null + enable: true + default: + enable: false diff --git a/modelopt/config/quantization/fp8_per_channel_per_token.yml b/modelopt/config/quantization/fp8_per_channel_per_token.yml new file mode 100644 index 000000000..954d48bea --- /dev/null +++ b/modelopt/config/quantization/fp8_per_channel_per_token.yml @@ -0,0 +1,13 @@ +defaults: + - quantization/default_disabled + +algorithm: max +quant_cfg: + '*weight_quantizer': + num_bits: [4, 3] + axis: 0 + '*input_quantizer': + num_bits: [4, 3] + type: dynamic + block_sizes: + -1: null diff --git a/modelopt/config/quantization/int4_awq.yml b/modelopt/config/quantization/int4_awq.yml new file mode 100644 index 000000000..3e75ebf33 --- /dev/null +++ b/modelopt/config/quantization/int4_awq.yml @@ -0,0 +1,15 @@ +defaults: + - quantization/default_disabled + +algorithm: + method: awq_lite + alpha_step: 0.1 +quant_cfg: + '*weight_quantizer': + num_bits: 4 + block_sizes: + -1: 128 + type: static + enable: true + '*input_quantizer': + enable: false diff --git a/modelopt/config/quantization/int4_blockwise_weight_only.yml b/modelopt/config/quantization/int4_blockwise_weight_only.yml new file mode 100644 index 000000000..062ffc29f --- /dev/null +++ b/modelopt/config/quantization/int4_blockwise_weight_only.yml @@ -0,0 +1,12 @@ +defaults: + - quantization/default_disabled + +algorithm: max +quant_cfg: + '*weight_quantizer': + num_bits: 4 + block_sizes: + -1: 128 + enable: true + '*input_quantizer': + enable: false diff --git a/modelopt/config/quantization/int8_default.yml b/modelopt/config/quantization/int8_default.yml new file mode 100644 index 000000000..72e9ba3c5 --- /dev/null +++ b/modelopt/config/quantization/int8_default.yml @@ -0,0 +1,11 @@ +defaults: + - quantization/default_disabled + +algorithm: max +quant_cfg: + "*weight_quantizer": + num_bits: 8 + axis: 0 + "*input_quantizer": + num_bits: 8 + axis: null \ No newline at end of file diff --git a/modelopt/config/quantization/int8_smoothquant.yml b/modelopt/config/quantization/int8_smoothquant.yml new file mode 100644 index 000000000..82fc9cd68 --- /dev/null +++ b/modelopt/config/quantization/int8_smoothquant.yml @@ -0,0 +1,11 @@ +defaults: + - quantization/default_disabled + +algorithm: smoothquant +quant_cfg: + "*weight_quantizer": + num_bits: 8 + axis: 0 + "*input_quantizer": + num_bits: 8 + axis: null \ No newline at end of file diff --git a/modelopt/config/quantization/int8_weight_only.yml b/modelopt/config/quantization/int8_weight_only.yml new file mode 100644 index 000000000..443959ccd --- /dev/null +++ b/modelopt/config/quantization/int8_weight_only.yml @@ -0,0 +1,10 @@ +defaults: + - quantization/default_disabled + +algorithm: max +quant_cfg: + "*weight_quantizer": + num_bits: 8 + axis: 0 + "*input_quantizer": + enable: false diff --git a/modelopt/config/quantization/mxfp4_default.yml b/modelopt/config/quantization/mxfp4_default.yml new file mode 100644 index 000000000..3d44b1978 --- /dev/null +++ b/modelopt/config/quantization/mxfp4_default.yml @@ -0,0 +1,19 @@ +defaults: + - quantization/default_disabled + +algorithm: null +quant_cfg: + "*weight_quantizer": + num_bits: [2, 1] + block_sizes: + -1: 32 + type: dynamic + scale_bits: [8, 0] + enable: true + "*input_quantizer": + num_bits: [2, 1] + block_sizes: + -1: 32 + type: dynamic + scale_bits: [8, 0] + enable: true diff --git a/modelopt/config/quantization/mxfp4_mlp_weight_only.yml b/modelopt/config/quantization/mxfp4_mlp_weight_only.yml new file mode 100644 index 000000000..4e5ce13fe --- /dev/null +++ b/modelopt/config/quantization/mxfp4_mlp_weight_only.yml @@ -0,0 +1,13 @@ +defaults: + - quantization/default_disabled + +algorithm: null +quant_cfg: + "*mlp*weight_quantizer": + num_bits: [2, 1] + block_sizes: + -1: 32 + type: dynamic + scale_bits: [8, 0] + enable: true + pass_through_bwd: true diff --git a/modelopt/config/quantization/mxfp6_default.yml b/modelopt/config/quantization/mxfp6_default.yml new file mode 100644 index 000000000..77d7a333e --- /dev/null +++ b/modelopt/config/quantization/mxfp6_default.yml @@ -0,0 +1,19 @@ +defaults: + - quantization/default_disabled + +algorithm: null +quant_cfg: + "*weight_quantizer": + num_bits: [3, 2] + block_sizes: + -1: 32 + type: dynamic + scale_bits: [8, 0] + enable: true + "*input_quantizer": + num_bits: [3, 2] + block_sizes: + -1: 32 + type: dynamic + scale_bits: [8, 0] + enable: true diff --git a/modelopt/config/quantization/mxfp8_default.yml b/modelopt/config/quantization/mxfp8_default.yml new file mode 100644 index 000000000..068b4e273 --- /dev/null +++ b/modelopt/config/quantization/mxfp8_default.yml @@ -0,0 +1,19 @@ +defaults: + - quantization/default_disabled + +algorithm: null +quant_cfg: + "*weight_quantizer": + num_bits: [4, 3] + block_sizes: + -1: 32 + type: dynamic + scale_bits: [8, 0] + enable: true + "*input_quantizer": + num_bits: [4, 3] + block_sizes: + -1: 32 + type: dynamic + scale_bits: [8, 0] + enable: true diff --git a/modelopt/config/quantization/mxint8_default.yml b/modelopt/config/quantization/mxint8_default.yml new file mode 100644 index 000000000..c36314d5d --- /dev/null +++ b/modelopt/config/quantization/mxint8_default.yml @@ -0,0 +1,19 @@ +defaults: + - quantization/default_disabled + +algorithm: null +quant_cfg: + "*weight_quantizer": + num_bits: 8 + block_sizes: + -1: 32 + type: dynamic + scale_bits: [8, 0] + enable: true + "*input_quantizer": + num_bits: 8 + block_sizes: + -1: 32 + type: dynamic + scale_bits: [8, 0] + enable: true diff --git a/modelopt/config/quantization/nvfp4_affine_kv.yml b/modelopt/config/quantization/nvfp4_affine_kv.yml new file mode 100644 index 000000000..3d9cf51d9 --- /dev/null +++ b/modelopt/config/quantization/nvfp4_affine_kv.yml @@ -0,0 +1,16 @@ +algorithm: max +quant_cfg: + '*[kv]_bmm_quantizer': + num_bits: [2, 1] + block_sizes: + -1: 16 + type: dynamic + scale_bits: [4, 3] + axis: null + bias: + -2: null + -4: null + type: static + enable: true + default: + enable: false diff --git a/modelopt/config/quantization/nvfp4_awq_clip.yml b/modelopt/config/quantization/nvfp4_awq_clip.yml new file mode 100644 index 000000000..42ebca341 --- /dev/null +++ b/modelopt/config/quantization/nvfp4_awq_clip.yml @@ -0,0 +1,22 @@ +defaults: + - quantization/default_disabled + +algorithm: + method: awq_clip +quant_cfg: + '*weight_quantizer': + block_sizes: + -1: 16 + type: dynamic + scale_bits: [4, 3] + num_bits: [2, 1] + axis: null + enable: true + '*input_quantizer': + block_sizes: + -1: 16 + type: dynamic + scale_bits: [4, 3] + num_bits: [2, 1] + axis: null + enable: true diff --git a/modelopt/config/quantization/nvfp4_awq_full.yml b/modelopt/config/quantization/nvfp4_awq_full.yml new file mode 100644 index 000000000..ab91c56b8 --- /dev/null +++ b/modelopt/config/quantization/nvfp4_awq_full.yml @@ -0,0 +1,23 @@ +defaults: + - quantization/default_disabled + +algorithm: + method: awq_full + alpha_step: 0.1 +quant_cfg: + '*weight_quantizer': + block_sizes: + -1: 16 + type: dynamic + scale_bits: [4, 3] + num_bits: [2, 1] + axis: null + enable: true + '*input_quantizer': + block_sizes: + -1: 16 + type: dynamic + scale_bits: [4, 3] + num_bits: [2, 1] + axis: null + enable: true diff --git a/modelopt/config/quantization/nvfp4_awq_lite.yml b/modelopt/config/quantization/nvfp4_awq_lite.yml new file mode 100644 index 000000000..9434693ce --- /dev/null +++ b/modelopt/config/quantization/nvfp4_awq_lite.yml @@ -0,0 +1,21 @@ +defaults: + - quantization/default_disabled + +algorithm: awq_lite +quant_cfg: + '*weight_quantizer': + block_sizes: + -1: 16 + type: dynamic + scale_bits: [4, 3] + num_bits: [2, 1] + axis: null + enable: true + '*input_quantizer': + block_sizes: + -1: 16 + type: dynamic + scale_bits: [4, 3] + num_bits: [2, 1] + axis: null + enable: true diff --git a/modelopt/config/quantization/nvfp4_default.yml b/modelopt/config/quantization/nvfp4_default.yml new file mode 100644 index 000000000..ecc78f55e --- /dev/null +++ b/modelopt/config/quantization/nvfp4_default.yml @@ -0,0 +1,21 @@ +defaults: + - quantization/default_disabled + +algorithm: max +quant_cfg: + '*weight_quantizer': + block_sizes: + -1: 16 + type: dynamic + scale_bits: [4, 3] + num_bits: [2, 1] + axis: null + enable: true + '*input_quantizer': + block_sizes: + -1: 16 + type: dynamic + scale_bits: [4, 3] + num_bits: [2, 1] + axis: null + enable: true diff --git a/modelopt/config/quantization/nvfp4_fp8_mha.yml b/modelopt/config/quantization/nvfp4_fp8_mha.yml new file mode 100644 index 000000000..9d811b29d --- /dev/null +++ b/modelopt/config/quantization/nvfp4_fp8_mha.yml @@ -0,0 +1,37 @@ +algorithm: max +quant_cfg: + default: + enable: false + '*weight_quantizer': + num_bits: [2, 1] + block_sizes: + -1: 16 + type: dynamic + scale_bits: [4, 3] + axis: null + enable: true + '*input_quantizer': + num_bits: [2, 1] + block_sizes: + -1: 16 + type: dynamic + scale_bits: [4, 3] + axis: null + enable: true + '*output_quantizer': + enable: false + '*q_bmm_quantizer': + num_bits: [4, 3] + axis: null + '*k_bmm_quantizer': + num_bits: [4, 3] + axis: null + '*v_bmm_quantizer': + num_bits: [4, 3] + axis: null + '*softmax_quantizer': + num_bits: [4, 3] + axis: null + 'transformer_blocks*bmm2_output_quantizer': + num_bits: [4, 3] + axis: null diff --git a/modelopt/config/quantization/nvfp4_kv.yml b/modelopt/config/quantization/nvfp4_kv.yml new file mode 100644 index 000000000..fc679fdcc --- /dev/null +++ b/modelopt/config/quantization/nvfp4_kv.yml @@ -0,0 +1,12 @@ +algorithm: max +quant_cfg: + '*[kv]_bmm_quantizer': + num_bits: [2, 1] + block_sizes: + -1: 16 + type: dynamic + scale_bits: [4, 3] + axis: null + enable: true + default: + enable: false diff --git a/modelopt/config/quantization/nvfp4_kv_rotate.yml b/modelopt/config/quantization/nvfp4_kv_rotate.yml new file mode 100644 index 000000000..30e691d3a --- /dev/null +++ b/modelopt/config/quantization/nvfp4_kv_rotate.yml @@ -0,0 +1,22 @@ +algorithm: max +quant_cfg: + '*q_bmm_quantizer': + enable: false + rotate: true + '*k_bmm_quantizer': + num_bits: [2, 1] + block_sizes: + -1: 16 + type: dynamic + scale_bits: [4, 3] + axis: null + enable: true + rotate: true + '*v_bmm_quantizer': + num_bits: [2, 1] + block_sizes: + -1: 16 + type: dynamic + scale_bits: [4, 3] + axis: null + enable: true diff --git a/modelopt/config/quantization/nvfp4_mlp_only.yml b/modelopt/config/quantization/nvfp4_mlp_only.yml new file mode 100644 index 000000000..96a9885ec --- /dev/null +++ b/modelopt/config/quantization/nvfp4_mlp_only.yml @@ -0,0 +1,21 @@ +defaults: + - quantization/default_disabled + +algorithm: max +quant_cfg: + '*mlp*weight_quantizer': + block_sizes: + -1: 16 + type: dynamic + scale_bits: [4, 3] + num_bits: [2, 1] + enable: true + pass_through_bwd: true + '*mlp*input_quantizer': + block_sizes: + -1: 16 + type: dynamic + scale_bits: [4, 3] + num_bits: [2, 1] + enable: true + pass_through_bwd: true diff --git a/modelopt/config/quantization/nvfp4_mlp_weight_only.yml b/modelopt/config/quantization/nvfp4_mlp_weight_only.yml new file mode 100644 index 000000000..ba0cfc335 --- /dev/null +++ b/modelopt/config/quantization/nvfp4_mlp_weight_only.yml @@ -0,0 +1,13 @@ +defaults: + - quantization/default_disabled + +algorithm: max +quant_cfg: + '*mlp*weight_quantizer': + num_bits: [2, 1] + block_sizes: + -1: 32 + type: dynamic + scale_bits: [4, 3] + enable: true + pass_through_bwd: true diff --git a/modelopt/config/quantization/nvfp4_svdquant_default.yml b/modelopt/config/quantization/nvfp4_svdquant_default.yml new file mode 100644 index 000000000..f555c598f --- /dev/null +++ b/modelopt/config/quantization/nvfp4_svdquant_default.yml @@ -0,0 +1,23 @@ +defaults: + - quantization/default_disabled + +algorithm: + method: svdquant + lowrank: 32 +quant_cfg: + '*weight_quantizer': + block_sizes: + -1: 16 + type: dynamic + scale_bits: [4, 3] + num_bits: [2, 1] + enable: true + axis: null + '*input_quantizer': + block_sizes: + -1: 16 + type: dynamic + scale_bits: [4, 3] + num_bits: [2, 1] + enable: true + axis: null diff --git a/modelopt/config/quantization/w4a8_awq_beta.yml b/modelopt/config/quantization/w4a8_awq_beta.yml new file mode 100644 index 000000000..e0623fa4f --- /dev/null +++ b/modelopt/config/quantization/w4a8_awq_beta.yml @@ -0,0 +1,20 @@ +defaults: + - quantization/default_disabled + +algorithm: awq_lite +quant_cfg: + '*weight_quantizer': + - num_bits: 4 + block_sizes: + -1: 128 + type: static + enable: true + + - num_bits: [4, 3] + axis: null + enable: true + + '*input_quantizer': + num_bits: [4, 3] + enable: true + axis: null diff --git a/modelopt/config/quantization/w4a8_mxfp4_fp8.yml b/modelopt/config/quantization/w4a8_mxfp4_fp8.yml new file mode 100644 index 000000000..24ab4fc9d --- /dev/null +++ b/modelopt/config/quantization/w4a8_mxfp4_fp8.yml @@ -0,0 +1,15 @@ +defaults: + - quantization/default_disabled + +algorithm: null +quant_cfg: + '*weight_quantizer': + num_bits: [2, 1] + block_sizes: + -1: 32 + type: dynamic + scale_bits: [8, 0] + enable: true + '*input_quantizer': + num_bits: [4, 3] + axis: null diff --git a/modelopt/config/quantization/w4a8_nvfp4_fp8.yml b/modelopt/config/quantization/w4a8_nvfp4_fp8.yml new file mode 100644 index 000000000..f8d7b6ff4 --- /dev/null +++ b/modelopt/config/quantization/w4a8_nvfp4_fp8.yml @@ -0,0 +1,17 @@ +defaults: + - quantization/default_disabled + +algorithm: max +quant_cfg: + '*weight_quantizer': + num_bits: [2, 1] + block_sizes: + -1: 32 + type: dynamic + scale_bits: [4, 3] + enable: true + axis: null + '*input_quantizer': + num_bits: [4, 3] + enable: true + axis: null diff --git a/modelopt/torch/opt/config.py b/modelopt/torch/opt/config.py index 11286dd95..7d216952d 100644 --- a/modelopt/torch/opt/config.py +++ b/modelopt/torch/opt/config.py @@ -18,9 +18,13 @@ import fnmatch import json from collections.abc import Callable, ItemsView, Iterator, KeysView, ValuesView +from importlib.resources import files +from pathlib import Path from typing import Any, TypeAlias import pydantic +import yaml +from omegaconf import OmegaConf from pydantic import ( BaseModel, Field, @@ -381,3 +385,45 @@ def get_kwargs_for_create_model_with_rules( "__cls_kwargs__": {"registry": registry}, **field_specs, } + + +BUITIN_CONFIG_PATH = files("modelopt.config") + + +def load_config(config_name_path: str | Path): + """Load a config yaml. + + config_name_path: + + Could be an internal predefined config name or a OS file path to a user specified config yaml file. + + A config name is simply the sub path to an built-in config yaml file without the yml/yaml suffix. + + For example, for the built-in FP8 default quantization config, the config name is "quantization/fp8_default". + It simply points to the config file modelopt/config/quantization/fp8_default.yml + + """ + if isinstance(config_name_path, str): + if not config_name_path.endswith(".yml") or config_name_path.endswith(".yaml"): + # config name case + config_name = config_name_path.lower() + config_file = BUITIN_CONFIG_PATH.joinpath(f"{config_name}.yml") + if not config_file.is_file(): + config_file = BUITIN_CONFIG_PATH.joinpath(f"{config_name}.yaml") + else: + config_file = Path(config_name_path) + else: + config_file = config_name_path + + if not config_file.is_file(): + raise ValueError(f"Cannot find config file of {config_name_path}") + + content = config_file.read_text(encoding="utf-8") + + config_data = yaml.safe_load(content) + + defaults = [load_config(default_cfg) for default_cfg in config_data.pop("defaults", [])] + + defaults.append(config_data) + + return OmegaConf.merge(*defaults) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 3a43113df..0cacd8bd7 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -89,8 +89,8 @@ from modelopt.torch.quantization import QuantModuleRegistry - # Get the class name for nn.Conv2d - class_name = QuantModuleRegistry.get_key(nn.Conv2d) + # Get the class name for torch.nn.Conv2d + class_name = QuantModuleRegistry.get_key(torch.nn.Conv2d) Here is an example of a quantization config: @@ -103,7 +103,7 @@ "*input_quantizer": {"num_bits": 8, "axis": None}, # Module class names mapping to quantizer configurations - "nn.LeakyReLU": {"*input_quantizer": {"enable": False}}, + "torch.nn.LeakyReLU": {"*input_quantizer": {"enable": False}}, } } @@ -137,18 +137,40 @@ """ from collections.abc import Callable -from typing import Literal +from pathlib import Path +from typing import Literal, cast +from omegaconf import DictConfig, OmegaConf from pydantic import ValidationInfo, field_validator, model_validator from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField +from modelopt.torch.opt.config import load_config as load_config from modelopt.torch.utils.network import ConstructorLike + +def load_quant_config(quant_config_name_path: str | Path) -> DictConfig: + """Load a quantization config yaml. + + quant_config_name_path: + + Could be an internal predefined config name or a user manually specified config yaml file. + """ + if isinstance(quant_config_name_path, str): + if not quant_config_name_path.endswith(".yml") or not quant_config_name_path.endswith( + ".yaml" + ): + # config name case + if not quant_config_name_path.startswith("quantization/"): + quant_config_name_path = f"quantization/{quant_config_name_path}" + + return load_config(quant_config_name_path) + + _default_disabled_quantizer_cfg = { - "nn.BatchNorm1d": {"*": {"enable": False}}, - "nn.BatchNorm2d": {"*": {"enable": False}}, - "nn.BatchNorm3d": {"*": {"enable": False}}, - "nn.LeakyReLU": {"*": {"enable": False}}, + "torch.nn.BatchNorm1d": {"*": {"enable": False}}, + "torch.nn.BatchNorm2d": {"*": {"enable": False}}, + "torch.nn.BatchNorm3d": {"*": {"enable": False}}, + "torch.nn.LeakyReLU": {"*": {"enable": False}}, "*lm_head*": {"enable": False}, "*proj_out.*": {"enable": False}, # In Whisper model, lm_head has key name proj_out "*block_sparse_moe.gate*": {"enable": False}, # Skip the MOE router @@ -161,6 +183,9 @@ "output.*": {"enable": False}, "default": {"enable": False}, } +assert OmegaConf.to_container( + OmegaConf.create({"quant_cfg": _default_disabled_quantizer_cfg}), resolve=True +) == OmegaConf.to_container(load_quant_config("DEFAULT_DISABLED"), resolve=True) INT8_DEFAULT_CFG = { "quant_cfg": { @@ -170,6 +195,9 @@ }, "algorithm": "max", } +assert OmegaConf.to_container( + OmegaConf.create(INT8_DEFAULT_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("INT8_DEFAULT"), resolve=True) INT8_SMOOTHQUANT_CFG = { "quant_cfg": { @@ -179,6 +207,9 @@ }, "algorithm": "smoothquant", } +assert OmegaConf.to_container( + OmegaConf.create(INT8_SMOOTHQUANT_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("INT8_SMOOTHQUANT"), resolve=True) INT8_WEIGHT_ONLY_CFG = { "quant_cfg": { @@ -188,6 +219,10 @@ }, "algorithm": "max", } +assert OmegaConf.to_container( + OmegaConf.create(INT8_WEIGHT_ONLY_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("INT8_WEIGHT_ONLY"), resolve=True) + FP8_DEFAULT_CFG = { "quant_cfg": { @@ -197,6 +232,10 @@ }, "algorithm": "max", } +assert OmegaConf.to_container( + OmegaConf.create(FP8_DEFAULT_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("FP8_DEFAULT"), resolve=True) + FP8_PER_CHANNEL_PER_TOKEN_CFG = { "quant_cfg": { @@ -210,6 +249,10 @@ }, "algorithm": "max", } +assert OmegaConf.to_container( + OmegaConf.create(FP8_PER_CHANNEL_PER_TOKEN_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("FP8_PER_CHANNEL_PER_TOKEN"), resolve=True) + # FP8 2D blockwise fake quantization config for deepseek models FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG = { @@ -224,6 +267,10 @@ }, "algorithm": "max", } +assert OmegaConf.to_container( + OmegaConf.create(FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("FP8_2D_BLOCKWISE_WEIGHT_ONLY"), resolve=True) + INT4_BLOCKWISE_WEIGHT_ONLY_CFG = { "quant_cfg": { @@ -233,6 +280,10 @@ }, "algorithm": "max", } +assert OmegaConf.to_container( + OmegaConf.create(INT4_BLOCKWISE_WEIGHT_ONLY_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("INT4_BLOCKWISE_WEIGHT_ONLY"), resolve=True) + INT4_AWQ_CFG = { "quant_cfg": { @@ -248,6 +299,10 @@ # "algorithm": {"method": "awq_full", "alpha_step": 0.1, "max_co_batch_size": 1024}, # "algorithm": {"method": "awq_clip", "max_co_batch_size": 2048}, } +assert OmegaConf.to_container( + OmegaConf.create(INT4_AWQ_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("INT4_AWQ"), resolve=True) + # W4A8 currently uses INT4 blockwise quantization (block size = 128) followed by FP8 quantization # for weights. This could change in the future @@ -262,6 +317,9 @@ }, "algorithm": "awq_lite", } +assert OmegaConf.to_container( + OmegaConf.create(W4A8_AWQ_BETA_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("W4A8_AWQ_BETA"), resolve=True) MXFP8_DEFAULT_CFG = { "quant_cfg": { @@ -279,6 +337,9 @@ }, "algorithm": None, } +assert OmegaConf.to_container( + OmegaConf.create(MXFP8_DEFAULT_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("MXFP8_DEFAULT"), resolve=True) MXFP6_DEFAULT_CFG = { "quant_cfg": { @@ -296,6 +357,10 @@ }, "algorithm": None, } +assert OmegaConf.to_container( + OmegaConf.create(MXFP6_DEFAULT_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("MXFP6_DEFAULT"), resolve=True) + MXFP4_DEFAULT_CFG = { "quant_cfg": { @@ -313,6 +378,9 @@ }, "algorithm": None, } +assert OmegaConf.to_container( + OmegaConf.create(MXFP4_DEFAULT_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("MXFP4_DEFAULT"), resolve=True) W4A8_MXFP4_FP8_CFG = { "quant_cfg": { @@ -326,6 +394,10 @@ }, "algorithm": None, } +assert OmegaConf.to_container( + OmegaConf.create(W4A8_MXFP4_FP8_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("W4A8_MXFP4_FP8"), resolve=True) + MXINT8_DEFAULT_CFG = { "quant_cfg": { @@ -343,6 +415,9 @@ }, "algorithm": None, } +assert OmegaConf.to_container( + OmegaConf.create(MXINT8_DEFAULT_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("MXINT8_DEFAULT"), resolve=True) FP8_KV_CFG = { "quant_cfg": { @@ -355,6 +430,10 @@ }, "algorithm": "max", } +assert OmegaConf.to_container(OmegaConf.create(FP8_KV_CFG), resolve=True) == OmegaConf.to_container( + load_quant_config("FP8_KV"), resolve=True +) + FP8_AFFINE_KV_CFG = { "quant_cfg": { @@ -367,6 +446,10 @@ }, "algorithm": "max", } +assert OmegaConf.to_container( + OmegaConf.create(FP8_AFFINE_KV_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("FP8_AFFINE_KV"), resolve=True) + NVFP4_DEFAULT_CFG = { "quant_cfg": { @@ -386,6 +469,9 @@ }, "algorithm": "max", } +assert OmegaConf.to_container( + OmegaConf.create(NVFP4_DEFAULT_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("NVFP4_DEFAULT"), resolve=True) NVFP4_AWQ_LITE_CFG = { @@ -406,6 +492,9 @@ }, "algorithm": "awq_lite", } +assert OmegaConf.to_container( + OmegaConf.create(NVFP4_AWQ_LITE_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("NVFP4_AWQ_LITE"), resolve=True) NVFP4_AWQ_CLIP_CFG = { "quant_cfg": { @@ -425,6 +514,9 @@ }, "algorithm": {"method": "awq_clip"}, } +assert OmegaConf.to_container( + OmegaConf.create(NVFP4_AWQ_CLIP_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("NVFP4_AWQ_CLIP"), resolve=True) NVFP4_AWQ_FULL_CFG = { "quant_cfg": { @@ -444,7 +536,9 @@ }, "algorithm": {"method": "awq_full", "alpha_step": 0.1}, } - +assert OmegaConf.to_container( + OmegaConf.create(NVFP4_AWQ_FULL_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("NVFP4_AWQ_FULL"), resolve=True) NVFP4_AFFINE_KV_CFG = { "quant_cfg": { @@ -459,6 +553,9 @@ }, "algorithm": "max", } +assert OmegaConf.to_container( + OmegaConf.create(NVFP4_AFFINE_KV_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("NVFP4_AFFINE_KV"), resolve=True) NVFP4_KV_CFG = { "quant_cfg": { @@ -472,6 +569,10 @@ }, "algorithm": "max", } +assert OmegaConf.to_container( + OmegaConf.create(NVFP4_KV_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("NVFP4_KV"), resolve=True) + # Moved from examples/diffusers/quantization/config.py to here NVFP4_FP8_MHA_CONFIG = { @@ -513,6 +614,10 @@ }, "algorithm": "max", } +assert OmegaConf.to_container( + OmegaConf.create(NVFP4_FP8_MHA_CONFIG), resolve=True +) == OmegaConf.to_container(load_quant_config("NVFP4_FP8_MHA"), resolve=True) + NVFP4_KV_ROTATE_CFG = { "quant_cfg": { @@ -536,6 +641,9 @@ }, "algorithm": "max", } +assert OmegaConf.to_container( + OmegaConf.create(NVFP4_KV_ROTATE_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("NVFP4_KV_ROTATE"), resolve=True) NVFP4_SVDQUANT_DEFAULT_CFG = { "quant_cfg": { @@ -555,6 +663,9 @@ }, "algorithm": {"method": "svdquant", "lowrank": 32}, } +assert OmegaConf.to_container( + OmegaConf.create(NVFP4_SVDQUANT_DEFAULT_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("NVFP4_SVDQUANT_DEFAULT"), resolve=True) W4A8_NVFP4_FP8_CFG = { "quant_cfg": { @@ -573,6 +684,9 @@ }, "algorithm": "max", } +assert OmegaConf.to_container( + OmegaConf.create(W4A8_NVFP4_FP8_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("W4A8_NVFP4_FP8"), resolve=True) MXFP4_MLP_WEIGHT_ONLY_CFG = { "quant_cfg": { @@ -586,6 +700,9 @@ }, "algorithm": None, } +assert OmegaConf.to_container( + OmegaConf.create(MXFP4_MLP_WEIGHT_ONLY_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("MXFP4_MLP_WEIGHT_ONLY"), resolve=True) NVFP4_MLP_WEIGHT_ONLY_CFG = { "quant_cfg": { @@ -603,6 +720,9 @@ }, "algorithm": "max", } +assert OmegaConf.to_container( + OmegaConf.create(NVFP4_MLP_WEIGHT_ONLY_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("NVFP4_MLP_WEIGHT_ONLY"), resolve=True) NVFP4_MLP_ONLY_CFG = { "quant_cfg": { @@ -622,6 +742,9 @@ }, "algorithm": "max", } +assert OmegaConf.to_container( + OmegaConf.create(NVFP4_MLP_ONLY_CFG), resolve=True +) == OmegaConf.to_container(load_quant_config("NVFP4_MLP_ONLY"), resolve=True) choices: set[str] = { "FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG", @@ -667,7 +790,7 @@ class QuantizerAttributeConfig(ModeloptBaseConfig): description="""If True, enables the quantizer. If False, by-pass the quantizer and returns the input tensor.""", ) - num_bits: int | tuple[int, int] = ModeloptField( + num_bits: int | tuple[int, int] | list[int] = ModeloptField( default=8, title="An integer or a tuple of two integers specifying the number of quantization bits.", description="""`num_bits` can be: @@ -706,6 +829,16 @@ def _validate_recursive(value): _validate_recursive(values) return values + @field_validator("num_bits", mode="before") + @classmethod + def tuple_num_bits(cls, num_bits: int | list[int] | tuple[int, int]) -> int | tuple[int, int]: + """Convert num_bits to tuple if list.""" + if isinstance(num_bits, list): + assert len(num_bits) == 2 + return cast("tuple[int, int]", tuple(num_bits)) + + return num_bits + @model_validator(mode="after") def validate_num_bits(self): """Validate `num_bits`.""" @@ -861,6 +994,16 @@ def _get_block_quant_axes_and_sizes(block_sizes): if k not in ["type", "scale_bits", "scale_block_sizes"] } + @field_validator("block_sizes", mode="before") + @classmethod + def tuple_block_sizes_scale_bits(cls, v) -> int | tuple[int, int]: + """Convert num_bits to tuple if list.""" + if v and v.get("scale_bits"): + scale_bits = v.get("scale_bits") + if isinstance(scale_bits, list): + v["scale_bits"] = tuple(scale_bits) + return v + @field_validator("block_sizes") @classmethod def validate_block_sizes(cls, v, info: ValidationInfo): diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index d4cf5049d..7908ec514 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -173,7 +173,7 @@ def get_dataset_dataloader( batch_size: int = 1, num_samples: int | list[int] = 512, max_sample_length: int = 512, - device: str | None = None, + device: torch.device | None = None, include_labels: bool = False, ) -> DataLoader: """Get a dataloader with the dataset name and toknizer of the target model. @@ -264,7 +264,7 @@ def get_max_batch_size( model: torch.nn.Module, max_sample_length: int = 512, sample_memory_usage_ratio: float = 1.0, - sample_input_single_batch: torch.Tensor = None, + sample_input_single_batch: torch.Tensor | None = None, enable_grad: bool = False, ): """Get the maximum batch size that can be used for the model.""" diff --git a/modelopt/torch/utils/image_processor.py b/modelopt/torch/utils/image_processor.py index 87960d54d..6374642e3 100644 --- a/modelopt/torch/utils/image_processor.py +++ b/modelopt/torch/utils/image_processor.py @@ -22,7 +22,7 @@ class BaseImageProcessor: """Base class for image processors.""" - def __init__(self, tokenizer, device="auto"): + def __init__(self, tokenizer, device="cuda"): """Constructor.""" self.tokenizer = tokenizer self.device = device diff --git a/modelopt/torch/utils/speech_dataset_utils.py b/modelopt/torch/utils/speech_dataset_utils.py index 0d414f7ec..a71d73773 100644 --- a/modelopt/torch/utils/speech_dataset_utils.py +++ b/modelopt/torch/utils/speech_dataset_utils.py @@ -79,12 +79,12 @@ def get_supported_speech_datasets() -> list[str]: def get_speech_dataset_dataloader( dataset_name: str = "peoples_speech", - processor: WhisperProcessor = None, + processor: WhisperProcessor | None = None, batch_size: int = 1, num_samples: int = 512, - device: str | None = None, + device: torch.device | None = None, dtype: torch.dtype | None = None, -) -> DataLoader: +) -> tuple[DataLoader, str]: """Get a dataloader with the dataset name and processor of the target model. Args: diff --git a/setup.py b/setup.py index dd124e10e..a0015fce4 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,8 @@ "safetensors", "torch>=2.6", "torchprofile>=0.0.4", + "PyYAML>=6.0", + "omegaconf>=2.3.0", ] optional_deps = { @@ -132,5 +134,7 @@ extras_require=optional_deps, packages=setuptools.find_namespace_packages(include=["modelopt*"]), package_dir={"": "."}, - package_data={"modelopt": ["**/*.h", "**/*.cpp", "**/*.cu"]}, + package_data={ + "modelopt": ["**/*.h", "**/*.cpp", "**/*.cu", "config/**/*.yml", "config/**/*.yaml"], + }, )