Skip to content
39 changes: 37 additions & 2 deletions pymllm/executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,20 +600,46 @@ def load_model(self) -> None:
quant_config = self._resolve_quant_config()

device_str = f"cuda:{self.gpu_id}" if self.device == "cuda" else self.device
requires_cpu_first_loading = bool(
getattr(model_cls, "requires_cpu_first_weight_loading", False)
)
instantiate_device_str = "cpu" if requires_cpu_first_loading else device_str
if instantiate_device_str != device_str:
logger.info(
"Using CPU-first model instantiation for %s before weight loading. "
"runtime_device=%s",
model_cls.__name__,
device_str,
)

# Use set_default_dtype so parameters created without explicit dtype
# get the target dtype, while parameters with explicit dtype=torch.float32
# (e.g. A_log, dt_bias in GDN layers) stay in float32.
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(self.dtype)
try:
with torch.device(device_str):
with torch.device(instantiate_device_str):
if quant_config is not None:
self.model = model_cls(hf_config, quant_config=quant_config)
else:
self.model = model_cls(hf_config)
finally:
torch.set_default_dtype(old_dtype)
self.model.load_weights(self._iter_weights(model_path))
use_model_path_loader = hasattr(self.model, "load_weights_from_model_path")
if use_model_path_loader:
loader_flag = getattr(self.model, "use_model_path_weight_loader", True)
if callable(loader_flag):
loader_flag = loader_flag()
use_model_path_loader = bool(loader_flag)

if use_model_path_loader:
logger.info(
"Using model-specific weight loader: %s.load_weights_from_model_path",
type(self.model).__name__,
)
self.model.load_weights_from_model_path(model_path)
else:
self.model.load_weights(self._iter_weights(model_path))

# Post-load processing: let each quantization method repack/transform
# weights from checkpoint format to runtime format (e.g. AWQ → Marlin,
Expand All @@ -623,6 +649,14 @@ def load_model(self) -> None:
if quant_method is not None and hasattr(quant_method, "process_weights_after_loading"):
quant_method.process_weights_after_loading(module)

move_compute_modules = getattr(self.model, "move_compute_modules_to_device", None)
if callable(move_compute_modules) and self.device == "cuda":
logger.info(
"Moving model compute modules to runtime device after weight loading: %s",
device_str,
)
move_compute_modules(torch.device(device_str))

self.model.eval()

after_mem = get_available_gpu_memory(self.device, self.gpu_id)
Expand Down Expand Up @@ -893,6 +927,7 @@ def init_attention_backend(self) -> None:
req_to_token=self.req_to_token_pool.req_to_token,
device=torch.device(self.device),
max_req_pool_size=self.req_to_token_pool.size,
sliding_window_size=self.sliding_window_size,
)

if self.gdn_pool is not None:
Expand Down
8 changes: 8 additions & 0 deletions pymllm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@
"pymllm.models.qwen3_5",
"Qwen3_5ForConditionalGeneration",
),
"Gemma3nForCausalLM": (
"pymllm.models.gemma3n",
"Gemma3nForCausalLM",
),
"Gemma3nForConditionalGeneration": (
"pymllm.models.gemma3n",
"Gemma3nForConditionalGeneration",
),
}


Expand Down
Loading