|
| 1 | +import jax |
| 2 | + |
| 3 | +from sgl_jax.srt.lora.utils import LoRABatchInfo |
| 4 | +from sgl_jax.srt.model_executor.forward_batch_info import ForwardBatch |
| 5 | + |
| 6 | + |
| 7 | +class BaseLoRABackend: |
| 8 | + """Base class for different Lora backends. |
| 9 | + Each backend has its own implementation of Lora kernels. |
| 10 | +
|
| 11 | + Args: |
| 12 | + max_loras_per_batch: maximum number of different lora weights |
| 13 | + that can be applied in a single forward batch. |
| 14 | + device: the device where the backend runs. |
| 15 | + """ |
| 16 | + |
| 17 | + def __init__(self, max_loras_per_batch: int): |
| 18 | + self.max_loras_per_batch = max_loras_per_batch |
| 19 | + |
| 20 | + def run_lora_a_gemm(self, x: jax.Array, weights: jax.Array, *args, **kwargs) -> jax.Array: |
| 21 | + """Run gemm of lora a modules with current backend. |
| 22 | +
|
| 23 | + Args: |
| 24 | + x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths |
| 25 | + weights: a set of lora weights with shape (num_lora, r, input_dim), r is lora rank, |
| 26 | + usually input_dim is much larger than r |
| 27 | + Returns: |
| 28 | + result with shape (s, r) |
| 29 | + """ |
| 30 | + pass |
| 31 | + |
| 32 | + def run_lora_b_gemm(self, x: jax.Array, weights: jax.Array, *args, **kwargs) -> jax.Array: |
| 33 | + """Run gemm of lora b modules with current backend. |
| 34 | +
|
| 35 | + Args: |
| 36 | + x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank |
| 37 | + weights: a set of lora weights with shape (num_lora, output_dim, r) |
| 38 | + usually output_dim is much larger than r |
| 39 | + Returns: |
| 40 | + result with shape (s, output_dim) |
| 41 | + """ |
| 42 | + pass |
| 43 | + |
| 44 | + def run_qkv_lora( |
| 45 | + self, |
| 46 | + x: jax.Array, |
| 47 | + qkv_lora_a: jax.Array, |
| 48 | + qkv_lora_b: jax.Array | tuple[jax.Array], |
| 49 | + output_slices: tuple, |
| 50 | + *args, |
| 51 | + **kwargs, |
| 52 | + ) -> jax.Array: |
| 53 | + """Run the lora pass for QKV Layer. |
| 54 | +
|
| 55 | + Args: |
| 56 | + x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths |
| 57 | + qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim) |
| 58 | + qkv_lora_b: lora_b module for qkv. |
| 59 | + If passed in as a tensor, its shape should be (num_lora, output_dim_q + 2 * output_dim_kv, r) |
| 60 | + If passed in as a tuple of two tensors, it should contain: |
| 61 | + a lora_b module for q, with shape (1, num_lora, output_dim_q, r) |
| 62 | + and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r) |
| 63 | + output_slices: a fixed tuple which has three item, (output_dim_q, output_dim_kv, output_dim_kv) |
| 64 | + Returns: |
| 65 | + result with shape (s, output_dim_q + 2 * output_dim_kv) |
| 66 | + """ |
| 67 | + pass |
| 68 | + |
| 69 | + def run_gate_up_lora( |
| 70 | + self, |
| 71 | + x: jax.Array, |
| 72 | + gate_up_lora_a: jax.Array, |
| 73 | + gate_up_lora_b: jax.Array | tuple[jax.Array], |
| 74 | + *args, |
| 75 | + **kwargs, |
| 76 | + ) -> jax.Array: |
| 77 | + """Run the lora pass for gate_up_proj. |
| 78 | +
|
| 79 | + Args: |
| 80 | + x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths |
| 81 | + gate_up_lora_a: lora_a module for gate_up_proj, with shape (num_lora, 2 * r, input_dim) |
| 82 | + gate_up_lora_b: lora_b module for qkv. |
| 83 | + If passed in as a tensor, its shape should be (num_lora, 2 * output_dim, r) |
| 84 | + If passed in as a tuple, it should contain two tensors with shape (num_lora, output_dim, r) |
| 85 | + output_slices: a fixed tuple which has three item, (output_dim_q, output_dim_kv, output_dim_kv) |
| 86 | + Returns: |
| 87 | + result with shape (s, 2 * output_dim) |
| 88 | + """ |
| 89 | + pass |
| 90 | + |
| 91 | + def prepare_lora_batch( |
| 92 | + self, |
| 93 | + forward_batch: ForwardBatch, |
| 94 | + weight_indices: list[int], |
| 95 | + lora_ranks: list[int], |
| 96 | + scalings: list[float], |
| 97 | + batch_info: LoRABatchInfo | None = None, |
| 98 | + ): |
| 99 | + """Prepare the lora weights and batch info for current forward batch. |
| 100 | +
|
| 101 | + This method provides a hook for each backend to conduct its own preparation |
| 102 | + logic for each forward batch. |
| 103 | +
|
| 104 | + Args: |
| 105 | + forward_batch: the ForwardBatch object for current forward pass |
| 106 | + weight_indices: list of indices of lora weights to be applied for current batch |
| 107 | + lora_ranks: list of lora ranks corresponding to weight_indices |
| 108 | + scalings: list of scaling factors corresponding to weight_indices |
| 109 | + batch_info: optional LoRABatchInfo object, if not provided, the backend should use its own |
| 110 | + internal batch info |
| 111 | + """ |
| 112 | + pass |
0 commit comments