Skip to content

Commit a24d156

Browse files
Refactor GraLoRA weight computation to improve efficiency in delta-weight calculation.
1 parent 151933a commit a24d156

File tree

1 file changed

+19
-48
lines changed

1 file changed

+19
-48
lines changed

src/peft/tuners/gralora/layer.py

Lines changed: 19 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -310,57 +310,28 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
310310
gralora_rank = r - hybrid_r
311311
subblock_gralora_rank = gralora_rank // gralora_k
312312

313-
# Simulate the forward pass computation to get equivalent weight matrix
314-
# We need to compute: W_delta such that W_delta @ x = gralora_forward(x) - base_forward(x)
315-
316-
# Create an identity matrix for each input dimension and compute output
317-
# This gives us the columns of the weight matrix
318-
delta_weight = torch.zeros(out_features, in_features, device=device, dtype=gralora_A.dtype)
319-
320-
# Process in batches to avoid memory issues
321-
batch_size = min(256, in_features)
322-
for start_idx in range(0, in_features, batch_size):
323-
end_idx = min(start_idx + batch_size, in_features)
324-
batch_len = end_idx - start_idx
325-
326-
# Create identity input: [batch_len, in_features]
327-
x = torch.zeros(batch_len, in_features, device=device, dtype=gralora_A.dtype)
328-
for i in range(batch_len):
329-
x[i, start_idx + i] = 1.0
330-
331-
# Apply GraLoRA transformation (following forward logic)
332-
# x shape: [batch_len, in_features]
333-
N = gralora_k
334-
335-
# Reshape x: [batch_len, N, in_features//N]
336-
x_reshaped = x.view(batch_len, N, in_features // N)
337-
338-
# Apply gralora_A: [batch_len, N, in_features//N] @ [N, in_features//N, rank]
339-
# Result: [batch_len, N, rank]
340-
temp = torch.einsum("bni, nir -> bnr", x_reshaped, gralora_A)
341-
342-
# Reshape and permute for information exchange
343-
# [batch_len, N, rank] -> [batch_len, N, N, subblock_rank]
344-
temp = temp.view(batch_len, N, N, subblock_gralora_rank)
345-
# Permute: [batch_len, N, N, subblock_rank] -> [batch_len, N, N, subblock_rank]
346-
temp = temp.permute(0, 2, 1, 3)
347-
# Reshape: [batch_len, N, N * subblock_rank]
348-
temp = temp.reshape(batch_len, N, N * subblock_gralora_rank)
349-
350-
# Apply gralora_B: [batch_len, N, N*subblock_rank] @ [N, rank, out_features//N]
351-
# Note: rank here is actually gralora_rank = N * subblock_gralora_rank
352-
# Result: [batch_len, N, out_features//N]
353-
output = torch.einsum("bnr, nro -> bno", temp, gralora_B)
354-
355-
# Reshape to [batch_len, out_features]
356-
output = output.reshape(batch_len, out_features)
357-
358-
# Store in delta_weight (transpose because weight is [out, in])
359-
delta_weight[:, start_idx:end_idx] = output.T
313+
# scatter gralora_A to get the scattered weight matrix
314+
l_indices = torch.arange(in_features, device=device)
315+
n_indices = (l_indices // (in_features // gralora_k))
316+
i_indices = (l_indices % (in_features // gralora_k))
317+
gralora_A_scattered = torch.zeros(in_features, gralora_k, gralora_rank, device=device, dtype=dtype)
318+
gralora_A_scattered.scatter_(1,
319+
n_indices.unsqueeze(1).unsqueeze(2).expand(-1, 1, gralora_rank),
320+
gralora_A[n_indices, i_indices, :].unsqueeze(1)
321+
)
322+
323+
# compute the delta weight
324+
delta_weight = torch.einsum(
325+
"ikr, kro -> iko",
326+
gralora_A_scattered
327+
.view(in_features, gralora_k, gralora_k, subblock_gralora_rank)
328+
.permute(0, 2, 1, 3)
329+
.reshape(in_features, gralora_k, gralora_rank),
330+
gralora_B,
331+
).reshape(in_features, out_features).T
360332

361333
# Add hybrid LoRA component if present
362334
if hybrid_r > 0:
363-
# general_A: [in_features, hybrid_r], general_B: [hybrid_r, out_features]
364335
weight_A_general = gralora_A_general.weight # [hybrid_r, in_features]
365336
weight_B_general = gralora_B_general.weight # [out_features, hybrid_r]
366337

0 commit comments

Comments
 (0)