Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions omlx/engine/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from ..api.tool_calling import convert_tools_for_template
from ..api.utils import clean_special_tokens
from ..utils.model_loading import load_text_model
from ..utils.tokenizer import get_tokenizer_config
from .base import BaseEngine, GenerationOutput

Expand Down Expand Up @@ -124,8 +125,6 @@ async def start(self) -> None:

import asyncio

from mlx_lm import load

from ..engine_core import AsyncEngineCore, EngineConfig
from ..scheduler import SchedulerConfig

Expand All @@ -140,7 +139,7 @@ async def start(self) -> None:
from ..engine_core import get_mlx_executor

def _load_model_sync():
return load(
return load_text_model(
self._model_name,
tokenizer_config=tokenizer_config,
)
Expand Down
6 changes: 2 additions & 4 deletions omlx/engine/vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
compute_image_hash,
extract_images_from_messages,
)
from ..utils.tokenizer import get_tokenizer_config
from ..utils.model_loading import load_vlm_model
from .base import BaseEngine, GenerationOutput

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -189,8 +189,6 @@ async def start(self) -> None:
if self._loaded:
return

from mlx_vlm.utils import load as vlm_load

from ..engine_core import AsyncEngineCore, EngineConfig
from ..scheduler import SchedulerConfig

Expand All @@ -203,7 +201,7 @@ def _load_vlm_sync():
# when torchvision is not available (extractors is None, `in` fails).
# oMLX does not support video input, so we skip video processing.
_patch_video_processor_bug()
return vlm_load(self._model_name)
return load_vlm_model(self._model_name)

loop = asyncio.get_running_loop()
self._vlm_model, self._processor = await loop.run_in_executor(
Expand Down
5 changes: 2 additions & 3 deletions omlx/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dataclasses import dataclass
from typing import Iterator

from ..utils.model_loading import load_text_model
from ..utils.tokenizer import get_tokenizer_config

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -73,8 +74,6 @@ def load(self) -> None:
return

try:
from mlx_lm import load

logger.info(f"Loading model: {self.model_name}")

# Build tokenizer config with model-specific fixes
Expand All @@ -83,7 +82,7 @@ def load(self) -> None:
trust_remote_code=self.trust_remote_code,
)

self.model, self.tokenizer = load(
self.model, self.tokenizer = load_text_model(
self.model_name,
tokenizer_config=tokenizer_config,
)
Expand Down
7 changes: 3 additions & 4 deletions omlx/models/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CAUSAL_LM_RERANKER_ARCHITECTURES,
SUPPORTED_RERANKER_ARCHITECTURES,
)
from ..utils.model_loading import load_text_model

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -147,11 +148,9 @@ def _load_xlm_roberta(self) -> Tuple[Any, Any]:
return model, tokenizer

def _load_causal_lm(self) -> Tuple[Any, Any]:
"""Load a CausalLM-based reranker model using mlx-lm."""
from mlx_lm import load as mlx_lm_load

"""Load a CausalLM-based reranker model."""
model_path = str(self.model_name)
model, tokenizer_wrapper = mlx_lm_load(model_path)
model, tokenizer_wrapper = load_text_model(model_path)

# mlx-lm returns a TokenizerWrapper; unwrap to get the underlying
# transformers tokenizer which supports __call__ for batch encoding.
Expand Down
3 changes: 3 additions & 0 deletions omlx/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

from .tokenizer import get_tokenizer_config, apply_qwen3_fix
from .model_loading import load_text_model, load_vlm_model
from .formatting import format_bytes as format_bytes_util
from .hardware import (
HardwareInfo,
Expand All @@ -29,6 +30,8 @@
# Tokenizer utilities
"get_tokenizer_config",
"apply_qwen3_fix",
"load_text_model",
"load_vlm_model",
# Hardware utilities
"HardwareInfo",
"detect_hardware",
Expand Down
157 changes: 157 additions & 0 deletions omlx/utils/model_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# SPDX-License-Identifier: Apache-2.0
"""Model loading helpers with pluggable custom quantization loading."""

from __future__ import annotations

import json
import logging
from collections.abc import Callable
from pathlib import Path
from typing import Any

logger = logging.getLogger(__name__)

_CustomQuantizationLoader = Callable[[str, bool], tuple[Any, Any, bool]]


def _read_config(model_name: str) -> dict[str, Any] | None:
"""Read local config.json for a model path, if available."""
config_path = Path(model_name) / "config.json"
if not config_path.is_file():
return None

try:
with open(config_path) as f:
data = json.load(f)
return data if isinstance(data, dict) else None
except (json.JSONDecodeError, OSError, TypeError):
return None


def _detect_custom_quantization_format(model_name: str) -> str | None:
"""Return lowercased custom quantization format from local config.json."""
config = _read_config(model_name)
if config is None:
return None

qconfig = config.get("quantization_config")
if not isinstance(qconfig, dict):
return None

quant_method = qconfig.get("quant_method")
if not isinstance(quant_method, str):
return None

normalized = quant_method.strip().lower()
return normalized or None


def _load_mlx_lm_text_model(
model_name: str,
tokenizer_config: dict[str, Any] | None = None,
):
from mlx_lm import load

return load(model_name, tokenizer_config=tokenizer_config)


def _load_mlx_vlm_model(model_name: str):
from mlx_vlm.utils import load as vlm_load

return vlm_load(model_name)


def _load_paroquant_mlx_model(
model_name: str,
force_text: bool,
):
try:
from paroquant.inference.backends.mlx.load import load as paroquant_load
except ImportError as e:
raise ImportError(
"ParoQuant model detected, but paroquant is not installed. "
"Install with: pip install 'paroquant[mlx]'"
) from e

return paroquant_load(model_name, force_text=force_text)


_CUSTOM_QUANTIZATION_LOADERS: dict[str, _CustomQuantizationLoader] = {
"paroquant": _load_paroquant_mlx_model,
}


def _load_via_custom_quantization_loader(
model_name: str,
quantization_format: str,
force_text: bool,
) -> tuple[Any, Any, bool] | None:
loader = _CUSTOM_QUANTIZATION_LOADERS.get(quantization_format)
if loader is None:
return None

logger.info(
f"Detected custom quantization format "
f"'{quantization_format}': {model_name}"
)
return loader(model_name, force_text)


def load_text_model(
model_name: str,
tokenizer_config: dict[str, Any] | None = None,
):
"""Load an LLM model/tokenizer pair, with custom quantization loading."""
quantization_format = _detect_custom_quantization_format(model_name)
if quantization_format is None:
return _load_mlx_lm_text_model(
model_name=model_name,
tokenizer_config=tokenizer_config,
)

loaded = _load_via_custom_quantization_loader(
model_name=model_name,
quantization_format=quantization_format,
force_text=True,
)
if loaded is None:
logger.warning(
f"Custom quantization format '{quantization_format}' "
"is not registered; "
"falling back to mlx-lm loader."
)
return _load_mlx_lm_text_model(
model_name=model_name,
tokenizer_config=tokenizer_config,
)

model, processor, _ = loaded
return model, getattr(processor, "tokenizer", processor)


def load_vlm_model(model_name: str):
"""Load a VLM model/processor pair, with custom quantization loading."""
quantization_format = _detect_custom_quantization_format(model_name)
if quantization_format is None:
return _load_mlx_vlm_model(model_name)

loaded = _load_via_custom_quantization_loader(
model_name=model_name,
quantization_format=quantization_format,
force_text=False,
)
if loaded is None:
logger.warning(
f"Custom quantization format '{quantization_format}' "
"is not registered; "
"falling back to mlx-vlm loader."
)
return _load_mlx_vlm_model(model_name)

model, processor, is_vlm = loaded
if not is_vlm:
raise RuntimeError(
f"Model '{model_name}' is marked as custom quantization format "
f"'{quantization_format}' but is not a VLM model."
)
return model, processor