From 14508e6d9d997967f2f0f8e3cc507194b64b04d6 Mon Sep 17 00:00:00 2001 From: Yesheng Liang Date: Fri, 13 Mar 2026 17:57:57 +0800 Subject: [PATCH] Add support for ParoQuant and custom quantization method loading --- omlx/engine/batched.py | 5 +- omlx/engine/vlm.py | 6 +- omlx/models/llm.py | 5 +- omlx/models/reranker.py | 7 +- omlx/utils/__init__.py | 3 + omlx/utils/model_loading.py | 157 ++++++++++++++++++++++++++++++++++++ 6 files changed, 169 insertions(+), 14 deletions(-) create mode 100644 omlx/utils/model_loading.py diff --git a/omlx/engine/batched.py b/omlx/engine/batched.py index 8bad1197..4869b66f 100644 --- a/omlx/engine/batched.py +++ b/omlx/engine/batched.py @@ -13,6 +13,7 @@ from ..api.tool_calling import convert_tools_for_template from ..api.utils import clean_special_tokens +from ..utils.model_loading import load_text_model from ..utils.tokenizer import get_tokenizer_config from .base import BaseEngine, GenerationOutput @@ -124,8 +125,6 @@ async def start(self) -> None: import asyncio - from mlx_lm import load - from ..engine_core import AsyncEngineCore, EngineConfig from ..scheduler import SchedulerConfig @@ -140,7 +139,7 @@ async def start(self) -> None: from ..engine_core import get_mlx_executor def _load_model_sync(): - return load( + return load_text_model( self._model_name, tokenizer_config=tokenizer_config, ) diff --git a/omlx/engine/vlm.py b/omlx/engine/vlm.py index c6ae4588..e55d6f9a 100644 --- a/omlx/engine/vlm.py +++ b/omlx/engine/vlm.py @@ -38,7 +38,7 @@ compute_image_hash, extract_images_from_messages, ) -from ..utils.tokenizer import get_tokenizer_config +from ..utils.model_loading import load_vlm_model from .base import BaseEngine, GenerationOutput logger = logging.getLogger(__name__) @@ -189,8 +189,6 @@ async def start(self) -> None: if self._loaded: return - from mlx_vlm.utils import load as vlm_load - from ..engine_core import AsyncEngineCore, EngineConfig from ..scheduler import SchedulerConfig @@ -203,7 +201,7 @@ def _load_vlm_sync(): # when torchvision is not available (extractors is None, `in` fails). # oMLX does not support video input, so we skip video processing. _patch_video_processor_bug() - return vlm_load(self._model_name) + return load_vlm_model(self._model_name) loop = asyncio.get_running_loop() self._vlm_model, self._processor = await loop.run_in_executor( diff --git a/omlx/models/llm.py b/omlx/models/llm.py index 69b5142c..fa861ccc 100644 --- a/omlx/models/llm.py +++ b/omlx/models/llm.py @@ -10,6 +10,7 @@ from dataclasses import dataclass from typing import Iterator +from ..utils.model_loading import load_text_model from ..utils.tokenizer import get_tokenizer_config logger = logging.getLogger(__name__) @@ -73,8 +74,6 @@ def load(self) -> None: return try: - from mlx_lm import load - logger.info(f"Loading model: {self.model_name}") # Build tokenizer config with model-specific fixes @@ -83,7 +82,7 @@ def load(self) -> None: trust_remote_code=self.trust_remote_code, ) - self.model, self.tokenizer = load( + self.model, self.tokenizer = load_text_model( self.model_name, tokenizer_config=tokenizer_config, ) diff --git a/omlx/models/reranker.py b/omlx/models/reranker.py index 1902cfe8..4ca90294 100644 --- a/omlx/models/reranker.py +++ b/omlx/models/reranker.py @@ -21,6 +21,7 @@ CAUSAL_LM_RERANKER_ARCHITECTURES, SUPPORTED_RERANKER_ARCHITECTURES, ) +from ..utils.model_loading import load_text_model logger = logging.getLogger(__name__) @@ -147,11 +148,9 @@ def _load_xlm_roberta(self) -> Tuple[Any, Any]: return model, tokenizer def _load_causal_lm(self) -> Tuple[Any, Any]: - """Load a CausalLM-based reranker model using mlx-lm.""" - from mlx_lm import load as mlx_lm_load - + """Load a CausalLM-based reranker model.""" model_path = str(self.model_name) - model, tokenizer_wrapper = mlx_lm_load(model_path) + model, tokenizer_wrapper = load_text_model(model_path) # mlx-lm returns a TokenizerWrapper; unwrap to get the underlying # transformers tokenizer which supports __call__ for batch encoding. diff --git a/omlx/utils/__init__.py b/omlx/utils/__init__.py index 9e02e0c1..3ce0b009 100644 --- a/omlx/utils/__init__.py +++ b/omlx/utils/__init__.py @@ -7,6 +7,7 @@ """ from .tokenizer import get_tokenizer_config, apply_qwen3_fix +from .model_loading import load_text_model, load_vlm_model from .formatting import format_bytes as format_bytes_util from .hardware import ( HardwareInfo, @@ -29,6 +30,8 @@ # Tokenizer utilities "get_tokenizer_config", "apply_qwen3_fix", + "load_text_model", + "load_vlm_model", # Hardware utilities "HardwareInfo", "detect_hardware", diff --git a/omlx/utils/model_loading.py b/omlx/utils/model_loading.py new file mode 100644 index 00000000..75cf3fe7 --- /dev/null +++ b/omlx/utils/model_loading.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Model loading helpers with pluggable custom quantization loading.""" + +from __future__ import annotations + +import json +import logging +from collections.abc import Callable +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + +_CustomQuantizationLoader = Callable[[str, bool], tuple[Any, Any, bool]] + + +def _read_config(model_name: str) -> dict[str, Any] | None: + """Read local config.json for a model path, if available.""" + config_path = Path(model_name) / "config.json" + if not config_path.is_file(): + return None + + try: + with open(config_path) as f: + data = json.load(f) + return data if isinstance(data, dict) else None + except (json.JSONDecodeError, OSError, TypeError): + return None + + +def _detect_custom_quantization_format(model_name: str) -> str | None: + """Return lowercased custom quantization format from local config.json.""" + config = _read_config(model_name) + if config is None: + return None + + qconfig = config.get("quantization_config") + if not isinstance(qconfig, dict): + return None + + quant_method = qconfig.get("quant_method") + if not isinstance(quant_method, str): + return None + + normalized = quant_method.strip().lower() + return normalized or None + + +def _load_mlx_lm_text_model( + model_name: str, + tokenizer_config: dict[str, Any] | None = None, +): + from mlx_lm import load + + return load(model_name, tokenizer_config=tokenizer_config) + + +def _load_mlx_vlm_model(model_name: str): + from mlx_vlm.utils import load as vlm_load + + return vlm_load(model_name) + + +def _load_paroquant_mlx_model( + model_name: str, + force_text: bool, +): + try: + from paroquant.inference.backends.mlx.load import load as paroquant_load + except ImportError as e: + raise ImportError( + "ParoQuant model detected, but paroquant is not installed. " + "Install with: pip install 'paroquant[mlx]'" + ) from e + + return paroquant_load(model_name, force_text=force_text) + + +_CUSTOM_QUANTIZATION_LOADERS: dict[str, _CustomQuantizationLoader] = { + "paroquant": _load_paroquant_mlx_model, +} + + +def _load_via_custom_quantization_loader( + model_name: str, + quantization_format: str, + force_text: bool, +) -> tuple[Any, Any, bool] | None: + loader = _CUSTOM_QUANTIZATION_LOADERS.get(quantization_format) + if loader is None: + return None + + logger.info( + f"Detected custom quantization format " + f"'{quantization_format}': {model_name}" + ) + return loader(model_name, force_text) + + +def load_text_model( + model_name: str, + tokenizer_config: dict[str, Any] | None = None, +): + """Load an LLM model/tokenizer pair, with custom quantization loading.""" + quantization_format = _detect_custom_quantization_format(model_name) + if quantization_format is None: + return _load_mlx_lm_text_model( + model_name=model_name, + tokenizer_config=tokenizer_config, + ) + + loaded = _load_via_custom_quantization_loader( + model_name=model_name, + quantization_format=quantization_format, + force_text=True, + ) + if loaded is None: + logger.warning( + f"Custom quantization format '{quantization_format}' " + "is not registered; " + "falling back to mlx-lm loader." + ) + return _load_mlx_lm_text_model( + model_name=model_name, + tokenizer_config=tokenizer_config, + ) + + model, processor, _ = loaded + return model, getattr(processor, "tokenizer", processor) + + +def load_vlm_model(model_name: str): + """Load a VLM model/processor pair, with custom quantization loading.""" + quantization_format = _detect_custom_quantization_format(model_name) + if quantization_format is None: + return _load_mlx_vlm_model(model_name) + + loaded = _load_via_custom_quantization_loader( + model_name=model_name, + quantization_format=quantization_format, + force_text=False, + ) + if loaded is None: + logger.warning( + f"Custom quantization format '{quantization_format}' " + "is not registered; " + "falling back to mlx-vlm loader." + ) + return _load_mlx_vlm_model(model_name) + + model, processor, is_vlm = loaded + if not is_vlm: + raise RuntimeError( + f"Model '{model_name}' is marked as custom quantization format " + f"'{quantization_format}' but is not a VLM model." + ) + return model, processor