Skip to content

Commit 19d513d

Browse files
committed
feat: lora manager
1 parent 91bf332 commit 19d513d

File tree

6 files changed

+1317
-101
lines changed

6 files changed

+1317
-101
lines changed

python/sgl_jax/srt/lora/__init__.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

python/sgl_jax/srt/lora/layers.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright 2023-2024 SGLang Team
2+
# Modifications copyright 2025 SGLang-JAX Team
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""LoRA layer wrappers using Flax Model Surgery."""
16+
17+
from __future__ import annotations
18+
19+
import jax
20+
import jax.numpy as jnp
21+
from flax import nnx
22+
23+
24+
class LoRALinear(nnx.Module):
25+
"""
26+
LoRA wrapper for Linear layers using Flax NNX.
27+
28+
This wraps an existing Linear layer and adds LoRA (Low-Rank Adaptation)
29+
computation. Uses Model Surgery to preserve the original weights and sharding.
30+
31+
The forward pass computes:
32+
output = base_layer(x) + scaling * (x @ lora_A @ lora_B)
33+
34+
where the LoRA term is only added when `enabled=True`.
35+
36+
Attributes:
37+
base_layer: Original Linear layer (preserves weights and sharding)
38+
lora_A: LoRA A matrix (in_features, lora_rank)
39+
lora_B: LoRA B matrix (lora_rank, out_features)
40+
scaling: LoRA scaling factor (typically alpha / rank)
41+
enabled: Whether LoRA computation is active
42+
"""
43+
44+
def __init__(
45+
self,
46+
in_features: int,
47+
out_features: int,
48+
lora_rank: int,
49+
base_layer: nnx.Linear | None = None,
50+
rngs: nnx.Rngs | None = None,
51+
):
52+
"""
53+
Initialize LoRA Linear layer.
54+
55+
Args:
56+
in_features: Input dimension
57+
out_features: Output dimension
58+
lora_rank: Rank of LoRA matrices
59+
base_layer: Existing Linear layer to wrap (optional)
60+
rngs: Random number generators for initialization
61+
"""
62+
self.in_features = in_features
63+
self.out_features = out_features
64+
self.lora_rank = lora_rank
65+
66+
# Base layer - will be populated via nnx.update() during surgery
67+
if base_layer is not None:
68+
self.base_layer = base_layer
69+
else:
70+
# Create placeholder base layer
71+
if rngs is None:
72+
rngs = nnx.Rngs(0)
73+
self.base_layer = nnx.Linear(
74+
in_features,
75+
out_features,
76+
use_bias=True,
77+
rngs=rngs,
78+
)
79+
80+
# LoRA parameters (initialized to small random values)
81+
if rngs is None:
82+
rngs = nnx.Rngs(0)
83+
84+
# Initialize lora_A with normal distribution scaled by 1/sqrt(rank)
85+
self.lora_A = nnx.Param(
86+
jax.random.normal(rngs(), (in_features, lora_rank)) / jnp.sqrt(lora_rank)
87+
)
88+
89+
# Initialize lora_B to zeros (standard LoRA initialization)
90+
self.lora_B = nnx.Param(jnp.zeros((lora_rank, out_features)))
91+
92+
# Control variables (not trainable)
93+
self.scaling = nnx.Variable(1.0) # Will be set to alpha / rank
94+
self.enabled = nnx.Variable(False) # Whether LoRA is active
95+
96+
def __call__(self, x: jax.Array) -> jax.Array:
97+
"""
98+
Forward pass with optional LoRA computation.
99+
100+
Args:
101+
x: Input tensor
102+
103+
Returns:
104+
Output tensor with LoRA delta added (if enabled)
105+
"""
106+
# Base layer computation (preserves original behavior)
107+
output = self.base_layer(x)
108+
109+
# Add LoRA delta if enabled
110+
if self.enabled.value:
111+
# Compute: x @ lora_A @ lora_B
112+
lora_delta = (x @ self.lora_A.value) @ self.lora_B.value
113+
output = output + self.scaling.value * lora_delta
114+
115+
return output
116+
117+
118+
class LoRAEmbedding(nnx.Module):
119+
"""
120+
LoRA wrapper for Embedding layers.
121+
122+
Similar to LoRALinear but for embedding layers.
123+
Currently a placeholder for future implementation.
124+
"""
125+
126+
def __init__(
127+
self,
128+
num_embeddings: int,
129+
features: int,
130+
lora_rank: int,
131+
base_layer: nnx.Embed | None = None,
132+
rngs: nnx.Rngs | None = None,
133+
):
134+
"""
135+
Initialize LoRA Embedding layer.
136+
137+
Args:
138+
num_embeddings: Size of vocabulary
139+
features: Embedding dimension
140+
lora_rank: Rank of LoRA matrices
141+
base_layer: Existing Embed layer to wrap (optional)
142+
rngs: Random number generators
143+
"""
144+
self.num_embeddings = num_embeddings
145+
self.features = features
146+
self.lora_rank = lora_rank
147+
148+
# Base layer
149+
if base_layer is not None:
150+
self.base_layer = base_layer
151+
else:
152+
if rngs is None:
153+
rngs = nnx.Rngs(0)
154+
self.base_layer = nnx.Embed(
155+
num_embeddings,
156+
features,
157+
rngs=rngs,
158+
)
159+
160+
# LoRA parameters for embeddings
161+
if rngs is None:
162+
rngs = nnx.Rngs(0)
163+
164+
self.lora_A = nnx.Param(jax.random.normal(rngs(), (num_embeddings, lora_rank)))
165+
self.lora_B = nnx.Param(jnp.zeros((lora_rank, features)))
166+
167+
self.scaling = nnx.Variable(1.0)
168+
self.enabled = nnx.Variable(False)
169+
170+
def __call__(self, x: jax.Array) -> jax.Array:
171+
"""
172+
Forward pass for embedding with LoRA.
173+
174+
Args:
175+
x: Input token indices
176+
177+
Returns:
178+
Embedded output with LoRA delta
179+
"""
180+
output = self.base_layer(x)
181+
182+
if self.enabled.value:
183+
# Embedding LoRA: lookup lora_A then multiply by lora_B
184+
lora_a_embed = self.lora_A.value[x] # Shape: [..., lora_rank]
185+
lora_delta = lora_a_embed @ self.lora_B.value # Shape: [..., features]
186+
output = output + self.scaling.value * lora_delta
187+
188+
return output

python/sgl_jax/srt/lora/lora.py

Lines changed: 0 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import re
2323

2424
import jax
25-
import jax.numpy as jnp
2625
from flax import nnx
2726

2827
from sgl_jax.srt.configs.load_config import LoadConfig
@@ -93,83 +92,3 @@ def initialize_weights(self):
9392
self.layers[layer_id].weights[name] = loaded_weight
9493
else:
9594
self.weights[name] = loaded_weight
96-
97-
# normalize kv_proj and gate_up_proj
98-
for layer in self.layers:
99-
weight_names = list(layer.weights.keys())
100-
self.normalize_qkv_proj(weight_names, layer.weights)
101-
self.normalize_gate_up_proj(weight_names, layer.weights)
102-
103-
def normalize_qkv_proj(self, weight_names: list[str], weights: dict[str, jax.Array]):
104-
# Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
105-
target_module = set()
106-
for weight_name in weight_names:
107-
if "k_proj" in weight_name:
108-
target_module.add("k_proj")
109-
if "q_proj" in weight_name:
110-
target_module.add("q_proj")
111-
if "v_proj" in weight_name:
112-
target_module.add("v_proj")
113-
if "qkv_proj" in weight_name:
114-
target_module.add("qkv_proj")
115-
if len(target_module) == 0:
116-
return
117-
118-
for weight_name in weight_names:
119-
# We assume every lora adaptor should contain lora modules for q_proj
120-
if "q_proj" in weight_name:
121-
q_name = weight_name
122-
k_name = weight_name.replace("q_proj", "k_proj")
123-
v_name = weight_name.replace("q_proj", "v_proj")
124-
qkv_name = weight_name.replace("q_proj", "qkv_proj")
125-
126-
# If k_proj doesn't have lora, initialize it to zero
127-
k_proj_weight = (
128-
weights[k_name]
129-
if "k_proj" in target_module
130-
else jnp.zeros_like(weights[v_name])
131-
)
132-
weights[qkv_name] = jnp.concatenate(
133-
(
134-
weights[q_name],
135-
k_proj_weight,
136-
weights[v_name],
137-
),
138-
0,
139-
)
140-
weights.pop(q_name)
141-
if "k_proj" in target_module:
142-
weights.pop(k_name)
143-
weights.pop(v_name)
144-
elif "qkv_proj" in weight_name:
145-
# If qkv_proj is already stacked, we normalize it following the SGL convention.
146-
qkv_name = weight_name
147-
q_name = weight_name.replace("qkv_proj", "q_proj")
148-
k_name = weight_name.replace("qkv_proj", "k_proj")
149-
v_name = weight_name.replace("qkv_proj", "v_proj")
150-
if "lora_A" in weight_name:
151-
weights[qkv_name] = weights[qkv_name].repeat(3, 1)
152-
# else: no-op as LoRA B weight is already stacked.
153-
154-
def normalize_gate_up_proj(self, weight_names: list[str], weights: dict[str, jax.Array]):
155-
for weight_name in weight_names:
156-
if "gate_proj" in weight_name:
157-
up_name = weight_name.replace("gate_proj", "up_proj")
158-
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
159-
if up_name not in weights:
160-
weights[up_name] = jax.zeros_like(weights[weight_name])
161-
assert isinstance(self.lora_backend, SUPPORTED_BACKENDS), (
162-
f"LoRA weight initialization currently only supported for LoRA backends: {', '.join(b.name for b in SUPPORTED_BACKENDS)}"
163-
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
164-
f"or consider implementing custom initialization logic for other backends."
165-
)
166-
weights[gate_up_name] = jnp.concatenate((weights[weight_name], weights[up_name]), 0)
167-
weights.pop(weight_name)
168-
if up_name in weights:
169-
weights.pop(up_name)
170-
elif "gate_up_proj" in weight_name:
171-
# If gate_up_proj is already stacked, we normalize it following the SGL convention
172-
gate_up_name = weight_name
173-
if "lora_A" in weight_name:
174-
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
175-
# else: no-op as LoRA B weight is already stacked.

0 commit comments

Comments
 (0)