diff --git a/src/compressed_tensors/offload/cache/base.py b/src/compressed_tensors/offload/cache/base.py index 6d1431b4a..060139e2b 100644 --- a/src/compressed_tensors/offload/cache/base.py +++ b/src/compressed_tensors/offload/cache/base.py @@ -42,8 +42,8 @@ class OffloadCache(MutableMapping, ABC): # names -> offloaded tensors (populated from _parameters or _buffers) offloaded_values: dict[Hashable, torch.Tensor] - # offloaded tensors -> onloaded tensors (only when offloading is disabled) - keep_onloaded_values: ClassVar[dict[torch.Tensor, torch.Tensor]] = dict() + # id(offloaded tensor) -> onloaded tensor (only when offloading is disabled) + keep_onloaded_values: ClassVar[dict[int, torch.Tensor]] = dict() @classmethod def cls_from_device( @@ -180,15 +180,15 @@ def __getitem__(self, key: Hashable) -> torch.Tensor: return offloaded # check for cache hit - if offloaded in self.keep_onloaded_values: - return self.keep_onloaded_values[offloaded] + if id(offloaded) in self.keep_onloaded_values: + return self.keep_onloaded_values[id(offloaded)] # onload value onloaded = self.onload(offloaded) # when offloading is disabled, populate cache if self.offloading_disabled: - self.keep_onloaded_values[offloaded] = onloaded + self.keep_onloaded_values[id(offloaded)] = onloaded return onloaded @@ -212,7 +212,7 @@ def __setitem__(self, key: Hashable, value: torch.Tensor | None): if offloaded is not None and torch.is_same_size(offloaded, value): self.update_offload(offloaded, value) - onloaded = self.keep_onloaded_values.get(offloaded, None) + onloaded = self.keep_onloaded_values.get(id(offloaded), None) if onloaded is not None and onloaded is not offloaded: onloaded.copy_(value) @@ -231,8 +231,8 @@ def __delitem__(self, key: Hashable): del self.offloaded_values[key] # remove strong ref - if offloaded in self.keep_onloaded_values: - del self.keep_onloaded_values[offloaded] + if id(offloaded) in self.keep_onloaded_values: + del self.keep_onloaded_values[id(offloaded)] def __contains__(self, key) -> bool: return key in self.offloaded_values diff --git a/src/compressed_tensors/offload/cache/disk.py b/src/compressed_tensors/offload/cache/disk.py index f623febd8..f3e00710f 100644 --- a/src/compressed_tensors/offload/cache/disk.py +++ b/src/compressed_tensors/offload/cache/disk.py @@ -32,8 +32,8 @@ class DiskCache(OffloadCache): offload_device = "disk" - # offloaded tensors -> weight info - index: dict[torch.Tensor, dict[str, str]] = dict() + # id(offloaded tensor) -> weight info + index: dict[int, dict[str, str]] = dict() # directory where new tensors are written to offload_dir: str @@ -67,7 +67,7 @@ def onload(self, offloaded: torch.Tensor | None) -> torch.Tensor | None: if offloaded is None: return None - weight_info = self.index[offloaded] + weight_info = self.index[id(offloaded)] device = _get_safe_open_device(self.onload_device) with safe_open( @@ -92,14 +92,14 @@ def offload( return None if tensor.device.type == "meta": - assert tensor in self.index + assert id(tensor) in self.index return tensor if offloaded is None: offloaded = send_tensors(tensor, device="meta") file_path = self._get_ct_file_path(self.offload_dir, offloaded) - self.index[offloaded] = { + self.index[id(offloaded)] = { "safetensors_file": file_path, "weight_name": "weight", "dtype": str(tensor.dtype).removeprefix("torch."), @@ -119,10 +119,10 @@ def __delitem__(self, key: str): :param key: name of tensor to invalidate """ offloaded = self.offloaded_values[key] - file_path = self.index[offloaded]["safetensors_file"] + file_path = self.index[id(offloaded)]["safetensors_file"] if self._is_ct_file_path(file_path): os.remove(file_path) - del self.index[offloaded] + del self.index[id(offloaded)] super().__delitem__(key) def update_offload(self, offloaded: torch.Tensor, data: torch.Tensor | None): @@ -133,8 +133,8 @@ def update_offload(self, offloaded: torch.Tensor, data: torch.Tensor | None): :param data: new data """ # get weight info from index - assert offloaded in self.index, "Cannot find offload to update" - weight_info = self.index[offloaded] + assert id(offloaded) in self.index, "Cannot find offload to update" + weight_info = self.index[id(offloaded)] file_path = weight_info["safetensors_file"] weight_name = weight_info["weight_name"] dtype = getattr(torch, weight_info["dtype"]) @@ -169,7 +169,7 @@ def create_checkpoint_symlink( file_path = cls._get_ct_file_path(offload_dir, offloaded) os.symlink(source_path, file_path) - cls.index[offloaded] = { + cls.index[id(offloaded)] = { "safetensors_file": file_path, "weight_name": weight_info["weight_name"], "dtype": weight_info["dtype"], diff --git a/src/compressed_tensors/offload/cache/dist_disk.py b/src/compressed_tensors/offload/cache/dist_disk.py index d8db1d460..fc58a699f 100644 --- a/src/compressed_tensors/offload/cache/dist_disk.py +++ b/src/compressed_tensors/offload/cache/dist_disk.py @@ -27,10 +27,11 @@ def offload(self, tensor: torch.Tensor | None) -> torch.Tensor | None: if is_source_process(): # write to disk offloaded = super().offload(tensor) + offloaded_id = id(offloaded) broadcast_obj = [ - self.index[offloaded]["safetensors_file"], - self.index[offloaded]["weight_name"], - self.index[offloaded]["dtype"], + self.index[offloaded_id]["safetensors_file"], + self.index[offloaded_id]["weight_name"], + self.index[offloaded_id]["dtype"], ] else: offloaded = send_tensors(tensor, device="meta") @@ -39,7 +40,7 @@ def offload(self, tensor: torch.Tensor | None) -> torch.Tensor | None: dist.broadcast_object_list(broadcast_obj, src=get_source_rank()) if not is_source_process(): - self.index[offloaded] = { + self.index[id(offloaded)] = { "safetensors_file": broadcast_obj[0], "weight_name": broadcast_obj[1], "dtype": broadcast_obj[2], diff --git a/src/compressed_tensors/offload/convert/from_accelerate.py b/src/compressed_tensors/offload/convert/from_accelerate.py index 3e136823b..33b093af4 100644 --- a/src/compressed_tensors/offload/convert/from_accelerate.py +++ b/src/compressed_tensors/offload/convert/from_accelerate.py @@ -186,7 +186,7 @@ def _save_ct_index_entry( offloaded: torch.Tensor, ): # already indexed from a previous round-trip (e.g. to_accelerate -> from_accelerate) - if offloaded in DiskCache.index: + if id(offloaded) in DiskCache.index: return entry: dict = dataset.index[name] diff --git a/src/compressed_tensors/offload/convert/to_accelerate.py b/src/compressed_tensors/offload/convert/to_accelerate.py index 12a477388..b86fd9e42 100644 --- a/src/compressed_tensors/offload/convert/to_accelerate.py +++ b/src/compressed_tensors/offload/convert/to_accelerate.py @@ -112,17 +112,20 @@ def to_accelerate_module( def _to_accelerate_disk_index( - model: torch.nn.Module, index: dict[torch.Tensor, dict[str, str]] + model: torch.nn.Module, index: dict[int, dict[str, str]] ) -> dict[str, dict[str, str]]: from compressed_tensors.offload import disable_onloading # circular dependency with disable_onloading(): - offloaded_to_key = _invert_dict(model.state_dict(keep_vars=True)) + inverse_state_dict = _invert_dict(model.state_dict(keep_vars=True)) + offloaded_id_to_name = { + id(offloaded): name for offloaded, name in inverse_state_dict.items() + } return { - key: weight_info - for offloaded, weight_info in index.items() - for key in offloaded_to_key[offloaded] + name: weight_info + for offloaded_id, weight_info in index.items() + for name in offloaded_id_to_name[offloaded_id] }