Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 58 additions & 45 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,30 @@
import sys
import warnings
from pathlib import Path
from typing import Any, cast

import torch
import transformers
from accelerate import infer_auto_device_map, init_empty_weights
from accelerate.utils import get_max_memory
from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
from omegaconf import OmegaConf
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
PreTrainedTokenizerBase,
ProcessorMixin,
)

try:
from huggingface_hub import snapshot_download
except ImportError:
snapshot_download = None

import modelopt.torch.quantization as mtq
from modelopt.torch.utils.image_processor import MllamaImageProcessor
from modelopt.torch.quantization.config import load_quant_config
from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor

SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"]

Expand Down Expand Up @@ -127,52 +137,49 @@ def build_quant_cfg(
qformat,
kv_cache_qformat,
awq_block_size,
auto_quantize,
model_type,
quant_cfg_choices,
kv_quant_cfg_choices,
):
quant_cfg = {}
if not auto_quantize:
assert qformat in quant_cfg_choices, (
f"Unsupported quantization format: {qformat} with {kv_cache_qformat} KV cache"
)

) -> dict[str, Any]:
if qformat in quant_cfg_choices:
quant_cfg = quant_cfg_choices[qformat]
else:
quant_cfg = OmegaConf.to_container(load_quant_config(qformat))

quant_cfg = cast("dict[str, Any]", quant_cfg)
if "awq" in quant_cfg.get("algorithm"):
quant_cfg = copy.deepcopy(quant_cfg)
weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
if isinstance(weight_quantizer, list):
weight_quantizer = weight_quantizer[0]
# If awq_block_size argument is provided, update weight_quantizer
if awq_block_size:
weight_quantizer["block_sizes"][-1] = awq_block_size

# Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models
if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]:
quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1}

enable_quant_kv_cache = kv_cache_qformat != "none"
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")

# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
if enable_quant_kv_cache:
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
quant_cfg,
getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"],
)

if "awq" in qformat:
quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])
weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
if isinstance(weight_quantizer, list):
weight_quantizer = weight_quantizer[0]
# If awq_block_size argument is provided, update weight_quantizer
if awq_block_size:
weight_quantizer["block_sizes"][-1] = awq_block_size

# Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models
if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]:
quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1}

enable_quant_kv_cache = kv_cache_qformat != "none"
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")

# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
if enable_quant_kv_cache:
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
quant_cfg,
getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"],
)

# Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead.
if model_type == "gemma" and "int8_sq" in qformat:
quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}
# Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead.
if model_type == "gemma" and "int8_sq" in qformat:
quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}

if model_type == "phi4mm":
# Only quantize the language model
quant_cfg["quant_cfg"]["*speech*"] = {"enable": False}
quant_cfg["quant_cfg"]["*audio*"] = {"enable": False}
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
if model_type == "phi4mm":
# Only quantize the language model
quant_cfg["quant_cfg"]["*speech*"] = {"enable": False}
quant_cfg["quant_cfg"]["*audio*"] = {"enable": False}
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}

return quant_cfg

Expand All @@ -184,7 +191,7 @@ def is_speculative(hf_config):
)


def get_tokenizer(ckpt_path, trust_remote_code=False, **kwargs):
def get_tokenizer(ckpt_path, trust_remote_code=False, **kwargs) -> PreTrainedTokenizerBase:
print(f"Initializing tokenizer from {ckpt_path}")

if "vila" in ckpt_path.lower():
Expand All @@ -205,8 +212,12 @@ def get_tokenizer(ckpt_path, trust_remote_code=False, **kwargs):


def get_processor(
ckpt_path, model_type, device=None, trust_remote_code=False, attn_implementation=None
):
ckpt_path,
model_type,
device: torch.device = "auto",
trust_remote_code=False,
attn_implementation=None,
) -> BaseImageProcessor | ProcessorMixin | None:
"""
Returns a :class:`modelopt.torch.utils.image_processor.MllamaImageProcessor` object.
"""
Expand Down Expand Up @@ -241,6 +252,8 @@ def get_processor(

return MllamaImageProcessor(processor, device)

return None


def get_dtype(dtype):
if dtype == "bf16":
Expand Down
Loading
Loading