Skip to content

Commit

Permalink
Move to class-id as thats actually unique
Browse files Browse the repository at this point in the history
  • Loading branch information
Vibhu Jawa committed Nov 6, 2024
1 parent db36fc3 commit cbb9a9a
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 13 deletions.
12 changes: 6 additions & 6 deletions crossfit/backend/torch/hf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ def __init__(
)

def load_on_worker(self, worker, device="cuda"):
setattr(worker, f"torch_model_{self.path_or_name}", self.load_model(device))
setattr(worker, f"cfg_{self.path_or_name}", self.load_cfg())
setattr(worker, f"torch_model_{id(self)}", self.load_model(device))
setattr(worker, f"cfg_{id(self)}", self.load_cfg())

def unload_from_worker(self, worker):
if hasattr(worker, f"torch_model_{self.path_or_name}"):
delattr(worker, f"torch_model_{self.path_or_name}")
if hasattr(worker, f"cfg_{self.path_or_name}"):
delattr(worker, f"cfg_{self.path_or_name}")
if hasattr(worker, f"torch_model_{id(self)}"):
delattr(worker, f"torch_model_{id(self)}")
if hasattr(worker, f"cfg_{id(self)}"):
delattr(worker, f"cfg_{id(self)}")
cleanup_torch_cache()

def load_model(self, device="cuda"):
Expand Down
6 changes: 3 additions & 3 deletions crossfit/backend/torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ def unload_from_worker(self, worker):
raise NotImplementedError()

def call_on_worker(self, worker, *args, **kwargs):
return getattr(worker, f"torch_model_{self.path_or_name}")(*args, **kwargs)
return getattr(worker, f"torch_model_{id(self)}")(*args, **kwargs)

def get_model(self, worker):
if not hasattr(worker, f"torch_model_{self.path_or_name}"):
if not hasattr(worker, f"torch_model_{id(self)}"):
self.load_on_worker(worker)
return getattr(worker, f"torch_model_{self.path_or_name}")
return getattr(worker, f"torch_model_{id(self)}")

def estimate_memory(self, max_num_tokens: int, batch_size: int) -> int:
raise NotImplementedError()
Expand Down
55 changes: 51 additions & 4 deletions tests/backend/pytorch_backend/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ def mock_worker(self):
def test_unload_from_worker(self, model, mock_worker):
model.load_on_worker(mock_worker)

assert hasattr(mock_worker, f"torch_model_{model.path_or_name}")
assert hasattr(mock_worker, f"cfg_{model.path_or_name}")
assert hasattr(mock_worker, f"torch_model_{id(model)}")
assert hasattr(mock_worker, f"cfg_{id(model)}")

model.unload_from_worker(mock_worker)

assert not hasattr(mock_worker, f"torch_model_{model.path_or_name}")
assert not hasattr(mock_worker, f"cfg_{model.path_or_name}")
assert not hasattr(mock_worker, f"torch_model_{id(model)}")
assert not hasattr(mock_worker, f"cfg_{id(model)}")


class DummyModelWithDictOutput(torch.nn.Module):
Expand Down Expand Up @@ -144,3 +144,50 @@ def test_meta_invalid_model_output_type(self):
):
predictor = cf.op.Predictor(model=self.model_string, model_output_cols=["a", "b"])
predictor.meta()


class DummyHFModel_WithOutputValue(HFModel):
def __init__(self, model_name, output_value):
self.model_name = model_name
self.output_value = output_value
super().__init__(model_name)

def load_model(self, device="cuda"):
class DummyModel(torch.nn.Module):
def __init__(self, output_value):
super().__init__()
self.output_value = output_value

def forward(self, batch):
output_size = len(batch["input_ids"])
return torch.ones(output_size, device="cuda") * self.output_value

return DummyModel(output_value=self.output_value).to(device)


def test_loading_multiple_models():
cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES="0")
client = Client(cluster)

ddf = dask_cudf.from_cudf(cudf.DataFrame({"text": ["apple"] * 6}), npartitions=1)
model_1 = DummyHFModel_WithOutputValue("microsoft/deberta-v3-base", 1)
model_2 = DummyHFModel_WithOutputValue("microsoft/deberta-v3-base", 2)

pipe_1 = cf.op.Sequential(
cf.op.Tokenizer(model_1, cols=["text"], tokenizer_type="sentencepiece"),
cf.op.Predictor(model_1, sorted_data_loader=False, batch_size=2, pred_output_col="pred_1"),
keep_cols=list(ddf.columns),
)
output_1_ddf = pipe_1(ddf)
pipe_2 = cf.op.Sequential(
cf.op.Tokenizer(model_2, cols=["text"], tokenizer_type="sentencepiece"),
cf.op.Predictor(model_2, sorted_data_loader=False, batch_size=2, pred_output_col="pred_2"),
keep_cols=list(output_1_ddf.columns),
)
output_2_ddf = pipe_2(output_1_ddf)
output_2_df = output_2_ddf.to_backend("pandas").compute()
assert output_2_df["pred_1"].values.tolist() == [1, 1, 1, 1, 1, 1]
assert output_2_df["pred_2"].values.tolist() == [2, 2, 2, 2, 2, 2]

del client
del cluster

0 comments on commit cbb9a9a

Please sign in to comment.