Skip to content

Commit b26fb75

Browse files
committed
Fix minor issues in pack_cuda.py
- Use .detach() instead of .data when moving packed INT4 weight to CPU to preserve tensor subclass identity safely - Remove unused loaded_keys set in load_and_pack_for_cuda - Handle top-level tensor keys (no dot) in load_and_pack_for_cuda
1 parent 53612c9 commit b26fb75

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

examples/models/gemma4_31b/quant/pack_cuda.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def pack_linear_for_cuda(module: nn.Module, weights: dict[str, torch.Tensor]) ->
112112
# Pack on CUDA (required by _convert_weight_to_int4pack), move back
113113
# to CPU for assembly. The model moves to CUDA later at runtime.
114114
packed = pack_int4_for_cuda(w, device="cuda")
115-
module.weight = nn.Parameter(packed.data.to("cpu"), requires_grad=False)
115+
module.weight = nn.Parameter(packed.detach().to("cpu"), requires_grad=False)
116116
torch.cuda.empty_cache()
117117
elif isinstance(w, IntxUnpackedToInt8Tensor):
118118
module.weight = nn.Parameter(w, requires_grad=False)
@@ -166,15 +166,17 @@ def load_and_pack_for_cuda(
166166

167167
# Stream one logical weight at a time: load its inner tensors,
168168
# reconstruct the subclass, pack, then release before the next.
169-
loaded_keys: set[str] = set()
170169
for name in tensor_names:
171-
module_fqn, weight_name = name.rsplit(".", 1)
172-
prefix = f"{module_fqn}._{weight_name}_"
170+
parts = name.rsplit(".", 1)
171+
module_fqn = parts[0] if len(parts) > 1 else ""
172+
weight_name = parts[-1]
173+
prefix = (
174+
f"{module_fqn}._{weight_name}_" if module_fqn else f"_{weight_name}_"
175+
)
173176
partial = {}
174177
for key in all_keys:
175178
if key.startswith(prefix) or key == name:
176179
partial[key] = f.get_tensor(key)
177-
loaded_keys.add(key)
178180
result, _ = unflatten_tensor_state_dict(partial, metadata)
179181
for fqn, value in result.items():
180182
pack_one(model, fqn, value, _packers)

0 commit comments

Comments
 (0)