1+ from secrets import token_bytes
2+ from typing import Optional
3+
4+
5+ from sgl_jax .srt .lora .backend .base_backend import BaseLoRABackend
6+ from sgl_jax .srt .lora .utils import LoRABatchInfo
7+ from sgl_jax .srt .model_executor .forward_batch_info import ForwardBatch , ForwardMode
8+ import jax
9+ import jax .numpy as jnp
10+ import numpy as np
11+
12+ MIN_CHUNK_SIZE = 16
13+
14+
15+ class BgmvLoRABackend (BaseLoRABackend ):
16+ """
17+ Bgmv LoRA backend using batched grouped matrix-vector multiplication.
18+ """
19+
20+ name = "bgmv"
21+
22+ def __init__ (
23+ self ,
24+ max_loras_per_batch : int ,
25+ max_lora_rank : int ,
26+ ):
27+ super ().__init__ (max_loras_per_batch )
28+ self .max_lora_rank = max_lora_rank
29+
30+ def run_lora_a_gemm (
31+ self ,
32+ x : jax .Array , # (s, input_dim)
33+ weights : jax .Array , # (num_lora, r, input_dim)
34+ * args ,
35+ ** kwargs
36+ ) -> jax .Array :
37+ #x = x.reshape(-1, x.shape[-1])
38+
39+ # # Add dimension for bgmv_shrink
40+ # lora_a_reshaped = jnp.expand_dims(weights, axis=1)
41+
42+ # Single bgmv_shrink call
43+ return bgmv_shrink (x , weights , self .batch_info .token_lora_indices , self .batch_info .scalings )
44+
45+ def run_lora_b_gemm (
46+ self ,
47+ x : jax .Array , # (s, r)
48+ weights : jax .Array , # (num_lora, output_dim, r)
49+ base_output : jax .Array | None = None ,
50+ * args ,
51+ ** kwargs ,
52+ ) -> jax .Array :
53+ s = x .shape [0 ]
54+ output_dim = weights .shape [1 ]
55+
56+ return bgmv_expand_slice (
57+ x ,
58+ weights ,
59+ base_output ,
60+ self .batch_info .token_lora_indices ,
61+ 0 ,
62+ output_dim
63+ (s , output_dim )
64+ )
65+
66+
67+ def run_qkv_lora (
68+ self ,
69+ x : jax .Array , # (s, input_dim)
70+ qkv_lora_a : jax .Array , # (num_lora, 3 * r, input_dim)
71+ qkv_lora_b : jax .Array | tuple [jax .Array ], # (num_lora, output_dim_q + 2 * output_dim_kv, r) or ((1, num_lora, output_dim_q, r), (2, num_lora, output_dim_kv, r))
72+ output_slices : tuple , # a tuple = (output_dim_q, output_dim_kv, output_dim_kv)
73+ base_output : jax .Array | None = None ,
74+ * args ,
75+ ** kwargs ,
76+ ) -> jax .Array :
77+ """Run the lora pass for QKV Layer.
78+
79+ Args:
80+ x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
81+ qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
82+ qkv_lora_b: lora_b module for qkv.
83+ If passed in as a tensor, its shape should be (num_lora, output_dim_q + 2 * output_dim_kv, r)
84+ If passed in as a tuple of two tensors, it should contain:
85+ a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
86+ and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
87+ output_slices: a fixed tuple which has three item, (output_dim_q, output_dim_kv, output_dim_kv)
88+ Returns:
89+ result with shape (s, output_dim_q + 2 * output_dim_kv)
90+ """
91+ pass
92+
93+ def run_gate_up_lora (
94+ self ,
95+ x : jax .Array ,
96+ gate_up_lora_a : jax .Array ,
97+ gate_up_lora_b : jax .Array | tuple [jax .Array ],
98+ base_output : jax .Array | None = None ,
99+ * args ,
100+ ** kwargs ,
101+ ) -> jax .Array :
102+ """Run the lora pass for gate_up_proj.
103+
104+ Args:
105+ x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
106+ gate_up_lora_a: lora_a module for gate_up_proj, with shape (num_lora, 2 * r, input_dim)
107+ gate_up_lora_b: lora_b module for qkv.
108+ If passed in as a tensor, its shape should be (num_lora, 2 * output_dim, r)
109+ If passed in as a tuple, it should contain two tensors with shape (num_lora, output_dim, r)
110+ output_slices: a fixed tuple which has three item, (output_dim_q, output_dim_kv, output_dim_kv)
111+ Returns:
112+ result with shape (s, 2 * output_dim)
113+ """
114+ pass
115+
116+ def prepare_lora_batch (
117+ self ,
118+ forward_batch : ForwardBatch ,
119+ weight_indices : list [int ], # (bs,), please pad with -1
120+ lora_ranks : list [int ], # (max_loras_per_batch,)
121+ scalings : list [float ], # (max_loras_per_batch,)
122+ batch_info : Optional [LoRABatchInfo ] = None ,
123+ ):
124+ lora_ranks_bs = []
125+ scalings_bs = []
126+ for indice in weight_indices :
127+ if indice != - 1 :
128+ lora_ranks_bs .append (lora_ranks [indice ])
129+ scalings_bs .append (scalings [indice ])
130+ else :
131+ lora_ranks_bs .append (0 )
132+ scalings_bs .append (0.0 )
133+
134+ assert len (forward_batch .seq_lens ) == len (weight_indices )
135+ assert len (forward_batch .seq_lens ) == len (lora_ranks_bs )
136+ assert len (forward_batch .seq_lens ) == len (scalings_bs )
137+
138+ target_len = forward_batch .input_ids .shape [0 ]
139+
140+ if forward_batch .forward_mode == ForwardMode .EXTEND :
141+ scalings_cpu = np .repeat (np .array (scalings_bs , dtype = np .float32 ), forward_batch .seq_lens )
142+ token_lora_indices_cpu = np .repeat (np .array (weight_indices , dtype = np .int32 ), forward_batch .seq_lens )
143+ lora_ranks_cpu = np .repeat (np .array (lora_ranks_bs ,dtype = np .int32 ), forward_batch .seq_lens )
144+
145+ num_to_pad = target_len - jnp .sum (forward_batch .seq_lens )
146+
147+ if num_to_pad > 0 :
148+ padded_scalings_cpu = np .pad (scalings_cpu ,[0 ,num_to_pad ], mode = "constant" , constant_values = 0.0 )
149+ padded_token_lora_indices_cpu = np .pad (token_lora_indices_cpu ,[0 ,num_to_pad ], mode = "constant" , constant_values = - 1 )
150+ padded_lora_ranks_cpu = np .pad (lora_ranks_cpu ,[0 ,num_to_pad ], mode = "constant" , constant_values = 0 )
151+ elif forward_batch .forward_mode == ForwardMode .DECODE :
152+ padded_scalings_cpu = np .array (scalings_bs , dtype = np .float32 )
153+ padded_token_lora_indices_cpu = np .array (weight_indices , dtype = np .int32 )
154+ padded_lora_ranks_cpu = np .array (lora_ranks_bs , dtype = np .int32 )
155+
156+ if batch_info is None :
157+ batch_info = LoRABatchInfo ()
158+
159+ batch_info = LoRABatchInfo (
160+ bs = forward_batch .batch_size ,
161+ scalings = jnp .array (padded_scalings_cpu ,dtype = jnp .float32 ),
162+ token_lora_indices = jnp .array (padded_token_lora_indices_cpu , dtype = jnp .int32 ),
163+ lora_ranks = jnp .array (padded_lora_ranks_cpu , dtype = jnp .int32 ),
164+ )
165+
166+ self .batch_info = batch_info
167+
168+ def bgmv_shrink (
169+ inputs ,
170+ lora_weights ,
171+ lora_indices ,
172+ scaling : float = 1.0 ,
173+ ):
174+ """
175+ Shrink operation: maps input to low-rank space.
176+
177+ Args:
178+ inputs: (s, input_dim)
179+ lora_weights: (num_lora, c * r, input_dim), c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
180+ lora_indices: (num_tokens)
181+ Returns:
182+ [s, c * r]
183+ """
184+ # if len(lora_weights.shape) == 4:
185+ # lora_weights = jnp.squeeze(lora_weights, axis=1)
186+ return scaling * bgmv_jax (inputs , lora_weights , lora_indices )
187+
188+
189+ def bgmv_expand_slice (
190+ inputs , # [num_tokens, lora_rank]
191+ lora_weights , # [num_loras, 1, out_features, lora_rank]
192+ base_output , # [num_tokens, total_out_features]
193+ lora_indices , # [num_tokens]
194+ slice_offset : int ,
195+ slice_size : int ,
196+ output_shape : tuple ,
197+ ):
198+ """
199+ Expand operation: maps from low-rank space to output space.
200+
201+ Args:
202+ inputs: [num_tokens, lora_rank]
203+ lora_weights: [num_loras, 1, out_features, lora_rank]
204+ output_tensor: [num_tokens, total_out_features]
205+ lora_indices: [num_tokens]
206+ """
207+ if len (lora_weights .shape ) == 4 :
208+ lora_weights = jnp .squeeze (lora_weights , axis = 1 )
209+
210+ outputs = bgmv_jax (inputs , lora_weights , lora_indices )
211+
212+ # Pad the outputs
213+ pad_left = slice_offset
214+ pad_right = output_shape [- 1 ] - (slice_offset + slice_size )
215+ outputs = jnp .pad (outputs , ((0 , 0 ), (pad_left , pad_right )), mode = 'constant' , constant_values = 0 )
216+
217+ if base_output is not None :
218+ return base_output + outputs
219+ else :
220+ return outputs
221+
222+ def bgmv_jax (
223+ inputs , # (s, input_dim)
224+ loras , # (num_lora, c * r, input_dim)
225+ idxs , # (num_tokens)
226+ ):
227+ """
228+ Batched grouped matrix-vector multiplication.
229+ For each token, select the corresponding LoRA and apply matrix multiplication.
230+ """
231+ return jnp .einsum (
232+ "td,tX,Xld->tl" ,
233+ inputs ,
234+ jax .nn .one_hot (idxs , loras .shape [0 ], dtype = inputs .dtype ),
235+ loras ,
236+ )
0 commit comments