From cbb9a9a78b25ffce65f71bd5479cd97c9de25424 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Wed, 6 Nov 2024 15:09:59 -0800 Subject: [PATCH] Move to class-id as thats actually unique --- crossfit/backend/torch/hf/model.py | 12 ++-- crossfit/backend/torch/model.py | 6 +- .../backend/pytorch_backend/test_torch_ops.py | 55 +++++++++++++++++-- 3 files changed, 60 insertions(+), 13 deletions(-) diff --git a/crossfit/backend/torch/hf/model.py b/crossfit/backend/torch/hf/model.py index 174d62d..a930fad 100644 --- a/crossfit/backend/torch/hf/model.py +++ b/crossfit/backend/torch/hf/model.py @@ -81,14 +81,14 @@ def __init__( ) def load_on_worker(self, worker, device="cuda"): - setattr(worker, f"torch_model_{self.path_or_name}", self.load_model(device)) - setattr(worker, f"cfg_{self.path_or_name}", self.load_cfg()) + setattr(worker, f"torch_model_{id(self)}", self.load_model(device)) + setattr(worker, f"cfg_{id(self)}", self.load_cfg()) def unload_from_worker(self, worker): - if hasattr(worker, f"torch_model_{self.path_or_name}"): - delattr(worker, f"torch_model_{self.path_or_name}") - if hasattr(worker, f"cfg_{self.path_or_name}"): - delattr(worker, f"cfg_{self.path_or_name}") + if hasattr(worker, f"torch_model_{id(self)}"): + delattr(worker, f"torch_model_{id(self)}") + if hasattr(worker, f"cfg_{id(self)}"): + delattr(worker, f"cfg_{id(self)}") cleanup_torch_cache() def load_model(self, device="cuda"): diff --git a/crossfit/backend/torch/model.py b/crossfit/backend/torch/model.py index eaeaebb..fafe7eb 100644 --- a/crossfit/backend/torch/model.py +++ b/crossfit/backend/torch/model.py @@ -64,12 +64,12 @@ def unload_from_worker(self, worker): raise NotImplementedError() def call_on_worker(self, worker, *args, **kwargs): - return getattr(worker, f"torch_model_{self.path_or_name}")(*args, **kwargs) + return getattr(worker, f"torch_model_{id(self)}")(*args, **kwargs) def get_model(self, worker): - if not hasattr(worker, f"torch_model_{self.path_or_name}"): + if not hasattr(worker, f"torch_model_{id(self)}"): self.load_on_worker(worker) - return getattr(worker, f"torch_model_{self.path_or_name}") + return getattr(worker, f"torch_model_{id(self)}") def estimate_memory(self, max_num_tokens: int, batch_size: int) -> int: raise NotImplementedError() diff --git a/tests/backend/pytorch_backend/test_torch_ops.py b/tests/backend/pytorch_backend/test_torch_ops.py index 81a4cf5..3d756c7 100644 --- a/tests/backend/pytorch_backend/test_torch_ops.py +++ b/tests/backend/pytorch_backend/test_torch_ops.py @@ -39,13 +39,13 @@ def mock_worker(self): def test_unload_from_worker(self, model, mock_worker): model.load_on_worker(mock_worker) - assert hasattr(mock_worker, f"torch_model_{model.path_or_name}") - assert hasattr(mock_worker, f"cfg_{model.path_or_name}") + assert hasattr(mock_worker, f"torch_model_{id(model)}") + assert hasattr(mock_worker, f"cfg_{id(model)}") model.unload_from_worker(mock_worker) - assert not hasattr(mock_worker, f"torch_model_{model.path_or_name}") - assert not hasattr(mock_worker, f"cfg_{model.path_or_name}") + assert not hasattr(mock_worker, f"torch_model_{id(model)}") + assert not hasattr(mock_worker, f"cfg_{id(model)}") class DummyModelWithDictOutput(torch.nn.Module): @@ -144,3 +144,50 @@ def test_meta_invalid_model_output_type(self): ): predictor = cf.op.Predictor(model=self.model_string, model_output_cols=["a", "b"]) predictor.meta() + + +class DummyHFModel_WithOutputValue(HFModel): + def __init__(self, model_name, output_value): + self.model_name = model_name + self.output_value = output_value + super().__init__(model_name) + + def load_model(self, device="cuda"): + class DummyModel(torch.nn.Module): + def __init__(self, output_value): + super().__init__() + self.output_value = output_value + + def forward(self, batch): + output_size = len(batch["input_ids"]) + return torch.ones(output_size, device="cuda") * self.output_value + + return DummyModel(output_value=self.output_value).to(device) + + +def test_loading_multiple_models(): + cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES="0") + client = Client(cluster) + + ddf = dask_cudf.from_cudf(cudf.DataFrame({"text": ["apple"] * 6}), npartitions=1) + model_1 = DummyHFModel_WithOutputValue("microsoft/deberta-v3-base", 1) + model_2 = DummyHFModel_WithOutputValue("microsoft/deberta-v3-base", 2) + + pipe_1 = cf.op.Sequential( + cf.op.Tokenizer(model_1, cols=["text"], tokenizer_type="sentencepiece"), + cf.op.Predictor(model_1, sorted_data_loader=False, batch_size=2, pred_output_col="pred_1"), + keep_cols=list(ddf.columns), + ) + output_1_ddf = pipe_1(ddf) + pipe_2 = cf.op.Sequential( + cf.op.Tokenizer(model_2, cols=["text"], tokenizer_type="sentencepiece"), + cf.op.Predictor(model_2, sorted_data_loader=False, batch_size=2, pred_output_col="pred_2"), + keep_cols=list(output_1_ddf.columns), + ) + output_2_ddf = pipe_2(output_1_ddf) + output_2_df = output_2_ddf.to_backend("pandas").compute() + assert output_2_df["pred_1"].values.tolist() == [1, 1, 1, 1, 1, 1] + assert output_2_df["pred_2"].values.tolist() == [2, 2, 2, 2, 2, 2] + + del client + del cluster