Skip to content

Commit 5ddbcee

Browse files
committed
feat: lora manager
1 parent ec17280 commit 5ddbcee

File tree

7 files changed

+1319
-22
lines changed

7 files changed

+1319
-22
lines changed

python/sgl_jax/srt/lora/__init__.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

python/sgl_jax/srt/lora/layers.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright 2023-2024 SGLang Team
2+
# Modifications copyright 2025 SGLang-JAX Team
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""LoRA layer wrappers using Flax Model Surgery."""
16+
17+
from __future__ import annotations
18+
19+
import jax
20+
import jax.numpy as jnp
21+
from flax import nnx
22+
23+
24+
class LoRALinear(nnx.Module):
25+
"""
26+
LoRA wrapper for Linear layers using Flax NNX.
27+
28+
This wraps an existing Linear layer and adds LoRA (Low-Rank Adaptation)
29+
computation. Uses Model Surgery to preserve the original weights and sharding.
30+
31+
The forward pass computes:
32+
output = base_layer(x) + scaling * (x @ lora_A @ lora_B)
33+
34+
where the LoRA term is only added when `enabled=True`.
35+
36+
Attributes:
37+
base_layer: Original Linear layer (preserves weights and sharding)
38+
lora_A: LoRA A matrix (in_features, lora_rank)
39+
lora_B: LoRA B matrix (lora_rank, out_features)
40+
scaling: LoRA scaling factor (typically alpha / rank)
41+
enabled: Whether LoRA computation is active
42+
"""
43+
44+
def __init__(
45+
self,
46+
in_features: int,
47+
out_features: int,
48+
lora_rank: int,
49+
base_layer: nnx.Linear | None = None,
50+
rngs: nnx.Rngs | None = None,
51+
):
52+
"""
53+
Initialize LoRA Linear layer.
54+
55+
Args:
56+
in_features: Input dimension
57+
out_features: Output dimension
58+
lora_rank: Rank of LoRA matrices
59+
base_layer: Existing Linear layer to wrap (optional)
60+
rngs: Random number generators for initialization
61+
"""
62+
self.in_features = in_features
63+
self.out_features = out_features
64+
self.lora_rank = lora_rank
65+
66+
# Base layer - will be populated via nnx.update() during surgery
67+
if base_layer is not None:
68+
self.base_layer = base_layer
69+
else:
70+
# Create placeholder base layer
71+
if rngs is None:
72+
rngs = nnx.Rngs(0)
73+
self.base_layer = nnx.Linear(
74+
in_features,
75+
out_features,
76+
use_bias=True,
77+
rngs=rngs,
78+
)
79+
80+
# LoRA parameters (initialized to small random values)
81+
if rngs is None:
82+
rngs = nnx.Rngs(0)
83+
84+
# Initialize lora_A with normal distribution scaled by 1/sqrt(rank)
85+
self.lora_A = nnx.Param(
86+
jax.random.normal(rngs(), (in_features, lora_rank)) / jnp.sqrt(lora_rank)
87+
)
88+
89+
# Initialize lora_B to zeros (standard LoRA initialization)
90+
self.lora_B = nnx.Param(jnp.zeros((lora_rank, out_features)))
91+
92+
# Control variables (not trainable)
93+
self.scaling = nnx.Variable(1.0) # Will be set to alpha / rank
94+
self.enabled = nnx.Variable(False) # Whether LoRA is active
95+
96+
def __call__(self, x: jax.Array) -> jax.Array:
97+
"""
98+
Forward pass with optional LoRA computation.
99+
100+
Args:
101+
x: Input tensor
102+
103+
Returns:
104+
Output tensor with LoRA delta added (if enabled)
105+
"""
106+
# Base layer computation (preserves original behavior)
107+
output = self.base_layer(x)
108+
109+
# Add LoRA delta if enabled
110+
if self.enabled.value:
111+
# Compute: x @ lora_A @ lora_B
112+
lora_delta = (x @ self.lora_A.value) @ self.lora_B.value
113+
output = output + self.scaling.value * lora_delta
114+
115+
return output
116+
117+
118+
class LoRAEmbedding(nnx.Module):
119+
"""
120+
LoRA wrapper for Embedding layers.
121+
122+
Similar to LoRALinear but for embedding layers.
123+
Currently a placeholder for future implementation.
124+
"""
125+
126+
def __init__(
127+
self,
128+
num_embeddings: int,
129+
features: int,
130+
lora_rank: int,
131+
base_layer: nnx.Embed | None = None,
132+
rngs: nnx.Rngs | None = None,
133+
):
134+
"""
135+
Initialize LoRA Embedding layer.
136+
137+
Args:
138+
num_embeddings: Size of vocabulary
139+
features: Embedding dimension
140+
lora_rank: Rank of LoRA matrices
141+
base_layer: Existing Embed layer to wrap (optional)
142+
rngs: Random number generators
143+
"""
144+
self.num_embeddings = num_embeddings
145+
self.features = features
146+
self.lora_rank = lora_rank
147+
148+
# Base layer
149+
if base_layer is not None:
150+
self.base_layer = base_layer
151+
else:
152+
if rngs is None:
153+
rngs = nnx.Rngs(0)
154+
self.base_layer = nnx.Embed(
155+
num_embeddings,
156+
features,
157+
rngs=rngs,
158+
)
159+
160+
# LoRA parameters for embeddings
161+
if rngs is None:
162+
rngs = nnx.Rngs(0)
163+
164+
self.lora_A = nnx.Param(jax.random.normal(rngs(), (num_embeddings, lora_rank)))
165+
self.lora_B = nnx.Param(jnp.zeros((lora_rank, features)))
166+
167+
self.scaling = nnx.Variable(1.0)
168+
self.enabled = nnx.Variable(False)
169+
170+
def __call__(self, x: jax.Array) -> jax.Array:
171+
"""
172+
Forward pass for embedding with LoRA.
173+
174+
Args:
175+
x: Input token indices
176+
177+
Returns:
178+
Embedded output with LoRA delta
179+
"""
180+
output = self.base_layer(x)
181+
182+
if self.enabled.value:
183+
# Embedding LoRA: lookup lora_A then multiply by lora_B
184+
lora_a_embed = self.lora_A.value[x] # Shape: [..., lora_rank]
185+
lora_delta = lora_a_embed @ self.lora_B.value # Shape: [..., features]
186+
output = output + self.scaling.value * lora_delta
187+
188+
return output

0 commit comments

Comments
 (0)