Skip to content

Commit

Permalink
Fix loading multiple models (#99)
Browse files Browse the repository at this point in the history
* fix loading multiple models

Signed-off-by: Vibhu Jawa <[email protected]>

* Move to class-id as thats actually unique

---------

Signed-off-by: Vibhu Jawa <[email protected]>
Co-authored-by: Vibhu Jawa <[email protected]>
  • Loading branch information
VibhuJawa and Vibhu Jawa authored Nov 10, 2024
1 parent 82f232f commit 14d6088
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 15 deletions.
12 changes: 6 additions & 6 deletions crossfit/backend/torch/hf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ def __init__(
)

def load_on_worker(self, worker, device="cuda"):
worker.torch_model = self.load_model(device)
worker.cfg = self.load_cfg()
setattr(worker, f"torch_model_{id(self)}", self.load_model(device))
setattr(worker, f"cfg_{id(self)}", self.load_cfg())

def unload_from_worker(self, worker):
if hasattr(worker, "torch_model"):
delattr(worker, "torch_model")
if hasattr(worker, "cfg"):
delattr(worker, "cfg")
if hasattr(worker, f"torch_model_{id(self)}"):
delattr(worker, f"torch_model_{id(self)}")
if hasattr(worker, f"cfg_{id(self)}"):
delattr(worker, f"cfg_{id(self)}")
cleanup_torch_cache()

def load_model(self, device="cuda"):
Expand Down
8 changes: 3 additions & 5 deletions crossfit/backend/torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,12 @@ def unload_from_worker(self, worker):
raise NotImplementedError()

def call_on_worker(self, worker, *args, **kwargs):
return worker.torch_model(*args, **kwargs)
return getattr(worker, f"torch_model_{id(self)}")(*args, **kwargs)

def get_model(self, worker):
# TODO: We should not hard code the attribute name
# to torch_model. We should use the path_or_name_model
if not hasattr(worker, "torch_model"):
if not hasattr(worker, f"torch_model_{id(self)}"):
self.load_on_worker(worker)
return worker.torch_model
return getattr(worker, f"torch_model_{id(self)}")

def estimate_memory(self, max_num_tokens: int, batch_size: int) -> int:
raise NotImplementedError()
Expand Down
55 changes: 51 additions & 4 deletions tests/backend/pytorch_backend/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ def mock_worker(self):
def test_unload_from_worker(self, model, mock_worker):
model.load_on_worker(mock_worker)

assert hasattr(mock_worker, "torch_model")
assert hasattr(mock_worker, "cfg")
assert hasattr(mock_worker, f"torch_model_{id(model)}")
assert hasattr(mock_worker, f"cfg_{id(model)}")

model.unload_from_worker(mock_worker)

assert not hasattr(mock_worker, "torch_model")
assert not hasattr(mock_worker, "cfg")
assert not hasattr(mock_worker, f"torch_model_{id(model)}")
assert not hasattr(mock_worker, f"cfg_{id(model)}")


class DummyModelWithDictOutput(torch.nn.Module):
Expand Down Expand Up @@ -144,3 +144,50 @@ def test_meta_invalid_model_output_type(self):
):
predictor = cf.op.Predictor(model=self.model_string, model_output_cols=["a", "b"])
predictor.meta()


class DummyHFModel_WithOutputValue(HFModel):
def __init__(self, model_name, output_value):
self.model_name = model_name
self.output_value = output_value
super().__init__(model_name)

def load_model(self, device="cuda"):
class DummyModel(torch.nn.Module):
def __init__(self, output_value):
super().__init__()
self.output_value = output_value

def forward(self, batch):
output_size = len(batch["input_ids"])
return torch.ones(output_size, device="cuda") * self.output_value

return DummyModel(output_value=self.output_value).to(device)


def test_loading_multiple_models():
cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES="0")
client = Client(cluster)

ddf = dask_cudf.from_cudf(cudf.DataFrame({"text": ["apple"] * 6}), npartitions=1)
model_1 = DummyHFModel_WithOutputValue("microsoft/deberta-v3-base", 1)
model_2 = DummyHFModel_WithOutputValue("microsoft/deberta-v3-base", 2)

pipe_1 = cf.op.Sequential(
cf.op.Tokenizer(model_1, cols=["text"], tokenizer_type="sentencepiece"),
cf.op.Predictor(model_1, sorted_data_loader=False, batch_size=2, pred_output_col="pred_1"),
keep_cols=list(ddf.columns),
)
output_1_ddf = pipe_1(ddf)
pipe_2 = cf.op.Sequential(
cf.op.Tokenizer(model_2, cols=["text"], tokenizer_type="sentencepiece"),
cf.op.Predictor(model_2, sorted_data_loader=False, batch_size=2, pred_output_col="pred_2"),
keep_cols=list(output_1_ddf.columns),
)
output_2_ddf = pipe_2(output_1_ddf)
output_2_df = output_2_ddf.to_backend("pandas").compute()
assert output_2_df["pred_1"].values.tolist() == [1, 1, 1, 1, 1, 1]
assert output_2_df["pred_2"].values.tolist() == [2, 2, 2, 2, 2, 2]

del client
del cluster

0 comments on commit 14d6088

Please sign in to comment.