From 6108925ab96eeca6f7b96b30c4781486a2f68af6 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Mon, 18 May 2026 05:30:18 +0000 Subject: [PATCH] switch to reload api Signed-off-by: SumanthRH --- skyrl/backends/skyrl_train/inference_servers/vllm_worker.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/skyrl/backends/skyrl_train/inference_servers/vllm_worker.py b/skyrl/backends/skyrl_train/inference_servers/vllm_worker.py index 8249b30a7f..5bc92cb24f 100644 --- a/skyrl/backends/skyrl_train/inference_servers/vllm_worker.py +++ b/skyrl/backends/skyrl_train/inference_servers/vllm_worker.py @@ -82,6 +82,8 @@ def load_weights(self, request: bytes) -> None: """ import pickle + from vllm.config import set_current_vllm_config + # Unpickle request to restore the original object type assert isinstance(request, bytes), f"Expected bytes, got {type(request).__name__}" request = pickle.loads(request) @@ -90,7 +92,8 @@ def load_weights(self, request: bytes) -> None: for name, tensor in self._weight_receiver.receive_weights(request): weight_list.append((name, tensor)) - self.model_runner.model.load_weights(weights=weight_list) + with torch.device(self.device), set_current_vllm_config(self.vllm_config): + self.model_runner.reload_weights(weights_iterator=iter(weight_list)) for weight in weight_list: del weight