Skip to content

Commit f5f83a9

Browse files
committed
feat: lora manager
1 parent ec17280 commit f5f83a9

File tree

7 files changed

+1606
-22
lines changed

7 files changed

+1606
-22
lines changed

python/sgl_jax/srt/lora/__init__.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

python/sgl_jax/srt/lora/layers.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# Copyright 2023-2024 SGLang Team
2+
# Modifications copyright 2025 SGLang-JAX Team
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""LoRA layer wrappers using Flax Model Surgery."""
16+
17+
from __future__ import annotations
18+
19+
from typing import TYPE_CHECKING
20+
21+
import jax
22+
from flax import nnx
23+
24+
if TYPE_CHECKING:
25+
from sgl_jax.srt.lora.backend.base_backend import BaseLoRABackend
26+
27+
28+
class LoRALinear(nnx.Module):
29+
"""
30+
LoRA wrapper for Linear layers using Flax NNX.
31+
32+
This wraps an existing Linear layer and adds LoRA (Low-Rank Adaptation)
33+
computation. Uses Model Surgery to preserve the original weights and sharding.
34+
35+
V1 implementation uses backend to perform LoRA computation:
36+
output = base_layer(x)
37+
if enabled:
38+
lora_output = backend.run_lora_a_gemm(x, lora_A_weights)
39+
output = backend.run_lora_b_gemm(lora_output, lora_B_weights, output)
40+
41+
Attributes:
42+
base_layer: Original Linear layer (preserves weights and sharding)
43+
lora_rank: LoRA rank dimension
44+
backend: LoRA backend for efficient computation
45+
enabled: Whether LoRA computation is active
46+
"""
47+
48+
def __init__(
49+
self,
50+
in_features: int,
51+
out_features: int,
52+
lora_rank: int,
53+
base_layer: nnx.Linear | None = None,
54+
backend: BaseLoRABackend | None = None,
55+
rngs: nnx.Rngs | None = None,
56+
):
57+
"""
58+
Initialize LoRA Linear layer.
59+
60+
Args:
61+
in_features: Input dimension
62+
out_features: Output dimension
63+
lora_rank: Rank of LoRA matrices
64+
base_layer: Existing Linear layer to wrap (optional)
65+
backend: LoRA backend for computation (optional)
66+
rngs: Random number generators for initialization
67+
"""
68+
self.in_features = in_features
69+
self.out_features = out_features
70+
self.lora_rank = lora_rank
71+
self.backend = backend
72+
73+
# Base layer - will be populated via nnx.update() during surgery
74+
if base_layer is not None:
75+
self.base_layer = base_layer
76+
else:
77+
# Create placeholder base layer
78+
if rngs is None:
79+
rngs = nnx.Rngs(0)
80+
self.base_layer = nnx.Linear(
81+
in_features,
82+
out_features,
83+
use_bias=True,
84+
rngs=rngs,
85+
)
86+
87+
# Control variable (not trainable)
88+
self.enabled = nnx.Variable(False) # Whether LoRA is active
89+
90+
def __call__(self, x: jax.Array) -> jax.Array:
91+
"""
92+
Forward pass with optional LoRA computation using backend.
93+
94+
Args:
95+
x: Input tensor (shape: [seq_len, in_features])
96+
97+
Returns:
98+
Output tensor with LoRA delta added (if enabled)
99+
"""
100+
# Base layer computation (preserves original behavior)
101+
output = self.base_layer(x)
102+
103+
# Add LoRA delta if enabled and backend is available
104+
if self.enabled.value and self.backend is not None:
105+
# Get LoRA weights from memory pool via backend
106+
# Backend handles batched LoRA computation for multiple adapters
107+
108+
# Step 1: Shrink - project to low-rank space
109+
# lora_A_weights fetched from memory pool based on batch_info
110+
lora_a_output = self.backend.run_lora_a_gemm(
111+
x, None
112+
) # Backend manages weights internally
113+
114+
# Step 2: Expand - project back to output space and add to base output
115+
output = self.backend.run_lora_b_gemm(lora_a_output, None, output)
116+
117+
return output
118+
119+
120+
class LoRAEmbedding(nnx.Module):
121+
"""
122+
LoRA wrapper for Embedding layers.
123+
124+
Similar to LoRALinear but for embedding layers.
125+
V1 implementation uses backend for computation.
126+
"""
127+
128+
def __init__(
129+
self,
130+
num_embeddings: int,
131+
features: int,
132+
lora_rank: int,
133+
base_layer: nnx.Embed | None = None,
134+
backend: BaseLoRABackend | None = None,
135+
rngs: nnx.Rngs | None = None,
136+
):
137+
"""
138+
Initialize LoRA Embedding layer.
139+
140+
Args:
141+
num_embeddings: Size of vocabulary
142+
features: Embedding dimension
143+
lora_rank: Rank of LoRA matrices
144+
base_layer: Existing Embed layer to wrap (optional)
145+
backend: LoRA backend for computation (optional)
146+
rngs: Random number generators
147+
"""
148+
self.num_embeddings = num_embeddings
149+
self.features = features
150+
self.lora_rank = lora_rank
151+
self.backend = backend
152+
153+
# Base layer
154+
if base_layer is not None:
155+
self.base_layer = base_layer
156+
else:
157+
if rngs is None:
158+
rngs = nnx.Rngs(0)
159+
self.base_layer = nnx.Embed(
160+
num_embeddings,
161+
features,
162+
rngs=rngs,
163+
)
164+
165+
# Control variable
166+
self.enabled = nnx.Variable(False)
167+
168+
def __call__(self, x: jax.Array) -> jax.Array:
169+
"""
170+
Forward pass for embedding with LoRA using backend.
171+
172+
Args:
173+
x: Input token indices
174+
175+
Returns:
176+
Embedded output with LoRA delta (if enabled)
177+
"""
178+
output = self.base_layer(x)
179+
180+
# V1: Embedding LoRA computation via backend
181+
# TODO: Implement embedding-specific backend methods if needed
182+
# For now, embeddings use simple pass-through
183+
if self.enabled.value and self.backend is not None:
184+
# Backend handles embedding LoRA computation
185+
pass
186+
187+
return output

0 commit comments

Comments
 (0)