Skip to content

Commit 9e6c476

Browse files
committed
add base_backend
1 parent 3b879f0 commit 9e6c476

File tree

3 files changed

+128
-0
lines changed

3 files changed

+128
-0
lines changed

python/sgl_jax/srt/layers/linear.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,5 @@ def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array | None]:
8585
if bias is not None:
8686
output = output + bias.value
8787
output_bias = self.bias if self.skip_bias_add else None
88+
jax.lax.select()
8889
return output, output_bias
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from typing import Optional, Tuple, Union
2+
3+
import jax
4+
5+
from sgl_jax.srt.lora.utils import LoRABatchInfo
6+
from sgl_jax.srt.model_executor.forward_batch_info import ForwardBatch
7+
8+
9+
class BaseLoRABackend:
10+
"""Base class for different Lora backends.
11+
Each backend has its own implementation of Lora kernels.
12+
13+
Args:
14+
max_loras_per_batch: maximum number of different lora weights
15+
that can be applied in a single forward batch.
16+
device: the device where the backend runs.
17+
"""
18+
19+
def __init__(self, max_loras_per_batch: int):
20+
self.max_loras_per_batch = max_loras_per_batch
21+
22+
def run_lora_a_gemm(
23+
self, x: jax.Array, weights: jax.Array, *args, **kwargs
24+
) -> jax.Array:
25+
"""Run gemm of lora a modules with current backend.
26+
27+
Args:
28+
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
29+
weights: a set of lora weights with shape (num_lora, c * r, input_dim),
30+
here r is lora rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
31+
usually input_dim is much larger than r
32+
Returns:
33+
result with shape (s, c * r)
34+
"""
35+
pass
36+
37+
def run_lora_b_gemm(
38+
self, x: jax.Array, weights: jax.Array, *args, **kwargs
39+
) -> jax.Array:
40+
"""Run gemm of lora b modules with current backend.
41+
42+
Args:
43+
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank
44+
weights: a set of lora weights with shape (num_lora, output_dim, r)
45+
usually output_dim is much larger than r
46+
Returns:
47+
result with shape (s, output_dim)
48+
"""
49+
pass
50+
51+
def run_qkv_lora(
52+
self,
53+
x: jax.Array,
54+
qkv_lora_a: jax.Array,
55+
qkv_lora_b: Union[jax.Array, Tuple[jax.Array]],
56+
*args,
57+
**kwargs,
58+
) -> jax.Array:
59+
"""Run the lora pass for QKV Layer.
60+
61+
Args:
62+
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
63+
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
64+
qkv_lora_b: lora_b module for qkv.
65+
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
66+
If passed in as a tuple of two tensors, it should contain:
67+
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
68+
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
69+
Returns:
70+
result with shape (s, output_dim_q + 2 * output_dim_kv)
71+
"""
72+
pass
73+
74+
def run_gate_up_lora(
75+
self,
76+
x: jax.Array,
77+
gate_up_lora_a: jax.Array,
78+
gate_up_lora_b: Union[jax.Array, Tuple[jax.Array]],
79+
*args,
80+
**kwargs,
81+
) -> jax.Array:
82+
"""Run the lora pass for gate_up_proj.
83+
84+
Args:
85+
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
86+
gate_up_lora_a: lora_a module for gate_up_proj, with shape (num_lora, 2 * r, input_dim)
87+
gate_up_lora_b: lora_b module for qkv.
88+
If passed in as a tensor, its shape should be (num_lora, 2 * output_dim, r)
89+
If passed in as a tuple, it should contain two tensors with shape (num_lora, output_dim, r)
90+
Returns:
91+
result with shape (s, 2 * output_dim)
92+
"""
93+
pass
94+
95+
def prepare_lora_batch(
96+
self,
97+
forward_batch: ForwardBatch,
98+
weight_indices: list[int],
99+
lora_ranks: list[int],
100+
scalings: list[float],
101+
batch_info: Optional[LoRABatchInfo] = None,
102+
):
103+
"""Prepare the lora weights and batch info for current forward batch.
104+
105+
This method provides a hook for each backend to conduct its own preparation
106+
logic for each forward batch.
107+
108+
Args:
109+
forward_batch: the ForwardBatch object for current forward pass
110+
weight_indices: list of indices of lora weights to be applied for current batch
111+
lora_ranks: list of lora ranks corresponding to weight_indices
112+
scalings: list of scaling factors corresponding to weight_indices
113+
batch_info: optional LoRABatchInfo object, if not provided, the backend should use its own
114+
internal batch info
115+
"""
116+
pass

python/sgl_jax/srt/lora/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from dataclasses import dataclass
2+
from enum import Enum
3+
4+
@dataclass
5+
class LoRABatchInfo:
6+
# Batch size
7+
bs: int
8+
9+
class LoRAType(Enum):
10+
LORA_A = 0
11+
LORA_B = 1

0 commit comments

Comments
 (0)