Skip to content
39 changes: 37 additions & 2 deletions pymllm/executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,20 +600,55 @@ def load_model(self) -> None:
quant_config = self._resolve_quant_config()

device_str = f"cuda:{self.gpu_id}" if self.device == "cuda" else self.device
requires_cpu_first_loading = bool(
getattr(model_cls, "requires_cpu_first_weight_loading", False)
)
if not requires_cpu_first_loading:
import os

gemma3n_native_mode = (
architecture in {"Gemma3nForCausalLM", "Gemma3nForConditionalGeneration"}
and os.environ.get("MLLM_GEMMA3N_NATIVE", "0") == "1"
)
requires_cpu_first_loading = gemma3n_native_mode

instantiate_device_str = "cpu" if requires_cpu_first_loading else device_str
if instantiate_device_str != device_str:
logger.info(
"Using CPU-first model instantiation for %s before weight loading. "
"runtime_device=%s",
model_cls.__name__,
device_str,
)

# Use set_default_dtype so parameters created without explicit dtype
# get the target dtype, while parameters with explicit dtype=torch.float32
# (e.g. A_log, dt_bias in GDN layers) stay in float32.
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(self.dtype)
try:
with torch.device(device_str):
with torch.device(instantiate_device_str):
if quant_config is not None:
self.model = model_cls(hf_config, quant_config=quant_config)
else:
self.model = model_cls(hf_config)
finally:
torch.set_default_dtype(old_dtype)
self.model.load_weights(self._iter_weights(model_path))
use_model_path_loader = hasattr(self.model, "load_weights_from_model_path")
if use_model_path_loader:
loader_flag = getattr(self.model, "use_model_path_weight_loader", True)
if callable(loader_flag):
loader_flag = loader_flag()
use_model_path_loader = bool(loader_flag)

if use_model_path_loader:
logger.info(
"Using model-specific weight loader: %s.load_weights_from_model_path",
type(self.model).__name__,
)
self.model.load_weights_from_model_path(model_path)
else:
self.model.load_weights(self._iter_weights(model_path))

# Post-load processing: let each quantization method repack/transform
# weights from checkpoint format to runtime format (e.g. AWQ → Marlin,
Expand Down
8 changes: 8 additions & 0 deletions pymllm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@
"pymllm.models.qwen3_5",
"Qwen3_5ForConditionalGeneration",
),
"Gemma3nForCausalLM": (
"pymllm.models.gemma3n",
"Gemma3nForCausalLM",
),
"Gemma3nForConditionalGeneration": (
"pymllm.models.gemma3n",
"Gemma3nForConditionalGeneration",
),
}


Expand Down
Loading