Skip to content

Commit

Permalink
Fix a minor bug which was causing loading twice (#101)
Browse files Browse the repository at this point in the history
Co-authored-by: Vibhu Jawa <[email protected]>
  • Loading branch information
VibhuJawa and Vibhu Jawa authored Nov 14, 2024
1 parent 14d6088 commit e009be2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
13 changes: 7 additions & 6 deletions crossfit/backend/torch/hf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
self.batch_size_increment = batch_size_increment
self.start_seq_len = start_seq_len
self.seq_len_increment = seq_len_increment
self._cfg_id = f"cfg_{id(self)}"

cache_dir = os.path.join(CF_HOME, "memory", self.load_cfg()._name_or_path)
os.makedirs(cache_dir, exist_ok=True)
Expand Down Expand Up @@ -81,14 +82,14 @@ def __init__(
)

def load_on_worker(self, worker, device="cuda"):
setattr(worker, f"torch_model_{id(self)}", self.load_model(device))
setattr(worker, f"cfg_{id(self)}", self.load_cfg())
setattr(worker, self._model_id, self.load_model(device))
setattr(worker, self._cfg_id, self.load_cfg())

def unload_from_worker(self, worker):
if hasattr(worker, f"torch_model_{id(self)}"):
delattr(worker, f"torch_model_{id(self)}")
if hasattr(worker, f"cfg_{id(self)}"):
delattr(worker, f"cfg_{id(self)}")
if hasattr(worker, self._model_id):
delattr(worker, self._model_id)
if hasattr(worker, self._cfg_id):
delattr(worker, self._cfg_id)
cleanup_torch_cache()

def load_model(self, device="cuda"):
Expand Down
7 changes: 4 additions & 3 deletions crossfit/backend/torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
self.path_or_name = path_or_name
self.max_mem_gb = max_mem_gb
self.model_output_type = _validate_model_output_type(model_output_type)
self._model_id = f"torch_model_{id(self)}"

def load_model(self, device="cuda"):
raise NotImplementedError()
Expand All @@ -64,12 +65,12 @@ def unload_from_worker(self, worker):
raise NotImplementedError()

def call_on_worker(self, worker, *args, **kwargs):
return getattr(worker, f"torch_model_{id(self)}")(*args, **kwargs)
return getattr(worker, self._model_id)(*args, **kwargs)

def get_model(self, worker):
if not hasattr(worker, f"torch_model_{id(self)}"):
if not hasattr(worker, self._model_id):
self.load_on_worker(worker)
return getattr(worker, f"torch_model_{id(self)}")
return getattr(worker, self._model_id)

def estimate_memory(self, max_num_tokens: int, batch_size: int) -> int:
raise NotImplementedError()
Expand Down

0 comments on commit e009be2

Please sign in to comment.