Skip to content

Commit

Permalink
fix loading multiple models
Browse files Browse the repository at this point in the history
Signed-off-by: Vibhu Jawa <[email protected]>
  • Loading branch information
Vibhu Jawa committed Nov 6, 2024
1 parent 82f232f commit db36fc3
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 15 deletions.
12 changes: 6 additions & 6 deletions crossfit/backend/torch/hf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ def __init__(
)

def load_on_worker(self, worker, device="cuda"):
worker.torch_model = self.load_model(device)
worker.cfg = self.load_cfg()
setattr(worker, f"torch_model_{self.path_or_name}", self.load_model(device))
setattr(worker, f"cfg_{self.path_or_name}", self.load_cfg())

def unload_from_worker(self, worker):
if hasattr(worker, "torch_model"):
delattr(worker, "torch_model")
if hasattr(worker, "cfg"):
delattr(worker, "cfg")
if hasattr(worker, f"torch_model_{self.path_or_name}"):
delattr(worker, f"torch_model_{self.path_or_name}")
if hasattr(worker, f"cfg_{self.path_or_name}"):
delattr(worker, f"cfg_{self.path_or_name}")
cleanup_torch_cache()

def load_model(self, device="cuda"):
Expand Down
8 changes: 3 additions & 5 deletions crossfit/backend/torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,12 @@ def unload_from_worker(self, worker):
raise NotImplementedError()

def call_on_worker(self, worker, *args, **kwargs):
return worker.torch_model(*args, **kwargs)
return getattr(worker, f"torch_model_{self.path_or_name}")(*args, **kwargs)

def get_model(self, worker):
# TODO: We should not hard code the attribute name
# to torch_model. We should use the path_or_name_model
if not hasattr(worker, "torch_model"):
if not hasattr(worker, f"torch_model_{self.path_or_name}"):
self.load_on_worker(worker)
return worker.torch_model
return getattr(worker, f"torch_model_{self.path_or_name}")

def estimate_memory(self, max_num_tokens: int, batch_size: int) -> int:
raise NotImplementedError()
Expand Down
8 changes: 4 additions & 4 deletions tests/backend/pytorch_backend/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ def mock_worker(self):
def test_unload_from_worker(self, model, mock_worker):
model.load_on_worker(mock_worker)

assert hasattr(mock_worker, "torch_model")
assert hasattr(mock_worker, "cfg")
assert hasattr(mock_worker, f"torch_model_{model.path_or_name}")
assert hasattr(mock_worker, f"cfg_{model.path_or_name}")

model.unload_from_worker(mock_worker)

assert not hasattr(mock_worker, "torch_model")
assert not hasattr(mock_worker, "cfg")
assert not hasattr(mock_worker, f"torch_model_{model.path_or_name}")
assert not hasattr(mock_worker, f"cfg_{model.path_or_name}")


class DummyModelWithDictOutput(torch.nn.Module):
Expand Down

0 comments on commit db36fc3

Please sign in to comment.