1+ from typing import Optional , Tuple , Union
2+
3+ import jax
4+
5+ from sgl_jax .srt .lora .utils import LoRABatchInfo
6+ from sgl_jax .srt .model_executor .forward_batch_info import ForwardBatch
7+
8+
9+ class BaseLoRABackend :
10+ """Base class for different Lora backends.
11+ Each backend has its own implementation of Lora kernels.
12+
13+ Args:
14+ max_loras_per_batch: maximum number of different lora weights
15+ that can be applied in a single forward batch.
16+ device: the device where the backend runs.
17+ """
18+
19+ def __init__ (self , max_loras_per_batch : int ):
20+ self .max_loras_per_batch = max_loras_per_batch
21+
22+ def run_lora_a_gemm (
23+ self , x : jax .Array , weights : jax .Array , * args , ** kwargs
24+ ) -> jax .Array :
25+ """Run gemm of lora a modules with current backend.
26+
27+ Args:
28+ x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
29+ weights: a set of lora weights with shape (num_lora, c * r, input_dim),
30+ here r is lora rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
31+ usually input_dim is much larger than r
32+ Returns:
33+ result with shape (s, c * r)
34+ """
35+ pass
36+
37+ def run_lora_b_gemm (
38+ self , x : jax .Array , weights : jax .Array , * args , ** kwargs
39+ ) -> jax .Array :
40+ """Run gemm of lora b modules with current backend.
41+
42+ Args:
43+ x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank
44+ weights: a set of lora weights with shape (num_lora, output_dim, r)
45+ usually output_dim is much larger than r
46+ Returns:
47+ result with shape (s, output_dim)
48+ """
49+ pass
50+
51+ def run_qkv_lora (
52+ self ,
53+ x : jax .Array ,
54+ qkv_lora_a : jax .Array ,
55+ qkv_lora_b : Union [jax .Array , Tuple [jax .Array ]],
56+ * args ,
57+ ** kwargs ,
58+ ) -> jax .Array :
59+ """Run the lora pass for QKV Layer.
60+
61+ Args:
62+ x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
63+ qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
64+ qkv_lora_b: lora_b module for qkv.
65+ If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
66+ If passed in as a tuple of two tensors, it should contain:
67+ a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
68+ and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
69+ Returns:
70+ result with shape (s, output_dim_q + 2 * output_dim_kv)
71+ """
72+ pass
73+
74+ def run_gate_up_lora (
75+ self ,
76+ x : jax .Array ,
77+ gate_up_lora_a : jax .Array ,
78+ gate_up_lora_b : Union [jax .Array , Tuple [jax .Array ]],
79+ * args ,
80+ ** kwargs ,
81+ ) -> jax .Array :
82+ """Run the lora pass for gate_up_proj.
83+
84+ Args:
85+ x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
86+ gate_up_lora_a: lora_a module for gate_up_proj, with shape (num_lora, 2 * r, input_dim)
87+ gate_up_lora_b: lora_b module for qkv.
88+ If passed in as a tensor, its shape should be (num_lora, 2 * output_dim, r)
89+ If passed in as a tuple, it should contain two tensors with shape (num_lora, output_dim, r)
90+ Returns:
91+ result with shape (s, 2 * output_dim)
92+ """
93+ pass
94+
95+ def prepare_lora_batch (
96+ self ,
97+ forward_batch : ForwardBatch ,
98+ weight_indices : list [int ],
99+ lora_ranks : list [int ],
100+ scalings : list [float ],
101+ batch_info : Optional [LoRABatchInfo ] = None ,
102+ ):
103+ """Prepare the lora weights and batch info for current forward batch.
104+
105+ This method provides a hook for each backend to conduct its own preparation
106+ logic for each forward batch.
107+
108+ Args:
109+ forward_batch: the ForwardBatch object for current forward pass
110+ weight_indices: list of indices of lora weights to be applied for current batch
111+ lora_ranks: list of lora ranks corresponding to weight_indices
112+ scalings: list of scaling factors corresponding to weight_indices
113+ batch_info: optional LoRABatchInfo object, if not provided, the backend should use its own
114+ internal batch info
115+ """
116+ pass
0 commit comments