From 077b0c0e8a768901e9186fb510bd7620ca6f46a2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 22 Jun 2026 11:59:53 -0400 Subject: [PATCH 1/3] do thing Signed-off-by: Kyle Sayers --- src/compressed_tensors/offload/cache/base.py | 16 ++++++++-------- src/compressed_tensors/offload/cache/disk.py | 16 ++++++++-------- .../offload/convert/to_accelerate.py | 13 ++++++++----- 3 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/compressed_tensors/offload/cache/base.py b/src/compressed_tensors/offload/cache/base.py index 6d1431b4a..060139e2b 100644 --- a/src/compressed_tensors/offload/cache/base.py +++ b/src/compressed_tensors/offload/cache/base.py @@ -42,8 +42,8 @@ class OffloadCache(MutableMapping, ABC): # names -> offloaded tensors (populated from _parameters or _buffers) offloaded_values: dict[Hashable, torch.Tensor] - # offloaded tensors -> onloaded tensors (only when offloading is disabled) - keep_onloaded_values: ClassVar[dict[torch.Tensor, torch.Tensor]] = dict() + # id(offloaded tensor) -> onloaded tensor (only when offloading is disabled) + keep_onloaded_values: ClassVar[dict[int, torch.Tensor]] = dict() @classmethod def cls_from_device( @@ -180,15 +180,15 @@ def __getitem__(self, key: Hashable) -> torch.Tensor: return offloaded # check for cache hit - if offloaded in self.keep_onloaded_values: - return self.keep_onloaded_values[offloaded] + if id(offloaded) in self.keep_onloaded_values: + return self.keep_onloaded_values[id(offloaded)] # onload value onloaded = self.onload(offloaded) # when offloading is disabled, populate cache if self.offloading_disabled: - self.keep_onloaded_values[offloaded] = onloaded + self.keep_onloaded_values[id(offloaded)] = onloaded return onloaded @@ -212,7 +212,7 @@ def __setitem__(self, key: Hashable, value: torch.Tensor | None): if offloaded is not None and torch.is_same_size(offloaded, value): self.update_offload(offloaded, value) - onloaded = self.keep_onloaded_values.get(offloaded, None) + onloaded = self.keep_onloaded_values.get(id(offloaded), None) if onloaded is not None and onloaded is not offloaded: onloaded.copy_(value) @@ -231,8 +231,8 @@ def __delitem__(self, key: Hashable): del self.offloaded_values[key] # remove strong ref - if offloaded in self.keep_onloaded_values: - del self.keep_onloaded_values[offloaded] + if id(offloaded) in self.keep_onloaded_values: + del self.keep_onloaded_values[id(offloaded)] def __contains__(self, key) -> bool: return key in self.offloaded_values diff --git a/src/compressed_tensors/offload/cache/disk.py b/src/compressed_tensors/offload/cache/disk.py index f623febd8..a4e9c0cd3 100644 --- a/src/compressed_tensors/offload/cache/disk.py +++ b/src/compressed_tensors/offload/cache/disk.py @@ -32,8 +32,8 @@ class DiskCache(OffloadCache): offload_device = "disk" - # offloaded tensors -> weight info - index: dict[torch.Tensor, dict[str, str]] = dict() + # id(offloaded tensor) -> weight info + index: dict[int, dict[str, str]] = dict() # directory where new tensors are written to offload_dir: str @@ -67,7 +67,7 @@ def onload(self, offloaded: torch.Tensor | None) -> torch.Tensor | None: if offloaded is None: return None - weight_info = self.index[offloaded] + weight_info = self.index[id(offloaded)] device = _get_safe_open_device(self.onload_device) with safe_open( @@ -99,7 +99,7 @@ def offload( offloaded = send_tensors(tensor, device="meta") file_path = self._get_ct_file_path(self.offload_dir, offloaded) - self.index[offloaded] = { + self.index[id(offloaded)] = { "safetensors_file": file_path, "weight_name": "weight", "dtype": str(tensor.dtype).removeprefix("torch."), @@ -119,10 +119,10 @@ def __delitem__(self, key: str): :param key: name of tensor to invalidate """ offloaded = self.offloaded_values[key] - file_path = self.index[offloaded]["safetensors_file"] + file_path = self.index[id(offloaded)]["safetensors_file"] if self._is_ct_file_path(file_path): os.remove(file_path) - del self.index[offloaded] + del self.index[id(offloaded)] super().__delitem__(key) def update_offload(self, offloaded: torch.Tensor, data: torch.Tensor | None): @@ -134,7 +134,7 @@ def update_offload(self, offloaded: torch.Tensor, data: torch.Tensor | None): """ # get weight info from index assert offloaded in self.index, "Cannot find offload to update" - weight_info = self.index[offloaded] + weight_info = self.index[id(offloaded)] file_path = weight_info["safetensors_file"] weight_name = weight_info["weight_name"] dtype = getattr(torch, weight_info["dtype"]) @@ -169,7 +169,7 @@ def create_checkpoint_symlink( file_path = cls._get_ct_file_path(offload_dir, offloaded) os.symlink(source_path, file_path) - cls.index[offloaded] = { + cls.index[id(offloaded)] = { "safetensors_file": file_path, "weight_name": weight_info["weight_name"], "dtype": weight_info["dtype"], diff --git a/src/compressed_tensors/offload/convert/to_accelerate.py b/src/compressed_tensors/offload/convert/to_accelerate.py index 12a477388..b86fd9e42 100644 --- a/src/compressed_tensors/offload/convert/to_accelerate.py +++ b/src/compressed_tensors/offload/convert/to_accelerate.py @@ -112,17 +112,20 @@ def to_accelerate_module( def _to_accelerate_disk_index( - model: torch.nn.Module, index: dict[torch.Tensor, dict[str, str]] + model: torch.nn.Module, index: dict[int, dict[str, str]] ) -> dict[str, dict[str, str]]: from compressed_tensors.offload import disable_onloading # circular dependency with disable_onloading(): - offloaded_to_key = _invert_dict(model.state_dict(keep_vars=True)) + inverse_state_dict = _invert_dict(model.state_dict(keep_vars=True)) + offloaded_id_to_name = { + id(offloaded): name for offloaded, name in inverse_state_dict.items() + } return { - key: weight_info - for offloaded, weight_info in index.items() - for key in offloaded_to_key[offloaded] + name: weight_info + for offloaded_id, weight_info in index.items() + for name in offloaded_id_to_name[offloaded_id] } From 2cd5e88959a5c00afa100c1e2c8c92380831c714 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 22 Jun 2026 12:14:39 -0400 Subject: [PATCH 2/3] do thing better Signed-off-by: Kyle Sayers --- src/compressed_tensors/offload/cache/disk.py | 4 ++-- src/compressed_tensors/offload/convert/from_accelerate.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/offload/cache/disk.py b/src/compressed_tensors/offload/cache/disk.py index a4e9c0cd3..f3e00710f 100644 --- a/src/compressed_tensors/offload/cache/disk.py +++ b/src/compressed_tensors/offload/cache/disk.py @@ -92,7 +92,7 @@ def offload( return None if tensor.device.type == "meta": - assert tensor in self.index + assert id(tensor) in self.index return tensor if offloaded is None: @@ -133,7 +133,7 @@ def update_offload(self, offloaded: torch.Tensor, data: torch.Tensor | None): :param data: new data """ # get weight info from index - assert offloaded in self.index, "Cannot find offload to update" + assert id(offloaded) in self.index, "Cannot find offload to update" weight_info = self.index[id(offloaded)] file_path = weight_info["safetensors_file"] weight_name = weight_info["weight_name"] diff --git a/src/compressed_tensors/offload/convert/from_accelerate.py b/src/compressed_tensors/offload/convert/from_accelerate.py index 3e136823b..33b093af4 100644 --- a/src/compressed_tensors/offload/convert/from_accelerate.py +++ b/src/compressed_tensors/offload/convert/from_accelerate.py @@ -186,7 +186,7 @@ def _save_ct_index_entry( offloaded: torch.Tensor, ): # already indexed from a previous round-trip (e.g. to_accelerate -> from_accelerate) - if offloaded in DiskCache.index: + if id(offloaded) in DiskCache.index: return entry: dict = dataset.index[name] From 06cc2246469d79b107b32831b86f7961bc9bfc31 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 22 Jun 2026 14:58:54 -0400 Subject: [PATCH 3/3] dist disk Signed-off-by: Kyle Sayers --- src/compressed_tensors/offload/cache/dist_disk.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/offload/cache/dist_disk.py b/src/compressed_tensors/offload/cache/dist_disk.py index d8db1d460..fc58a699f 100644 --- a/src/compressed_tensors/offload/cache/dist_disk.py +++ b/src/compressed_tensors/offload/cache/dist_disk.py @@ -27,10 +27,11 @@ def offload(self, tensor: torch.Tensor | None) -> torch.Tensor | None: if is_source_process(): # write to disk offloaded = super().offload(tensor) + offloaded_id = id(offloaded) broadcast_obj = [ - self.index[offloaded]["safetensors_file"], - self.index[offloaded]["weight_name"], - self.index[offloaded]["dtype"], + self.index[offloaded_id]["safetensors_file"], + self.index[offloaded_id]["weight_name"], + self.index[offloaded_id]["dtype"], ] else: offloaded = send_tensors(tensor, device="meta") @@ -39,7 +40,7 @@ def offload(self, tensor: torch.Tensor | None) -> torch.Tensor | None: dist.broadcast_object_list(broadcast_obj, src=get_source_rank()) if not is_source_process(): - self.index[offloaded] = { + self.index[id(offloaded)] = { "safetensors_file": broadcast_obj[0], "weight_name": broadcast_obj[1], "dtype": broadcast_obj[2],