|
| 1 | +# Copyright 2023-2024 SGLang Team |
| 2 | +# Modifications copyright 2025 SGLang-JAX Team |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""LoRA layer wrappers using Flax Model Surgery.""" |
| 16 | + |
| 17 | +from __future__ import annotations |
| 18 | + |
| 19 | +from typing import TYPE_CHECKING |
| 20 | + |
| 21 | +import jax |
| 22 | +from flax import nnx |
| 23 | + |
| 24 | +if TYPE_CHECKING: |
| 25 | + from sgl_jax.srt.lora.backend.base_backend import BaseLoRABackend |
| 26 | + |
| 27 | + |
| 28 | +class LoRALinear(nnx.Module): |
| 29 | + """ |
| 30 | + LoRA wrapper for Linear layers using Flax NNX. |
| 31 | +
|
| 32 | + This wraps an existing Linear layer and adds LoRA (Low-Rank Adaptation) |
| 33 | + computation. Uses Model Surgery to preserve the original weights and sharding. |
| 34 | +
|
| 35 | + V1 implementation uses backend to perform LoRA computation: |
| 36 | + output = base_layer(x) |
| 37 | + if enabled: |
| 38 | + lora_output = backend.run_lora_a_gemm(x, lora_A_weights) |
| 39 | + output = backend.run_lora_b_gemm(lora_output, lora_B_weights, output) |
| 40 | +
|
| 41 | + Attributes: |
| 42 | + base_layer: Original Linear layer (preserves weights and sharding) |
| 43 | + lora_rank: LoRA rank dimension |
| 44 | + backend: LoRA backend for efficient computation |
| 45 | + enabled: Whether LoRA computation is active |
| 46 | + """ |
| 47 | + |
| 48 | + def __init__( |
| 49 | + self, |
| 50 | + in_features: int, |
| 51 | + out_features: int, |
| 52 | + lora_rank: int, |
| 53 | + base_layer: nnx.Linear | None = None, |
| 54 | + backend: BaseLoRABackend | None = None, |
| 55 | + rngs: nnx.Rngs | None = None, |
| 56 | + ): |
| 57 | + """ |
| 58 | + Initialize LoRA Linear layer. |
| 59 | +
|
| 60 | + Args: |
| 61 | + in_features: Input dimension |
| 62 | + out_features: Output dimension |
| 63 | + lora_rank: Rank of LoRA matrices |
| 64 | + base_layer: Existing Linear layer to wrap (optional) |
| 65 | + backend: LoRA backend for computation (optional) |
| 66 | + rngs: Random number generators for initialization |
| 67 | + """ |
| 68 | + self.in_features = in_features |
| 69 | + self.out_features = out_features |
| 70 | + self.lora_rank = lora_rank |
| 71 | + self.backend = backend |
| 72 | + |
| 73 | + # Base layer - will be populated via nnx.update() during surgery |
| 74 | + if base_layer is not None: |
| 75 | + self.base_layer = base_layer |
| 76 | + else: |
| 77 | + # Create placeholder base layer |
| 78 | + if rngs is None: |
| 79 | + rngs = nnx.Rngs(0) |
| 80 | + self.base_layer = nnx.Linear( |
| 81 | + in_features, |
| 82 | + out_features, |
| 83 | + use_bias=True, |
| 84 | + rngs=rngs, |
| 85 | + ) |
| 86 | + |
| 87 | + # Control variable (not trainable) |
| 88 | + self.enabled = nnx.Variable(False) # Whether LoRA is active |
| 89 | + |
| 90 | + def __call__(self, x: jax.Array) -> jax.Array: |
| 91 | + """ |
| 92 | + Forward pass with optional LoRA computation using backend. |
| 93 | +
|
| 94 | + Args: |
| 95 | + x: Input tensor (shape: [seq_len, in_features]) |
| 96 | +
|
| 97 | + Returns: |
| 98 | + Output tensor with LoRA delta added (if enabled) |
| 99 | + """ |
| 100 | + # Base layer computation (preserves original behavior) |
| 101 | + output = self.base_layer(x) |
| 102 | + |
| 103 | + # Add LoRA delta if enabled and backend is available |
| 104 | + if self.enabled.value and self.backend is not None: |
| 105 | + # Get LoRA weights from memory pool via backend |
| 106 | + # Backend handles batched LoRA computation for multiple adapters |
| 107 | + |
| 108 | + # Step 1: Shrink - project to low-rank space |
| 109 | + # lora_A_weights fetched from memory pool based on batch_info |
| 110 | + lora_a_output = self.backend.run_lora_a_gemm( |
| 111 | + x, None |
| 112 | + ) # Backend manages weights internally |
| 113 | + |
| 114 | + # Step 2: Expand - project back to output space and add to base output |
| 115 | + output = self.backend.run_lora_b_gemm(lora_a_output, None, output) |
| 116 | + |
| 117 | + return output |
| 118 | + |
| 119 | + |
| 120 | +class LoRAEmbedding(nnx.Module): |
| 121 | + """ |
| 122 | + LoRA wrapper for Embedding layers. |
| 123 | +
|
| 124 | + Similar to LoRALinear but for embedding layers. |
| 125 | + V1 implementation uses backend for computation. |
| 126 | + """ |
| 127 | + |
| 128 | + def __init__( |
| 129 | + self, |
| 130 | + num_embeddings: int, |
| 131 | + features: int, |
| 132 | + lora_rank: int, |
| 133 | + base_layer: nnx.Embed | None = None, |
| 134 | + backend: BaseLoRABackend | None = None, |
| 135 | + rngs: nnx.Rngs | None = None, |
| 136 | + ): |
| 137 | + """ |
| 138 | + Initialize LoRA Embedding layer. |
| 139 | +
|
| 140 | + Args: |
| 141 | + num_embeddings: Size of vocabulary |
| 142 | + features: Embedding dimension |
| 143 | + lora_rank: Rank of LoRA matrices |
| 144 | + base_layer: Existing Embed layer to wrap (optional) |
| 145 | + backend: LoRA backend for computation (optional) |
| 146 | + rngs: Random number generators |
| 147 | + """ |
| 148 | + self.num_embeddings = num_embeddings |
| 149 | + self.features = features |
| 150 | + self.lora_rank = lora_rank |
| 151 | + self.backend = backend |
| 152 | + |
| 153 | + # Base layer |
| 154 | + if base_layer is not None: |
| 155 | + self.base_layer = base_layer |
| 156 | + else: |
| 157 | + if rngs is None: |
| 158 | + rngs = nnx.Rngs(0) |
| 159 | + self.base_layer = nnx.Embed( |
| 160 | + num_embeddings, |
| 161 | + features, |
| 162 | + rngs=rngs, |
| 163 | + ) |
| 164 | + |
| 165 | + # Control variable |
| 166 | + self.enabled = nnx.Variable(False) |
| 167 | + |
| 168 | + def __call__(self, x: jax.Array) -> jax.Array: |
| 169 | + """ |
| 170 | + Forward pass for embedding with LoRA using backend. |
| 171 | +
|
| 172 | + Args: |
| 173 | + x: Input token indices |
| 174 | +
|
| 175 | + Returns: |
| 176 | + Embedded output with LoRA delta (if enabled) |
| 177 | + """ |
| 178 | + output = self.base_layer(x) |
| 179 | + |
| 180 | + # V1: Embedding LoRA computation via backend |
| 181 | + # TODO: Implement embedding-specific backend methods if needed |
| 182 | + # For now, embeddings use simple pass-through |
| 183 | + if self.enabled.value and self.backend is not None: |
| 184 | + # Backend handles embedding LoRA computation |
| 185 | + pass |
| 186 | + |
| 187 | + return output |
0 commit comments