Skip to content

Commit a7ae79f

Browse files
committed
add part codes for bgmv backend
1 parent 19d7b29 commit a7ae79f

File tree

3 files changed

+253
-4
lines changed

3 files changed

+253
-4
lines changed

python/sgl_jax/srt/lora/backend/base_backend.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,10 @@ def run_lora_a_gemm(self, x: jax.Array, weights: jax.Array, *args, **kwargs) ->
2222
2323
Args:
2424
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
25-
weights: a set of lora weights with shape (num_lora, c * r, input_dim),
26-
here r is lora rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
25+
weights: a set of lora weights with shape (num_lora, r, input_dim), r is lora rank,
2726
usually input_dim is much larger than r
2827
Returns:
29-
result with shape (s, c * r)
28+
result with shape (s, r)
3029
"""
3130
pass
3231

@@ -47,6 +46,7 @@ def run_qkv_lora(
4746
x: jax.Array,
4847
qkv_lora_a: jax.Array,
4948
qkv_lora_b: jax.Array | tuple[jax.Array],
49+
output_slices: tuple,
5050
*args,
5151
**kwargs,
5252
) -> jax.Array:
@@ -56,10 +56,11 @@ def run_qkv_lora(
5656
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
5757
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
5858
qkv_lora_b: lora_b module for qkv.
59-
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
59+
If passed in as a tensor, its shape should be (num_lora, output_dim_q + 2 * output_dim_kv, r)
6060
If passed in as a tuple of two tensors, it should contain:
6161
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
6262
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
63+
output_slices: a fixed tuple which has three item, (output_dim_q, output_dim_kv, output_dim_kv)
6364
Returns:
6465
result with shape (s, output_dim_q + 2 * output_dim_kv)
6566
"""
@@ -81,6 +82,7 @@ def run_gate_up_lora(
8182
gate_up_lora_b: lora_b module for qkv.
8283
If passed in as a tensor, its shape should be (num_lora, 2 * output_dim, r)
8384
If passed in as a tuple, it should contain two tensors with shape (num_lora, output_dim, r)
85+
output_slices: a fixed tuple which has three item, (output_dim_q, output_dim_kv, output_dim_kv)
8486
Returns:
8587
result with shape (s, 2 * output_dim)
8688
"""
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
from secrets import token_bytes
2+
from typing import Optional
3+
4+
5+
from sgl_jax.srt.lora.backend.base_backend import BaseLoRABackend
6+
from sgl_jax.srt.lora.utils import LoRABatchInfo
7+
from sgl_jax.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
8+
import jax
9+
import jax.numpy as jnp
10+
import numpy as np
11+
12+
MIN_CHUNK_SIZE = 16
13+
14+
15+
class BgmvLoRABackend(BaseLoRABackend):
16+
"""
17+
Bgmv LoRA backend using batched grouped matrix-vector multiplication.
18+
"""
19+
20+
name = "bgmv"
21+
22+
def __init__(
23+
self,
24+
max_loras_per_batch: int,
25+
max_lora_rank: int,
26+
):
27+
super().__init__(max_loras_per_batch)
28+
self.max_lora_rank = max_lora_rank
29+
30+
def run_lora_a_gemm(
31+
self,
32+
x: jax.Array, # (s, input_dim)
33+
weights: jax.Array, # (num_lora, r, input_dim)
34+
*args,
35+
**kwargs
36+
) -> jax.Array:
37+
#x = x.reshape(-1, x.shape[-1])
38+
39+
# # Add dimension for bgmv_shrink
40+
# lora_a_reshaped = jnp.expand_dims(weights, axis=1)
41+
42+
# Single bgmv_shrink call
43+
return bgmv_shrink(x, weights, self.batch_info.token_lora_indices, self.batch_info.scalings)
44+
45+
def run_lora_b_gemm(
46+
self,
47+
x: jax.Array, # (s, r)
48+
weights: jax.Array, # (num_lora, output_dim, r)
49+
base_output: jax.Array |None =None,
50+
*args,
51+
**kwargs,
52+
) -> jax.Array:
53+
s = x.shape[0]
54+
output_dim=weights.shape[1]
55+
56+
return bgmv_expand_slice(
57+
x,
58+
weights,
59+
base_output,
60+
self.batch_info.token_lora_indices,
61+
0,
62+
output_dim
63+
(s, output_dim)
64+
)
65+
66+
67+
def run_qkv_lora(
68+
self,
69+
x: jax.Array, # (s, input_dim)
70+
qkv_lora_a: jax.Array, # (num_lora, 3 * r, input_dim)
71+
qkv_lora_b: jax.Array | tuple[jax.Array], # (num_lora, output_dim_q + 2 * output_dim_kv, r) or ((1, num_lora, output_dim_q, r), (2, num_lora, output_dim_kv, r))
72+
output_slices: tuple, # a tuple = (output_dim_q, output_dim_kv, output_dim_kv)
73+
base_output: jax.Array |None =None,
74+
*args,
75+
**kwargs,
76+
) -> jax.Array:
77+
"""Run the lora pass for QKV Layer.
78+
79+
Args:
80+
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
81+
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
82+
qkv_lora_b: lora_b module for qkv.
83+
If passed in as a tensor, its shape should be (num_lora, output_dim_q + 2 * output_dim_kv, r)
84+
If passed in as a tuple of two tensors, it should contain:
85+
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
86+
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
87+
output_slices: a fixed tuple which has three item, (output_dim_q, output_dim_kv, output_dim_kv)
88+
Returns:
89+
result with shape (s, output_dim_q + 2 * output_dim_kv)
90+
"""
91+
pass
92+
93+
def run_gate_up_lora(
94+
self,
95+
x: jax.Array,
96+
gate_up_lora_a: jax.Array,
97+
gate_up_lora_b: jax.Array | tuple[jax.Array],
98+
base_output: jax.Array |None =None,
99+
*args,
100+
**kwargs,
101+
) -> jax.Array:
102+
"""Run the lora pass for gate_up_proj.
103+
104+
Args:
105+
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
106+
gate_up_lora_a: lora_a module for gate_up_proj, with shape (num_lora, 2 * r, input_dim)
107+
gate_up_lora_b: lora_b module for qkv.
108+
If passed in as a tensor, its shape should be (num_lora, 2 * output_dim, r)
109+
If passed in as a tuple, it should contain two tensors with shape (num_lora, output_dim, r)
110+
output_slices: a fixed tuple which has three item, (output_dim_q, output_dim_kv, output_dim_kv)
111+
Returns:
112+
result with shape (s, 2 * output_dim)
113+
"""
114+
pass
115+
116+
def prepare_lora_batch(
117+
self,
118+
forward_batch: ForwardBatch,
119+
weight_indices: list[int], # (bs,), please pad with -1
120+
lora_ranks: list[int], # (max_loras_per_batch,)
121+
scalings: list[float], # (max_loras_per_batch,)
122+
batch_info: Optional[LoRABatchInfo] = None,
123+
):
124+
lora_ranks_bs = []
125+
scalings_bs = []
126+
for indice in weight_indices:
127+
if indice != -1:
128+
lora_ranks_bs.append(lora_ranks[indice])
129+
scalings_bs.append(scalings[indice])
130+
else:
131+
lora_ranks_bs.append(0)
132+
scalings_bs.append(0.0)
133+
134+
assert len(forward_batch.seq_lens) == len(weight_indices)
135+
assert len(forward_batch.seq_lens) == len(lora_ranks_bs)
136+
assert len(forward_batch.seq_lens) == len(scalings_bs)
137+
138+
target_len = forward_batch.input_ids.shape[0]
139+
140+
if forward_batch.forward_mode==ForwardMode.EXTEND:
141+
scalings_cpu = np.repeat(np.array(scalings_bs, dtype=np.float32), forward_batch.seq_lens)
142+
token_lora_indices_cpu = np.repeat(np.array(weight_indices, dtype=np.int32), forward_batch.seq_lens)
143+
lora_ranks_cpu = np.repeat(np.array(lora_ranks_bs,dtype=np.int32), forward_batch.seq_lens)
144+
145+
num_to_pad = target_len - jnp.sum(forward_batch.seq_lens)
146+
147+
if num_to_pad>0:
148+
padded_scalings_cpu = np.pad(scalings_cpu,[0,num_to_pad], mode="constant", constant_values=0.0)
149+
padded_token_lora_indices_cpu = np.pad(token_lora_indices_cpu,[0,num_to_pad], mode="constant", constant_values=-1)
150+
padded_lora_ranks_cpu = np.pad(lora_ranks_cpu,[0,num_to_pad], mode="constant", constant_values=0)
151+
elif forward_batch.forward_mode == ForwardMode.DECODE:
152+
padded_scalings_cpu = np.array(scalings_bs, dtype=np.float32)
153+
padded_token_lora_indices_cpu = np.array(weight_indices, dtype=np.int32)
154+
padded_lora_ranks_cpu = np.array(lora_ranks_bs, dtype= np.int32)
155+
156+
if batch_info is None:
157+
batch_info = LoRABatchInfo()
158+
159+
batch_info = LoRABatchInfo(
160+
bs=forward_batch.batch_size,
161+
scalings=jnp.array(padded_scalings_cpu,dtype=jnp.float32),
162+
token_lora_indices = jnp.array(padded_token_lora_indices_cpu, dtype=jnp.int32),
163+
lora_ranks = jnp.array(padded_lora_ranks_cpu, dtype=jnp.int32),
164+
)
165+
166+
self.batch_info = batch_info
167+
168+
def bgmv_shrink(
169+
inputs,
170+
lora_weights,
171+
lora_indices,
172+
scaling: float = 1.0,
173+
):
174+
"""
175+
Shrink operation: maps input to low-rank space.
176+
177+
Args:
178+
inputs: (s, input_dim)
179+
lora_weights: (num_lora, c * r, input_dim), c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
180+
lora_indices: (num_tokens)
181+
Returns:
182+
[s, c * r]
183+
"""
184+
# if len(lora_weights.shape) == 4:
185+
# lora_weights = jnp.squeeze(lora_weights, axis=1)
186+
return scaling * bgmv_jax(inputs, lora_weights, lora_indices)
187+
188+
189+
def bgmv_expand_slice(
190+
inputs, # [num_tokens, lora_rank]
191+
lora_weights, # [num_loras, 1, out_features, lora_rank]
192+
base_output, # [num_tokens, total_out_features]
193+
lora_indices, # [num_tokens]
194+
slice_offset: int,
195+
slice_size: int,
196+
output_shape: tuple,
197+
):
198+
"""
199+
Expand operation: maps from low-rank space to output space.
200+
201+
Args:
202+
inputs: [num_tokens, lora_rank]
203+
lora_weights: [num_loras, 1, out_features, lora_rank]
204+
output_tensor: [num_tokens, total_out_features]
205+
lora_indices: [num_tokens]
206+
"""
207+
if len(lora_weights.shape) == 4:
208+
lora_weights = jnp.squeeze(lora_weights, axis=1)
209+
210+
outputs = bgmv_jax(inputs, lora_weights, lora_indices)
211+
212+
# Pad the outputs
213+
pad_left = slice_offset
214+
pad_right = output_shape[-1] - (slice_offset + slice_size)
215+
outputs = jnp.pad(outputs, ((0, 0), (pad_left, pad_right)), mode='constant', constant_values=0)
216+
217+
if base_output is not None:
218+
return base_output + outputs
219+
else:
220+
return outputs
221+
222+
def bgmv_jax(
223+
inputs, # (s, input_dim)
224+
loras, # (num_lora, c * r, input_dim)
225+
idxs, # (num_tokens)
226+
):
227+
"""
228+
Batched grouped matrix-vector multiplication.
229+
For each token, select the corresponding LoRA and apply matrix multiplication.
230+
"""
231+
return jnp.einsum(
232+
"td,tX,Xld->tl",
233+
inputs,
234+
jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype),
235+
loras,
236+
)

python/sgl_jax/srt/lora/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
11
from dataclasses import dataclass
22
from enum import Enum
33

4+
import jax
5+
46

57
@dataclass
68
class LoRABatchInfo:
79
# Batch size
810
bs: int
911

12+
# scaling of each lora adapter, in shape (num_tokens,)
13+
scalings: jax.Array
14+
15+
# (num_tokens,)
16+
token_lora_indices: jax.Array
17+
18+
# (num_tokens,)
19+
lora_ranks: jax.Array
20+
1021

1122
class LoRAType(Enum):
1223
LORA_A = 0

0 commit comments

Comments
 (0)