Skip to content

Commit 23d8fce

Browse files
authored
[Feat] add bgmv backend for LoRA (#381)
1 parent d316da6 commit 23d8fce

File tree

5 files changed

+482
-1
lines changed

5 files changed

+482
-1
lines changed
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import jax
2+
3+
from sgl_jax.srt.lora.utils import LoRABatchInfo
4+
from sgl_jax.srt.model_executor.forward_batch_info import ForwardBatch
5+
6+
7+
class BaseLoRABackend:
8+
"""Base class for different Lora backends.
9+
Each backend has its own implementation of Lora kernels.
10+
11+
Args:
12+
max_loras_per_batch: maximum number of different lora weights
13+
that can be applied in a single forward batch.
14+
device: the device where the backend runs.
15+
"""
16+
17+
def __init__(self, max_loras_per_batch: int):
18+
self.max_loras_per_batch = max_loras_per_batch
19+
20+
def run_lora_a_gemm(self, x: jax.Array, weights: jax.Array, *args, **kwargs) -> jax.Array:
21+
"""Run gemm of lora a modules with current backend.
22+
23+
Args:
24+
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
25+
weights: a set of lora weights with shape (num_lora, r, input_dim), r is lora rank,
26+
usually input_dim is much larger than r
27+
Returns:
28+
result with shape (s, r)
29+
"""
30+
pass
31+
32+
def run_lora_b_gemm(self, x: jax.Array, weights: jax.Array, *args, **kwargs) -> jax.Array:
33+
"""Run gemm of lora b modules with current backend.
34+
35+
Args:
36+
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank
37+
weights: a set of lora weights with shape (num_lora, output_dim, r)
38+
usually output_dim is much larger than r
39+
Returns:
40+
result with shape (s, output_dim)
41+
"""
42+
pass
43+
44+
def run_qkv_lora(
45+
self,
46+
x: jax.Array,
47+
qkv_lora_a: jax.Array,
48+
qkv_lora_b: jax.Array | tuple[jax.Array],
49+
output_slices: tuple,
50+
*args,
51+
**kwargs,
52+
) -> jax.Array:
53+
"""Run the lora pass for QKV Layer.
54+
55+
Args:
56+
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
57+
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
58+
qkv_lora_b: lora_b module for qkv.
59+
If passed in as a tensor, its shape should be (num_lora, output_dim_q + 2 * output_dim_kv, r)
60+
If passed in as a tuple of two tensors, it should contain:
61+
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
62+
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
63+
output_slices: a fixed tuple which has three item, (output_dim_q, output_dim_kv, output_dim_kv)
64+
Returns:
65+
result with shape (s, output_dim_q + 2 * output_dim_kv)
66+
"""
67+
pass
68+
69+
def run_gate_up_lora(
70+
self,
71+
x: jax.Array,
72+
gate_up_lora_a: jax.Array,
73+
gate_up_lora_b: jax.Array | tuple[jax.Array],
74+
*args,
75+
**kwargs,
76+
) -> jax.Array:
77+
"""Run the lora pass for gate_up_proj.
78+
79+
Args:
80+
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
81+
gate_up_lora_a: lora_a module for gate_up_proj, with shape (num_lora, 2 * r, input_dim)
82+
gate_up_lora_b: lora_b module for qkv.
83+
If passed in as a tensor, its shape should be (num_lora, 2 * output_dim, r)
84+
If passed in as a tuple, it should contain two tensors with shape (num_lora, output_dim, r)
85+
output_slices: a fixed tuple which has three item, (output_dim_q, output_dim_kv, output_dim_kv)
86+
Returns:
87+
result with shape (s, 2 * output_dim)
88+
"""
89+
pass
90+
91+
def prepare_lora_batch(
92+
self,
93+
forward_batch: ForwardBatch,
94+
weight_indices: list[int],
95+
lora_ranks: list[int],
96+
scalings: list[float],
97+
batch_info: LoRABatchInfo | None = None,
98+
):
99+
"""Prepare the lora weights and batch info for current forward batch.
100+
101+
This method provides a hook for each backend to conduct its own preparation
102+
logic for each forward batch.
103+
104+
Args:
105+
forward_batch: the ForwardBatch object for current forward pass
106+
weight_indices: list of indices of lora weights to be applied for current batch
107+
lora_ranks: list of lora ranks corresponding to weight_indices
108+
scalings: list of scaling factors corresponding to weight_indices
109+
batch_info: optional LoRABatchInfo object, if not provided, the backend should use its own
110+
internal batch info
111+
"""
112+
pass

0 commit comments

Comments
 (0)