@@ -310,57 +310,28 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
310310 gralora_rank = r - hybrid_r
311311 subblock_gralora_rank = gralora_rank // gralora_k
312312
313- # Simulate the forward pass computation to get equivalent weight matrix
314- # We need to compute: W_delta such that W_delta @ x = gralora_forward(x) - base_forward(x)
315-
316- # Create an identity matrix for each input dimension and compute output
317- # This gives us the columns of the weight matrix
318- delta_weight = torch .zeros (out_features , in_features , device = device , dtype = gralora_A .dtype )
319-
320- # Process in batches to avoid memory issues
321- batch_size = min (256 , in_features )
322- for start_idx in range (0 , in_features , batch_size ):
323- end_idx = min (start_idx + batch_size , in_features )
324- batch_len = end_idx - start_idx
325-
326- # Create identity input: [batch_len, in_features]
327- x = torch .zeros (batch_len , in_features , device = device , dtype = gralora_A .dtype )
328- for i in range (batch_len ):
329- x [i , start_idx + i ] = 1.0
330-
331- # Apply GraLoRA transformation (following forward logic)
332- # x shape: [batch_len, in_features]
333- N = gralora_k
334-
335- # Reshape x: [batch_len, N, in_features//N]
336- x_reshaped = x .view (batch_len , N , in_features // N )
337-
338- # Apply gralora_A: [batch_len, N, in_features//N] @ [N, in_features//N, rank]
339- # Result: [batch_len, N, rank]
340- temp = torch .einsum ("bni, nir -> bnr" , x_reshaped , gralora_A )
341-
342- # Reshape and permute for information exchange
343- # [batch_len, N, rank] -> [batch_len, N, N, subblock_rank]
344- temp = temp .view (batch_len , N , N , subblock_gralora_rank )
345- # Permute: [batch_len, N, N, subblock_rank] -> [batch_len, N, N, subblock_rank]
346- temp = temp .permute (0 , 2 , 1 , 3 )
347- # Reshape: [batch_len, N, N * subblock_rank]
348- temp = temp .reshape (batch_len , N , N * subblock_gralora_rank )
349-
350- # Apply gralora_B: [batch_len, N, N*subblock_rank] @ [N, rank, out_features//N]
351- # Note: rank here is actually gralora_rank = N * subblock_gralora_rank
352- # Result: [batch_len, N, out_features//N]
353- output = torch .einsum ("bnr, nro -> bno" , temp , gralora_B )
354-
355- # Reshape to [batch_len, out_features]
356- output = output .reshape (batch_len , out_features )
357-
358- # Store in delta_weight (transpose because weight is [out, in])
359- delta_weight [:, start_idx :end_idx ] = output .T
313+ # scatter gralora_A to get the scattered weight matrix
314+ l_indices = torch .arange (in_features , device = device )
315+ n_indices = (l_indices // (in_features // gralora_k ))
316+ i_indices = (l_indices % (in_features // gralora_k ))
317+ gralora_A_scattered = torch .zeros (in_features , gralora_k , gralora_rank , device = device , dtype = dtype )
318+ gralora_A_scattered .scatter_ (1 ,
319+ n_indices .unsqueeze (1 ).unsqueeze (2 ).expand (- 1 , 1 , gralora_rank ),
320+ gralora_A [n_indices , i_indices , :].unsqueeze (1 )
321+ )
322+
323+ # compute the delta weight
324+ delta_weight = torch .einsum (
325+ "ikr, kro -> iko" ,
326+ gralora_A_scattered
327+ .view (in_features , gralora_k , gralora_k , subblock_gralora_rank )
328+ .permute (0 , 2 , 1 , 3 )
329+ .reshape (in_features , gralora_k , gralora_rank ),
330+ gralora_B ,
331+ ).reshape (in_features , out_features ).T
360332
361333 # Add hybrid LoRA component if present
362334 if hybrid_r > 0 :
363- # general_A: [in_features, hybrid_r], general_B: [hybrid_r, out_features]
364335 weight_A_general = gralora_A_general .weight # [hybrid_r, in_features]
365336 weight_B_general = gralora_B_general .weight # [out_features, hybrid_r]
366337
0 commit comments