Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/compressed_tensors/offload/cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class OffloadCache(MutableMapping, ABC):
# names -> offloaded tensors (populated from _parameters or _buffers)
offloaded_values: dict[Hashable, torch.Tensor]

# offloaded tensors -> onloaded tensors (only when offloading is disabled)
keep_onloaded_values: ClassVar[dict[torch.Tensor, torch.Tensor]] = dict()
# id(offloaded tensor) -> onloaded tensor (only when offloading is disabled)
keep_onloaded_values: ClassVar[dict[int, torch.Tensor]] = dict()

@classmethod
def cls_from_device(
Expand Down Expand Up @@ -180,15 +180,15 @@ def __getitem__(self, key: Hashable) -> torch.Tensor:
return offloaded

# check for cache hit
if offloaded in self.keep_onloaded_values:
return self.keep_onloaded_values[offloaded]
if id(offloaded) in self.keep_onloaded_values:
return self.keep_onloaded_values[id(offloaded)]

# onload value
onloaded = self.onload(offloaded)

# when offloading is disabled, populate cache
if self.offloading_disabled:
self.keep_onloaded_values[offloaded] = onloaded
self.keep_onloaded_values[id(offloaded)] = onloaded

return onloaded

Expand All @@ -212,7 +212,7 @@ def __setitem__(self, key: Hashable, value: torch.Tensor | None):
if offloaded is not None and torch.is_same_size(offloaded, value):
self.update_offload(offloaded, value)

onloaded = self.keep_onloaded_values.get(offloaded, None)
onloaded = self.keep_onloaded_values.get(id(offloaded), None)
if onloaded is not None and onloaded is not offloaded:
onloaded.copy_(value)

Expand All @@ -231,8 +231,8 @@ def __delitem__(self, key: Hashable):
del self.offloaded_values[key]

# remove strong ref
if offloaded in self.keep_onloaded_values:
del self.keep_onloaded_values[offloaded]
if id(offloaded) in self.keep_onloaded_values:
del self.keep_onloaded_values[id(offloaded)]

def __contains__(self, key) -> bool:
return key in self.offloaded_values
Expand Down
20 changes: 10 additions & 10 deletions src/compressed_tensors/offload/cache/disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ class DiskCache(OffloadCache):

offload_device = "disk"

# offloaded tensors -> weight info
index: dict[torch.Tensor, dict[str, str]] = dict()
# id(offloaded tensor) -> weight info
index: dict[int, dict[str, str]] = dict()
Comment thread
kylesayrs marked this conversation as resolved.

# directory where new tensors are written to
offload_dir: str
Expand Down Expand Up @@ -67,7 +67,7 @@ def onload(self, offloaded: torch.Tensor | None) -> torch.Tensor | None:
if offloaded is None:
return None

weight_info = self.index[offloaded]
weight_info = self.index[id(offloaded)]
device = _get_safe_open_device(self.onload_device)

with safe_open(
Expand All @@ -92,14 +92,14 @@ def offload(
return None

if tensor.device.type == "meta":
assert tensor in self.index
assert id(tensor) in self.index
return tensor

if offloaded is None:
offloaded = send_tensors(tensor, device="meta")

file_path = self._get_ct_file_path(self.offload_dir, offloaded)
self.index[offloaded] = {
self.index[id(offloaded)] = {
"safetensors_file": file_path,
"weight_name": "weight",
"dtype": str(tensor.dtype).removeprefix("torch."),
Expand All @@ -119,10 +119,10 @@ def __delitem__(self, key: str):
:param key: name of tensor to invalidate
"""
offloaded = self.offloaded_values[key]
file_path = self.index[offloaded]["safetensors_file"]
file_path = self.index[id(offloaded)]["safetensors_file"]
if self._is_ct_file_path(file_path):
os.remove(file_path)
del self.index[offloaded]
del self.index[id(offloaded)]
super().__delitem__(key)

def update_offload(self, offloaded: torch.Tensor, data: torch.Tensor | None):
Expand All @@ -133,8 +133,8 @@ def update_offload(self, offloaded: torch.Tensor, data: torch.Tensor | None):
:param data: new data
"""
# get weight info from index
assert offloaded in self.index, "Cannot find offload to update"
weight_info = self.index[offloaded]
assert id(offloaded) in self.index, "Cannot find offload to update"
weight_info = self.index[id(offloaded)]
file_path = weight_info["safetensors_file"]
weight_name = weight_info["weight_name"]
dtype = getattr(torch, weight_info["dtype"])
Expand Down Expand Up @@ -169,7 +169,7 @@ def create_checkpoint_symlink(
file_path = cls._get_ct_file_path(offload_dir, offloaded)

os.symlink(source_path, file_path)
cls.index[offloaded] = {
cls.index[id(offloaded)] = {
"safetensors_file": file_path,
"weight_name": weight_info["weight_name"],
"dtype": weight_info["dtype"],
Expand Down
9 changes: 5 additions & 4 deletions src/compressed_tensors/offload/cache/dist_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ def offload(self, tensor: torch.Tensor | None) -> torch.Tensor | None:
if is_source_process():
# write to disk
offloaded = super().offload(tensor)
offloaded_id = id(offloaded)
broadcast_obj = [
self.index[offloaded]["safetensors_file"],
self.index[offloaded]["weight_name"],
self.index[offloaded]["dtype"],
self.index[offloaded_id]["safetensors_file"],
self.index[offloaded_id]["weight_name"],
self.index[offloaded_id]["dtype"],
]
else:
offloaded = send_tensors(tensor, device="meta")
Expand All @@ -39,7 +40,7 @@ def offload(self, tensor: torch.Tensor | None) -> torch.Tensor | None:
dist.broadcast_object_list(broadcast_obj, src=get_source_rank())

if not is_source_process():
self.index[offloaded] = {
self.index[id(offloaded)] = {
"safetensors_file": broadcast_obj[0],
"weight_name": broadcast_obj[1],
"dtype": broadcast_obj[2],
Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/offload/convert/from_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _save_ct_index_entry(
offloaded: torch.Tensor,
):
# already indexed from a previous round-trip (e.g. to_accelerate -> from_accelerate)
if offloaded in DiskCache.index:
if id(offloaded) in DiskCache.index:
return

entry: dict = dataset.index[name]
Expand Down
13 changes: 8 additions & 5 deletions src/compressed_tensors/offload/convert/to_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,20 @@ def to_accelerate_module(


def _to_accelerate_disk_index(
model: torch.nn.Module, index: dict[torch.Tensor, dict[str, str]]
model: torch.nn.Module, index: dict[int, dict[str, str]]
) -> dict[str, dict[str, str]]:
from compressed_tensors.offload import disable_onloading # circular dependency

with disable_onloading():
offloaded_to_key = _invert_dict(model.state_dict(keep_vars=True))
inverse_state_dict = _invert_dict(model.state_dict(keep_vars=True))
offloaded_id_to_name = {
id(offloaded): name for offloaded, name in inverse_state_dict.items()
}

return {
key: weight_info
for offloaded, weight_info in index.items()
for key in offloaded_to_key[offloaded]
name: weight_info
for offloaded_id, weight_info in index.items()
for name in offloaded_id_to_name[offloaded_id]
}


Expand Down
Loading