diff --git a/python/sgl_jax/srt/lora/__init__.py b/python/sgl_jax/srt/lora/__init__.py deleted file mode 100644 index 7a25d7709..000000000 --- a/python/sgl_jax/srt/lora/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Modifications copyright 2025 SGLang-JAX Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""LoRA support for sgl-jax.""" - -from sgl_jax.srt.lora.lora_config import LoRAConfig -from sgl_jax.srt.lora.lora_registry import LoRARef, LoRARegistry - -__all__ = ["LoRAConfig", "LoRARef", "LoRARegistry"] diff --git a/python/sgl_jax/srt/lora/layers.py b/python/sgl_jax/srt/lora/layers.py new file mode 100644 index 000000000..82f8dad9a --- /dev/null +++ b/python/sgl_jax/srt/lora/layers.py @@ -0,0 +1,187 @@ +# Copyright 2023-2024 SGLang Team +# Modifications copyright 2025 SGLang-JAX Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""LoRA layer wrappers using Flax Model Surgery.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import jax +from flax import nnx + +if TYPE_CHECKING: + from sgl_jax.srt.lora.backend.base_backend import BaseLoRABackend + + +class LoRALinear(nnx.Module): + """ + LoRA wrapper for Linear layers using Flax NNX. + + This wraps an existing Linear layer and adds LoRA (Low-Rank Adaptation) + computation. Uses Model Surgery to preserve the original weights and sharding. + + V1 implementation uses backend to perform LoRA computation: + output = base_layer(x) + if enabled: + lora_output = backend.run_lora_a_gemm(x, lora_A_weights) + output = backend.run_lora_b_gemm(lora_output, lora_B_weights, output) + + Attributes: + base_layer: Original Linear layer (preserves weights and sharding) + lora_rank: LoRA rank dimension + backend: LoRA backend for efficient computation + enabled: Whether LoRA computation is active + """ + + def __init__( + self, + in_features: int, + out_features: int, + lora_rank: int, + base_layer: nnx.Linear | None = None, + backend: BaseLoRABackend | None = None, + rngs: nnx.Rngs | None = None, + ): + """ + Initialize LoRA Linear layer. + + Args: + in_features: Input dimension + out_features: Output dimension + lora_rank: Rank of LoRA matrices + base_layer: Existing Linear layer to wrap (optional) + backend: LoRA backend for computation (optional) + rngs: Random number generators for initialization + """ + self.in_features = in_features + self.out_features = out_features + self.lora_rank = lora_rank + self.backend = backend + + # Base layer - will be populated via nnx.update() during surgery + if base_layer is not None: + self.base_layer = base_layer + else: + # Create placeholder base layer + if rngs is None: + rngs = nnx.Rngs(0) + self.base_layer = nnx.Linear( + in_features, + out_features, + use_bias=True, + rngs=rngs, + ) + + # Control variable (not trainable) + self.enabled = nnx.Variable(False) # Whether LoRA is active + + def __call__(self, x: jax.Array) -> jax.Array: + """ + Forward pass with optional LoRA computation using backend. + + Args: + x: Input tensor (shape: [seq_len, in_features]) + + Returns: + Output tensor with LoRA delta added (if enabled) + """ + # Base layer computation (preserves original behavior) + output = self.base_layer(x) + + # Add LoRA delta if enabled and backend is available + if self.enabled.value and self.backend is not None: + # Get LoRA weights from memory pool via backend + # Backend handles batched LoRA computation for multiple adapters + + # Step 1: Shrink - project to low-rank space + # lora_A_weights fetched from memory pool based on batch_info + lora_a_output = self.backend.run_lora_a_gemm( + x, None + ) # Backend manages weights internally + + # Step 2: Expand - project back to output space and add to base output + output = self.backend.run_lora_b_gemm(lora_a_output, None, output) + + return output + + +class LoRAEmbedding(nnx.Module): + """ + LoRA wrapper for Embedding layers. + + Similar to LoRALinear but for embedding layers. + V1 implementation uses backend for computation. + """ + + def __init__( + self, + num_embeddings: int, + features: int, + lora_rank: int, + base_layer: nnx.Embed | None = None, + backend: BaseLoRABackend | None = None, + rngs: nnx.Rngs | None = None, + ): + """ + Initialize LoRA Embedding layer. + + Args: + num_embeddings: Size of vocabulary + features: Embedding dimension + lora_rank: Rank of LoRA matrices + base_layer: Existing Embed layer to wrap (optional) + backend: LoRA backend for computation (optional) + rngs: Random number generators + """ + self.num_embeddings = num_embeddings + self.features = features + self.lora_rank = lora_rank + self.backend = backend + + # Base layer + if base_layer is not None: + self.base_layer = base_layer + else: + if rngs is None: + rngs = nnx.Rngs(0) + self.base_layer = nnx.Embed( + num_embeddings, + features, + rngs=rngs, + ) + + # Control variable + self.enabled = nnx.Variable(False) + + def __call__(self, x: jax.Array) -> jax.Array: + """ + Forward pass for embedding with LoRA using backend. + + Args: + x: Input token indices + + Returns: + Embedded output with LoRA delta (if enabled) + """ + output = self.base_layer(x) + + # V1: Embedding LoRA computation via backend + # TODO: Implement embedding-specific backend methods if needed + # For now, embeddings use simple pass-through + if self.enabled.value and self.backend is not None: + # Backend handles embedding LoRA computation + pass + + return output diff --git a/python/sgl_jax/srt/lora/lora_manager.py b/python/sgl_jax/srt/lora/lora_manager.py new file mode 100644 index 000000000..2a60304fb --- /dev/null +++ b/python/sgl_jax/srt/lora/lora_manager.py @@ -0,0 +1,641 @@ +# Copyright 2023-2024 SGLang Team +# Modifications copyright 2025 SGLang-JAX Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""LoRA manager implementation for JAX - Phase 3 placeholder.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import jax.numpy as jnp +from jax.sharding import Mesh + +from sgl_jax.srt.lora.lora import ChunkedSgmvLoRABackend, LoRAAdapter +from sgl_jax.srt.lora.lora_config import LoRAConfig +from sgl_jax.srt.lora.lora_memory_pool import LoRAMemoryPool +from sgl_jax.srt.lora.lora_registry import LoRARef + +if TYPE_CHECKING: + from sgl_jax.srt.managers.schedule_batch import ScheduleBatch + +logger = logging.getLogger(__name__) + + +class LoRAManager: + """ + LoRA manager for JAX-based inference. + + V1 implementation: Simplified version with static adapter loading. + - All LoRA adapters are loaded at initialization time + - No dynamic loading/unloading during inference + - No eviction policy (all adapters stay in CPU memory) + - Adapters are transferred to device memory pool on-demand per batch + + Key differences from PyTorch/SGLang version: + - Uses JAX arrays instead of PyTorch tensors + - Memory pool uses JAX sharding for multi-device support + - Layer wrapping uses Flax NNX Model Surgery + - No kernel backend (uses JAX/XLA compilation instead) + + Future enhancements (V2+): + - Dynamic adapter loading/unloading + - LRU/FIFO eviction policies + - Adapter registry with async loading + - Support for larger number of adapters than memory pool slots + + Attributes: + max_loras_per_batch: Maximum number of LoRA adapters per batch + max_lora_rank: Maximum LoRA rank supported + num_layers: Number of transformer layers + target_modules: Set of target module names + mesh: JAX device mesh for sharding + dtype: Data type for LoRA weights + configs: Dict mapping lora_id -> LoRAConfig + loras: Dict mapping lora_id -> LoRAAdapter (CPU-side weights) + lora_refs: Dict mapping lora_id -> LoRARef + memory_pool: LoRAMemoryPool instance + """ + + def __init__( + self, + base_model, + base_hf_config, + max_loras_per_batch: int, + dtype: jnp.dtype, + mesh: Mesh, + tp_size: int = 1, + max_lora_rank: int | None = None, + target_modules: set[str] | None = None, + lora_paths: list[LoRARef] | None = None, + server_args=None, + ): + """ + Initialize LoRA manager. + + Args: + base_model: The base model to apply LoRA to + base_hf_config: HuggingFace config of the base model + max_loras_per_batch: Maximum number of LoRA adapters in a batch + dtype: Data type for LoRA weights + mesh: JAX device mesh for sharding + tp_size: Tensor parallelism size + max_lora_rank: Maximum LoRA rank to support (or None to infer) + target_modules: Set of target module names (or None to infer) + lora_paths: Optional list of LoRARef to preload + server_args: Server arguments (for future use) + """ + self.base_model = base_model + self.base_hf_config = base_hf_config + self.max_loras_per_batch = max_loras_per_batch + self.dtype = dtype + self.mesh = mesh + self.tp_size = tp_size + self.server_args = server_args + + # Extract model architecture from hf_config + self.num_layers = base_hf_config.num_hidden_layers + self.hidden_size = base_hf_config.hidden_size + self.intermediate_size = getattr(base_hf_config, "intermediate_size", self.hidden_size * 4) + self.num_attention_heads = base_hf_config.num_attention_heads + self.num_kv_heads = getattr(base_hf_config, "num_key_value_heads", self.num_attention_heads) + + # Initialize mutable state + self.init_state( + max_lora_rank=max_lora_rank, + target_modules=target_modules, + lora_paths=lora_paths, + ) + + def init_state( + self, + max_lora_rank: int | None = None, + target_modules: set[str] | None = None, + lora_paths: list[LoRARef] | None = None, + ): + """ + Initialize internal state of LoRAManager. + + Args: + max_lora_rank: Maximum LoRA rank (or None to infer from lora_paths) + target_modules: Target module names (or None to infer from lora_paths) + lora_paths: Optional list of LoRARef to preload + """ + # Validate arguments + if not lora_paths and (max_lora_rank is None or target_modules is None): + raise ValueError( + "When no lora_paths provided, must specify both max_lora_rank and target_modules" + ) + + # Initialize adapter storage + self.init_lora_adapters(lora_paths) + + # Infer or validate shapes + self.init_lora_shapes( + max_lora_rank=max_lora_rank, + target_modules=target_modules, + ) + + # Apply Model Surgery to add LoRA layers (if base_model provided) + if self.base_model is not None: + self.apply_lora_surgery() + + # Initialize memory pool + self.init_memory_pool() + + logger.info( + "LoRA manager initialized: max_rank=%d, target_modules=%s, max_loras=%d", + self.max_lora_rank, + self.target_modules, + self.max_loras_per_batch, + ) + + def init_lora_adapters(self, lora_paths: list[LoRARef] | None = None): + """ + Initialize adapter storage and optionally load adapters. + + Args: + lora_paths: Optional list of LoRARef to preload + """ + # Configs of all active LoRA adapters, indexed by LoRA ID + self.configs: dict[str, LoRAConfig] = {} + + # LoRA adapter weights cached in CPU memory, indexed by LoRA ID + self.loras: dict[str, LoRAAdapter] = {} + + # Mapping from LoRA ID to LoRARef object + self.lora_refs: dict[str, LoRARef] = {} + + # Count of pinned LoRA adapters + self.num_pinned_loras: int = 0 + + if lora_paths: + for lora_ref in lora_paths: + self.load_lora_adapter(lora_ref) + + def init_lora_shapes( + self, + max_lora_rank: int | None = None, + target_modules: set[str] | None = None, + ): + """ + Infer LoRA target modules and max_lora_rank from loaded adapters if not provided. + + Args: + max_lora_rank: Maximum LoRA rank (or None to infer) + target_modules: Target module names (or None to infer) + """ + # Initialize target_modules + if target_modules is not None: + self.target_modules = target_modules + else: + self.target_modules = set() + + # Infer from loaded adapters + for lora_id, config in self.configs.items(): + adapter_target_modules = set(config.target_modules) + + if target_modules is not None: + # Validate adapter is compatible + if not adapter_target_modules.issubset(self.target_modules): + unsupported = adapter_target_modules - self.target_modules + lora_name = self.lora_refs[lora_id].lora_name + raise ValueError( + "LoRA adapter '%s' contains unsupported modules: %s. " + "Specified target_modules: %s", + lora_name, + unsupported, + self.target_modules, + ) + else: + # Infer target_modules from adapter + self.target_modules.update(adapter_target_modules) + + # Infer or use max_lora_rank + if max_lora_rank is not None: + self.max_lora_rank = max_lora_rank + else: + self.max_lora_rank = max( + [config.r for config in self.configs.values()], + default=8, # Default rank if no adapters loaded + ) + + def init_memory_pool(self): + """Initialize the LoRA memory pool with proper sharding.""" + self.memory_pool = LoRAMemoryPool( + max_loras_per_batch=self.max_loras_per_batch, + max_lora_rank=self.max_lora_rank, + num_layers=self.num_layers, + target_modules=self.target_modules, + mesh=self.mesh, + dtype=self.dtype, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + num_attention_heads=self.num_attention_heads, + num_kv_heads=self.num_kv_heads, + tp_size=self.tp_size, + ) + self.memory_pool.init_buffers() + + def load_lora_adapter(self, lora_ref: LoRARef): + """ + Load a single LoRA adapter. + + V1 implementation: Loads config and weights from disk to CPU memory once. + No dynamic loading/unloading. + + Args: + lora_ref: LoRARef object with lora_id, lora_name, lora_path + + Raises: + ValueError: If adapter already loaded or incompatible + """ + if lora_ref.lora_id in self.loras: + raise ValueError(f"LoRA adapter {lora_ref.lora_id} already loaded") + + if lora_ref.pinned and self.num_pinned_loras >= self.max_loras_per_batch - 1: + raise ValueError( + f"Cannot pin adapter {lora_ref.lora_name}: already have {self.num_pinned_loras} " + f"pinned adapters (max {self.max_loras_per_batch - 1}, reserving 1 slot for dynamic use)" + ) + + # Load config + config = LoRAConfig(lora_ref.lora_path) + self.configs[lora_ref.lora_id] = config + + # Load adapter weights to CPU + self.load_lora_weights(lora_ref) + + # Store metadata + self.lora_refs[lora_ref.lora_id] = lora_ref + if lora_ref.pinned: + self.num_pinned_loras += 1 + + logger.info( + "Loaded LoRA adapter: %s (id=%s, rank=%d, pinned=%s)", + lora_ref.lora_name, + lora_ref.lora_id, + config.r, + lora_ref.pinned, + ) + + def load_lora_weights(self, lora_ref: LoRARef): + """ + Load LoRA weights from disk to CPU memory. + + V1 implementation: Creates LoRAAdapter and calls initialize_weights() + to load weights from checkpoint files. + + Args: + lora_ref: LoRARef object with lora_id and lora_path + """ + from sgl_jax.srt.configs.load_config import LoadConfig + + # Get load config (TODO: get from server_args if available) + load_config = LoadConfig() + + # Create LoRA backend (placeholder for v1) + lora_backend = ChunkedSgmvLoRABackend() + + # Create adapter + adapter = LoRAAdapter( + uid=lora_ref.lora_id, + config=self.configs[lora_ref.lora_id], + base_hf_config=self.base_hf_config, + load_config=load_config, + lora_backend=lora_backend, + ) + + # Load weights from disk to CPU + adapter.initialize_weights() + + # Store adapter + self.loras[lora_ref.lora_id] = adapter + + logger.info( + "Loaded weights for LoRA adapter: %s (%d layers)", + lora_ref.lora_name, + len(adapter.layers), + ) + + def can_support(self, config: LoRAConfig) -> bool: + """Check if memory pool can support the given LoRA config.""" + return self.memory_pool.can_support(config) + + def prepare_lora_batch(self, schedule_batch: ScheduleBatch): + """ + Prepare LoRA batch for inference. + + V1 implementation: Transfers required adapter weights from CPU to device memory pool. + All adapters are pre-loaded at initialization, no dynamic loading. + + Args: + schedule_batch: ScheduleBatch containing requests with lora_ids + + Raises: + ValueError: If batch exceeds max_loras_per_batch or adapter not loaded + """ + # Collect unique lora_ids from batch + cur_uids = set() + for req in schedule_batch.reqs: + if hasattr(req, "lora_id") and req.lora_id is not None: + cur_uids.add(req.lora_id) + else: + # Base model (no LoRA) + cur_uids.add(None) + + # Validate batch size + if len(cur_uids) > self.max_loras_per_batch: + raise ValueError( + f"Batch has {len(cur_uids)} unique LoRA adapters, exceeds max {self.max_loras_per_batch}" + ) + + # Validate all adapters are loaded + for uid in cur_uids: + if uid is not None and uid not in self.loras: + raise ValueError(f"LoRA adapter {uid} not loaded") + + # Load adapters into device memory pool (CPU -> device transfer) + self.memory_pool.prepare_lora_batch( + cur_uids=cur_uids, + lora_adapters=self.loras, + ) + + # Control LoRA layer enable/disable and set scaling + if hasattr(self, "lora_modules") and self.lora_modules: + self._update_lora_layers(cur_uids) + + logger.debug("Prepared LoRA batch: %d unique adapters", len(cur_uids)) + + def get_buffer_id(self, lora_id: str | None) -> int: + """Get buffer slot ID for a given LoRA adapter ID.""" + return self.memory_pool.get_buffer_id(lora_id) + + def apply_lora_surgery(self): + """ + Apply Flax Model Surgery to add LoRA layers to the base model. + + This method uses Flax NNX's Model Surgery technique to dynamically + replace Linear layers with LoRALinear wrappers without modifying + the original model definition. + + Steps: + 1. Save original model state (including sharding information) + 2. Replace target Linear layers with LoRALinear wrappers + 3. Restore original weights via nnx.update() (preserves sharding) + + The surgery preserves: + - Original model weights + - Sharding specifications + - Model structure compatibility with JIT compilation + """ + from flax import nnx + + from sgl_jax.srt.lora.layers import LoRALinear + + if self.base_model is None: + logger.warning("No base_model provided, skipping LoRA surgery") + return + + logger.info("Applying LoRA surgery to base model...") + + # Step 1: Save original state (with sharding!) + original_state = nnx.state(self.base_model) + + # Step 2: Track replaced modules + self.lora_modules: list[tuple[str, LoRALinear]] = [] + + # Step 3: Replace Linear layers with LoRALinear + # We need to iterate through the model and find target modules + # For now, use a simple approach: check common layer names + try: + # Try to access model.layers (common structure) + if hasattr(self.base_model, "layers"): + layers = self.base_model.layers + elif hasattr(self.base_model, "model") and hasattr(self.base_model.model, "layers"): + layers = self.base_model.model.layers + else: + logger.warning("Could not find model.layers, skipping surgery") + return + + # Iterate through layers + for layer_idx in range(len(layers)): + layer = layers[layer_idx] + + # Check for attention layers + if hasattr(layer, "self_attn"): + attn = layer.self_attn + for module_name in self.target_modules: + if hasattr(attn, module_name): + self._replace_with_lora( + attn, + module_name, + f"layers.{layer_idx}.self_attn.{module_name}", + ) + + # Check for MLP layers + if hasattr(layer, "mlp"): + mlp = layer.mlp + for module_name in self.target_modules: + if hasattr(mlp, module_name): + self._replace_with_lora( + mlp, + module_name, + f"layers.{layer_idx}.mlp.{module_name}", + ) + + except Exception as e: + logger.error("Error during LoRA surgery: %s", e) + logger.warning("LoRA surgery failed, continuing without LoRA layers") + return + + # Step 4: Restore original weights (preserves sharding) + try: + nnx.update(self.base_model, original_state) + logger.info( + "LoRA surgery completed: replaced %d modules", + len(self.lora_modules), + ) + except Exception as e: + logger.error("Error restoring original state: %s", e) + raise + + def _replace_with_lora( + self, + parent_module, + attr_name: str, + full_path: str, + ): + """ + Replace a Linear layer with LoRALinear wrapper. + + Args: + parent_module: Parent module containing the layer + attr_name: Attribute name of the layer (e.g., "q_proj") + full_path: Full path for logging (e.g., "layers.0.self_attn.q_proj") + """ + from flax import nnx + + from sgl_jax.srt.lora.layers import LoRALinear + + original_layer = getattr(parent_module, attr_name, None) + if original_layer is None: + return + + # Check if it's a Linear layer + if not isinstance(original_layer, nnx.Linear): + return + + # Get or create backend + if not hasattr(self, "lora_backend"): + from sgl_jax.srt.lora.backend.bgmv_backend import BgmvLoRABackend + + self.lora_backend = BgmvLoRABackend( + max_loras_per_batch=self.max_loras_per_batch, + max_lora_rank=self.max_lora_rank, + ) + + # Create LoRALinear wrapper with backend + lora_layer = LoRALinear( + in_features=original_layer.in_features, + out_features=original_layer.out_features, + lora_rank=self.max_lora_rank, + base_layer=original_layer, + backend=self.lora_backend, + rngs=nnx.Rngs(42), # Fixed seed for reproducibility + ) + + # Replace the layer + setattr(parent_module, attr_name, lora_layer) + + # Track the replacement + self.lora_modules.append((full_path, lora_layer)) + + logger.debug("Replaced %s with LoRALinear", full_path) + + def _get_nested_attr(self, obj, attr_path: str): + """ + Get nested attribute using dot notation. + + Args: + obj: Object to traverse + attr_path: Dot-separated path (e.g., "layers.0.self_attn.q_proj") + + Returns: + The nested attribute + """ + for attr in attr_path.split("."): + obj = getattr(obj, attr) + return obj + + def _set_nested_attr(self, obj, attr_path: str, value): + """ + Set nested attribute using dot notation. + + Args: + obj: Object to traverse + attr_path: Dot-separated path + value: Value to set + """ + parts = attr_path.split(".") + for attr in parts[:-1]: + obj = getattr(obj, attr) + setattr(obj, parts[-1], value) + + def verify_sharding_preserved(self): + """ + Verify that model surgery preserved sharding information. + + Checks that base layer weights still have their original sharding specs. + """ + if not hasattr(self, "lora_modules") or not self.lora_modules: + logger.warning("No LoRA modules to verify") + return + + logger.info("Verifying sharding preservation...") + + for module_path, lora_layer in self.lora_modules: + try: + # Check if base layer kernel has sharding + if hasattr(lora_layer.base_layer, "kernel"): + kernel = lora_layer.base_layer.kernel + if hasattr(kernel, "value"): + kernel_value = kernel.value + if hasattr(kernel_value, "sharding"): + sharding = kernel_value.sharding + logger.info( + "%s base_layer.kernel sharding: %s", + module_path, + sharding, + ) + else: + logger.warning( + "%s base_layer.kernel has no sharding attribute", + module_path, + ) + + # Check LoRA parameters sharding + if hasattr(lora_layer.lora_A, "value"): + lora_a_value = lora_layer.lora_A.value + if hasattr(lora_a_value, "sharding"): + logger.info( + "%s lora_A sharding: %s", + module_path, + lora_a_value.sharding, + ) + + except Exception as e: + logger.warning( + "Error checking sharding for %s: %s", + module_path, + e, + ) + + def _update_lora_layers(self, cur_uids: set[str | None]): + """ + Update LoRA layers based on current batch. + + Enables/disables LoRA computation and sets scaling factors based on + whether the batch contains LoRA requests. + + Args: + cur_uids: Set of lora_ids in the current batch + """ + # Determine if batch has any LoRA requests + has_lora = any(uid is not None for uid in cur_uids) + + # TODO: Currently simplified - enables all layers if any request has LoRA + # Phase 4 should implement per-request LoRA handling with batched computation + + for module_path, lora_layer in self.lora_modules: + # Enable/disable LoRA + lora_layer.enabled.value = has_lora + + if has_lora: + # Set scaling based on active adapter + # Simplified: use first non-None adapter's scaling + active_uid = next((uid for uid in cur_uids if uid is not None), None) + if active_uid and active_uid in self.loras: + adapter = self.loras[active_uid] + lora_layer.scaling.value = adapter.scaling + logger.debug( + "Enabled LoRA for %s with scaling=%.4f", + module_path, + adapter.scaling, + ) + else: + # Fallback scaling + lora_layer.scaling.value = 1.0 + else: + logger.debug("Disabled LoRA for %s", module_path) diff --git a/python/sgl_jax/srt/lora/lora_memory_pool.py b/python/sgl_jax/srt/lora/lora_memory_pool.py new file mode 100644 index 000000000..c5336d297 --- /dev/null +++ b/python/sgl_jax/srt/lora/lora_memory_pool.py @@ -0,0 +1,753 @@ +# Copyright 2023-2024 SGLang Team +# Modifications copyright 2025 SGLang-JAX Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""LoRA memory pool implementation for JAX.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P +from jax.tree_util import register_pytree_node_class + +if TYPE_CHECKING: + from sgl_jax.srt.lora.lora import LoRAAdapter + from sgl_jax.srt.lora.lora_config import LoRAConfig + +logger = logging.getLogger(__name__) + + +class EmptySlot: + """ + Singleton class to represent an empty slot in the memory pool. + This improves readability by not using special str as a placeholder. + """ + + __slots__ = () + + def __repr__(self): + return "|EMPTY|" + + def __new__(cls): + if not hasattr(cls, "_instance"): + cls._instance = super().__new__(cls) + return cls._instance + + +EMPTY_SLOT = EmptySlot() + + +@register_pytree_node_class +class LoRAMemoryPool: + """ + JAX-based memory pool for LoRA adapters. + + Unlike PyTorch version, this uses functional updates and pytree registration + for JAX jit compatibility. No eviction policy is implemented - uses simple + incremental buffer allocation. + + Key differences from PyTorch version: + - Pytree-compatible for JAX jit + - Functional updates with .at[].set() instead of in-place mutations + - Sharding specs for distributed inference + - Simple buffer_id allocation without eviction + - CPU-side tracking (uid mappings) separate from JAX arrays + + Attributes: + max_loras_per_batch: Maximum number of LoRA adapters per batch + max_lora_rank: Maximum LoRA rank supported + num_layers: Number of transformer layers + target_modules: Set of target module names (e.g., {"qkv_proj", "o_proj"}) + mesh: JAX device mesh for sharding + dtype: Data type for LoRA weights + A_buffer: Dict[module_name, List[jax.Array]] - A matrices per layer + B_buffer: Dict[module_name, List[jax.Array]] - B matrices per layer + uid_to_buffer_id: Mapping from lora_id to buffer slot (CPU-side) + buffer_id_to_uid: Mapping from buffer slot to lora_id (CPU-side) + """ + + def __init__( + self, + max_loras_per_batch: int, + max_lora_rank: int, + num_layers: int, + target_modules: set[str], + mesh: Mesh, + dtype: jnp.dtype = jnp.float16, + hidden_size: int = 4096, + intermediate_size: int = 11008, + num_attention_heads: int = 32, + num_kv_heads: int = 32, + tp_size: int = 1, + ): + """ + Initialize LoRA memory pool. + + Args: + max_loras_per_batch: Maximum number of LoRA adapters in a batch + max_lora_rank: Maximum LoRA rank to support + num_layers: Number of transformer layers + target_modules: Set of target module names + mesh: JAX device mesh for sharding + dtype: Data type for LoRA weights + hidden_size: Model hidden dimension + intermediate_size: FFN intermediate dimension + num_attention_heads: Number of attention heads + num_kv_heads: Number of KV heads (for GQA) + tp_size: Tensor parallelism size + """ + self.max_loras_per_batch = max_loras_per_batch + self.max_lora_rank = max_lora_rank + self.num_layers = num_layers + self.target_modules = target_modules + self.mesh = mesh + self.dtype = dtype + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_kv_heads = num_kv_heads + self.tp_size = tp_size + + # CPU-side tracking (not in pytree) + # These are mutable Python objects used for bookkeeping + self.uid_to_buffer_id: dict[str | None, int] = {} + self.buffer_id_to_uid: list[str | None | EmptySlot] = [EMPTY_SLOT] * max_loras_per_batch + + # GPU buffers (in pytree) - initialized in init_buffers() + self.A_buffer: dict[str, list[jax.Array]] = {} + self.B_buffer: dict[str, list[jax.Array]] = {} + + def tree_flatten(self): + """Flatten for pytree registration - only JAX arrays are children.""" + # Flatten A_buffer and B_buffer into lists + a_buffer_flat = [] + b_buffer_flat = [] + module_names = sorted(self.A_buffer.keys()) + + for module_name in module_names: + a_buffer_flat.extend(self.A_buffer[module_name]) + b_buffer_flat.extend(self.B_buffer[module_name]) + + children = (a_buffer_flat, b_buffer_flat) + aux_data = { + "max_loras_per_batch": self.max_loras_per_batch, + "max_lora_rank": self.max_lora_rank, + "num_layers": self.num_layers, + "target_modules": self.target_modules, + "mesh": self.mesh, + "dtype": self.dtype, + "hidden_size": self.hidden_size, + "intermediate_size": self.intermediate_size, + "num_attention_heads": self.num_attention_heads, + "num_kv_heads": self.num_kv_heads, + "tp_size": self.tp_size, + "module_names": module_names, + "uid_to_buffer_id": self.uid_to_buffer_id, + "buffer_id_to_uid": self.buffer_id_to_uid, + } + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Unflatten from pytree.""" + obj = object.__new__(cls) + + # Restore attributes + obj.max_loras_per_batch = aux_data["max_loras_per_batch"] + obj.max_lora_rank = aux_data["max_lora_rank"] + obj.num_layers = aux_data["num_layers"] + obj.target_modules = aux_data["target_modules"] + obj.mesh = aux_data["mesh"] + obj.dtype = aux_data["dtype"] + obj.hidden_size = aux_data["hidden_size"] + obj.intermediate_size = aux_data["intermediate_size"] + obj.num_attention_heads = aux_data["num_attention_heads"] + obj.num_kv_heads = aux_data["num_kv_heads"] + obj.tp_size = aux_data["tp_size"] + obj.uid_to_buffer_id = aux_data["uid_to_buffer_id"] + obj.buffer_id_to_uid = aux_data["buffer_id_to_uid"] + + # Reconstruct A_buffer and B_buffer + a_buffer_flat, b_buffer_flat = children + module_names = aux_data["module_names"] + + obj.A_buffer = {} + obj.B_buffer = {} + + a_idx = 0 + b_idx = 0 + for module_name in module_names: + obj.A_buffer[module_name] = a_buffer_flat[a_idx : a_idx + obj.num_layers] + obj.B_buffer[module_name] = b_buffer_flat[b_idx : b_idx + obj.num_layers] + a_idx += obj.num_layers + b_idx += obj.num_layers + + return obj + + def can_support(self, config: LoRAConfig) -> bool: + """Check if the memory pool can support the given LoRA config.""" + if config.r > self.max_lora_rank: + return False + # Check if target modules are supported + config_modules = set(config.target_modules) + return config_modules.issubset(self.target_modules) + + def _get_lora_a_shape(self, module_name: str) -> tuple[int, int, int]: + """ + Get shape for LoRA A matrix. + + Returns: (max_loras_per_batch, max_lora_rank, input_dim) + + Sharding strategy (for row-parallel layers like o_proj, down_proj): + - Shard input dimension across TP + """ + if module_name == "qkv_proj": + # Input: hidden_size + input_dim = self.hidden_size + if self.tp_size > 1: + # Column-parallel: input NOT sharded + pass + elif module_name == "o_proj": + # Input: hidden_size (from concatenated heads) + input_dim = self.hidden_size + if self.tp_size > 1: + # Row-parallel: input sharded + input_dim = input_dim // self.tp_size + elif module_name == "gate_up_proj": + # Input: hidden_size + input_dim = self.hidden_size + if self.tp_size > 1: + # Column-parallel: input NOT sharded + pass + elif module_name == "down_proj": + # Input: intermediate_size + input_dim = self.intermediate_size + if self.tp_size > 1: + # Row-parallel: input sharded + input_dim = input_dim // self.tp_size + else: + # Default: hidden_size + input_dim = self.hidden_size + + return (self.max_loras_per_batch, self.max_lora_rank, input_dim) + + def _get_lora_b_shape(self, module_name: str) -> tuple[int, int, int]: + """ + Get shape for LoRA B matrix. + + Returns: (max_loras_per_batch, output_dim, max_lora_rank) + + Sharding strategy (for column-parallel layers like qkv_proj, gate_up_proj): + - Shard output dimension across TP + """ + if module_name == "qkv_proj": + # Output: (num_heads * head_dim) for Q, K, V combined + # head_dim = hidden_size // num_attention_heads + head_dim = self.hidden_size // self.num_attention_heads + # Q heads + KV heads + output_dim = (self.num_attention_heads + 2 * self.num_kv_heads) * head_dim + if self.tp_size > 1: + # Column-parallel: output sharded + output_dim = output_dim // self.tp_size + elif module_name == "o_proj": + # Output: hidden_size + output_dim = self.hidden_size + if self.tp_size > 1: + # Row-parallel: output NOT sharded + pass + elif module_name == "gate_up_proj": + # Output: intermediate_size * 2 (gate and up combined) + output_dim = self.intermediate_size * 2 + if self.tp_size > 1: + # Column-parallel: output sharded + output_dim = output_dim // self.tp_size + elif module_name == "down_proj": + # Output: hidden_size + output_dim = self.hidden_size + if self.tp_size > 1: + # Row-parallel: output NOT sharded + pass + else: + # Default: hidden_size + output_dim = self.hidden_size + + return (self.max_loras_per_batch, output_dim, self.max_lora_rank) + + def _get_lora_a_sharding(self, module_name: str) -> NamedSharding: + """Get sharding spec for LoRA A matrix.""" + # Row-parallel layers: shard input dimension + if module_name in {"o_proj", "down_proj"}: + # Shape: (batch, rank, input_dim) + # Shard input_dim across tensor axis + return NamedSharding(self.mesh, P(None, None, "tensor")) + else: + # Column-parallel: no sharding for A + return NamedSharding(self.mesh, P(None, None, None)) + + def _get_lora_b_sharding(self, module_name: str) -> NamedSharding: + """Get sharding spec for LoRA B matrix.""" + # Column-parallel layers: shard output dimension + if module_name in {"qkv_proj", "gate_up_proj"}: + # Shape: (batch, output_dim, rank) + # Shard output_dim across tensor axis + return NamedSharding(self.mesh, P(None, "tensor", None)) + else: + # Row-parallel: no sharding for B + return NamedSharding(self.mesh, P(None, None, None)) + + def init_buffers(self): + """ + Initialize GPU buffers for LoRA weights. + + Creates A_buffer and B_buffer with proper sharding. + """ + logger.info("Initializing LoRA memory pool buffers for %d layers", self.num_layers) + + with self.mesh: + for module_name in self.target_modules: + a_shape = self._get_lora_a_shape(module_name) + b_shape = self._get_lora_b_shape(module_name) + a_sharding = self._get_lora_a_sharding(module_name) + b_sharding = self._get_lora_b_sharding(module_name) + + self.A_buffer[module_name] = [] + self.B_buffer[module_name] = [] + + for _ in range(self.num_layers): + # Create sharded A buffer + a_buf = jax.jit( + lambda shape=a_shape, dt=self.dtype: jnp.zeros(shape, dtype=dt), + out_shardings=a_sharding, + )() + self.A_buffer[module_name].append(a_buf) + + # Create sharded B buffer + b_buf = jax.jit( + lambda shape=b_shape, dt=self.dtype: jnp.zeros(shape, dtype=dt), + out_shardings=b_sharding, + )() + self.B_buffer[module_name].append(b_buf) + + logger.info( + "Created LoRA buffers for %s: A=%s, B=%s", + module_name, + a_shape, + b_shape, + ) + + logger.info("LoRA memory pool initialization complete") + + def prepare_lora_batch( + self, + cur_uids: set[str | None], + lora_adapters: dict[str | None, LoRAAdapter], + ): + """ + Prepare LoRA batch by loading adapters into buffer slots. + + Simplified version without eviction policy - uses incremental allocation. + + Args: + cur_uids: Set of lora_ids needed for current batch + lora_adapters: Dict mapping lora_id to LoRAAdapter + + Raises: + ValueError: If no buffer slots available + """ + + def get_available_buffer_slot() -> int: + """Find next available buffer slot (simple incremental allocation).""" + for buffer_id in range(self.max_loras_per_batch): + if self.buffer_id_to_uid[buffer_id] == EMPTY_SLOT: + return buffer_id + + raise ValueError( + "No available buffer slots. Max %d LoRA adapters per batch exceeded.", + self.max_loras_per_batch, + ) + + # Load each adapter that's not already loaded + for uid in cur_uids: + if uid not in self.uid_to_buffer_id: + buffer_id = get_available_buffer_slot() + lora_adapter = lora_adapters.get(uid) + self.load_lora_weight_to_buffer(uid, buffer_id, lora_adapter) + self.uid_to_buffer_id[uid] = buffer_id + self.buffer_id_to_uid[buffer_id] = uid + logger.debug("Loaded LoRA %s into buffer slot %d", uid, buffer_id) + + def load_lora_weight_to_buffer( + self, + uid: str | None, + buffer_id: int, + lora_adapter: LoRAAdapter | None, + ): + """ + Load LoRA weights into buffer slot. + + Args: + uid: LoRA adapter ID (None for base model) + buffer_id: Buffer slot index + lora_adapter: LoRA adapter object (None for base model) + """ + if uid is None: + # Base model: zero out the buffer slot + logger.debug("Loading base model (zeros) into buffer slot %d", buffer_id) + for module_name in self.target_modules: + for layer_id in range(self.num_layers): + # Zero out A buffer + self.A_buffer[module_name][layer_id] = ( + self.A_buffer[module_name][layer_id].at[buffer_id].set(0) + ) + # Zero out B buffer + self.B_buffer[module_name][layer_id] = ( + self.B_buffer[module_name][layer_id].at[buffer_id].set(0) + ) + return + + if lora_adapter is None: + logger.warning("LoRA adapter %s is None, loading zeros", uid) + # Treat as base model if adapter is None + for module_name in self.target_modules: + for layer_id in range(self.num_layers): + self.A_buffer[module_name][layer_id] = ( + self.A_buffer[module_name][layer_id].at[buffer_id].set(0) + ) + self.B_buffer[module_name][layer_id] = ( + self.B_buffer[module_name][layer_id].at[buffer_id].set(0) + ) + return + + logger.debug("Loading LoRA adapter %s into buffer slot %d", uid, buffer_id) + + # Get TP rank for distributed inference + tp_rank = jax.process_index() if self.tp_size > 1 else 0 + + # Process each layer + for layer_id in range(self.num_layers): + layer_weights = lora_adapter.layers[layer_id].weights + + # Process each target module + for module_name in self.target_modules: + # Extract and load weights for this module + lora_a, lora_b = self._extract_module_weights(layer_weights, layer_id, module_name) + + if lora_a is not None and lora_b is not None: + # Handle rank padding/slicing + lora_a = self._handle_rank_mismatch(lora_a, is_lora_a=True) + lora_b = self._handle_rank_mismatch(lora_b, is_lora_a=False) + + # Apply TP slicing if needed + if self.tp_size > 1: + lora_a = self._apply_tp_slice(lora_a, module_name, tp_rank, is_lora_a=True) + lora_b = self._apply_tp_slice(lora_b, module_name, tp_rank, is_lora_a=False) + + # Load into buffer + self.A_buffer[module_name][layer_id] = ( + self.A_buffer[module_name][layer_id].at[buffer_id].set(lora_a) + ) + self.B_buffer[module_name][layer_id] = ( + self.B_buffer[module_name][layer_id].at[buffer_id].set(lora_b) + ) + else: + # Module not found in adapter weights, zero out + logger.debug( + "Module %s not found in layer %d weights, zeroing buffer", + module_name, + layer_id, + ) + self.A_buffer[module_name][layer_id] = ( + self.A_buffer[module_name][layer_id].at[buffer_id].set(0) + ) + self.B_buffer[module_name][layer_id] = ( + self.B_buffer[module_name][layer_id].at[buffer_id].set(0) + ) + + def _extract_module_weights( + self, + layer_weights: dict[str, jax.Array], + layer_id: int, + module_name: str, + ) -> tuple[jax.Array | None, jax.Array | None]: + """ + Extract LoRA A and B weights for a specific module from layer weights. + + Args: + layer_weights: Dictionary of weight tensors for this layer + layer_id: Layer index + module_name: Target module name (qkv_proj, o_proj, gate_up_proj, down_proj) + + Returns: + Tuple of (lora_a, lora_b) tensors, or (None, None) if not found + + Weight naming convention in LoRA adapters: + base_model.model.layers.{layer_id}.{module_path}.lora_A.weight + base_model.model.layers.{layer_id}.{module_path}.lora_B.weight + + Module name mapping: + - qkv_proj: Concatenate q_proj, k_proj, v_proj + - gate_up_proj: Concatenate gate_proj, up_proj + - o_proj: Direct mapping to o_proj + - down_proj: Direct mapping to down_proj + """ + # Handle composite modules (qkv_proj, gate_up_proj) + if module_name == "qkv_proj": + # Need to concatenate q_proj, k_proj, v_proj + return self._extract_and_concat_qkv(layer_weights, layer_id) + elif module_name == "gate_up_proj": + # Need to concatenate gate_proj, up_proj + return self._extract_and_concat_gate_up(layer_weights, layer_id) + else: + # Direct mapping (o_proj, down_proj) + return self._extract_single_module(layer_weights, layer_id, module_name) + + def _extract_single_module( + self, + layer_weights: dict[str, jax.Array], + layer_id: int, + module_name: str, + ) -> tuple[jax.Array | None, jax.Array | None]: + """Extract weights for a single module (o_proj, down_proj).""" + lora_a = None + lora_b = None + + # Search for matching weight keys + for key, weight in layer_weights.items(): + # Match pattern: layers.{layer_id}.{path}.{module_name}.lora_{A|B}.weight + if f"layers.{layer_id}." in key and module_name in key: + if "lora_A.weight" in key: + lora_a = weight + elif "lora_B.weight" in key: + lora_b = weight + + return lora_a, lora_b + + def _extract_and_concat_qkv( + self, + layer_weights: dict[str, jax.Array], + layer_id: int, + ) -> tuple[jax.Array | None, jax.Array | None]: + """ + Extract and concatenate q_proj, k_proj, v_proj weights. + + For attention QKV projection, we need to concatenate: + - lora_A: Concatenate along rank dimension (axis 0) + - lora_B: Concatenate along output dimension (axis 0) + + Returns: + Concatenated (lora_a_qkv, lora_b_qkv) + """ + # Extract individual components + q_a, q_b = self._extract_single_module(layer_weights, layer_id, "q_proj") + k_a, k_b = self._extract_single_module(layer_weights, layer_id, "k_proj") + v_a, v_b = self._extract_single_module(layer_weights, layer_id, "v_proj") + + # Check if all components are present + if all(x is not None for x in [q_a, k_a, v_a, q_b, k_b, v_b]): + # Concatenate A matrices along rank dimension (axis 0) + # Shape: (rank, hidden_size) -> (3*rank, hidden_size) or similar + lora_a_qkv = jnp.concatenate([q_a, k_a, v_a], axis=0) + + # Concatenate B matrices along output dimension (axis 0) + # Shape: (head_dim, rank) -> (3*head_dim, rank) or similar + lora_b_qkv = jnp.concatenate([q_b, k_b, v_b], axis=0) + + return lora_a_qkv, lora_b_qkv + else: + # Not all components found + logger.warning( + "Incomplete QKV weights in layer %d: q_proj=%s, k_proj=%s, v_proj=%s", + layer_id, + q_a is not None, + k_a is not None, + v_a is not None, + ) + return None, None + + def _extract_and_concat_gate_up( + self, + layer_weights: dict[str, jax.Array], + layer_id: int, + ) -> tuple[jax.Array | None, jax.Array | None]: + """ + Extract and concatenate gate_proj, up_proj weights. + + For FFN gate/up projection, we need to concatenate: + - lora_A: Concatenate along rank dimension (axis 0) + - lora_B: Concatenate along output dimension (axis 0) + + Returns: + Concatenated (lora_a_gate_up, lora_b_gate_up) + """ + # Extract individual components + gate_a, gate_b = self._extract_single_module(layer_weights, layer_id, "gate_proj") + up_a, up_b = self._extract_single_module(layer_weights, layer_id, "up_proj") + + # Check if both components are present + if all(x is not None for x in [gate_a, up_a, gate_b, up_b]): + # Concatenate A matrices along rank dimension (axis 0) + lora_a_gate_up = jnp.concatenate([gate_a, up_a], axis=0) + + # Concatenate B matrices along output dimension (axis 0) + lora_b_gate_up = jnp.concatenate([gate_b, up_b], axis=0) + + return lora_a_gate_up, lora_b_gate_up + else: + # Not all components found + logger.warning( + "Incomplete gate_up weights in layer %d: gate_proj=%s, up_proj=%s", + layer_id, + gate_a is not None, + up_a is not None, + ) + return None, None + + def _handle_rank_mismatch( + self, + weight: jax.Array, + is_lora_a: bool, + ) -> jax.Array: + """ + Handle rank mismatch between adapter and buffer. + + For lora_A: shape is (rank, input_dim) + For lora_B: shape is (output_dim, rank) + + If adapter rank < max_lora_rank: Pad with zeros + If adapter rank > max_lora_rank: Slice (shouldn't happen with proper config) + + Args: + weight: LoRA weight tensor + is_lora_a: True for A matrix, False for B matrix + + Returns: + Adjusted weight tensor + """ + if is_lora_a: + # lora_A shape: (rank, input_dim) + current_rank = weight.shape[0] + if current_rank < self.max_lora_rank: + # Pad along rank dimension (axis 0) + pad_size = self.max_lora_rank - current_rank + weight = jnp.pad( + weight, + ((0, pad_size), (0, 0)), + mode="constant", + constant_values=0, + ) + elif current_rank > self.max_lora_rank: + # Slice to max_lora_rank (shouldn't happen normally) + logger.warning( + "LoRA rank %d exceeds max_lora_rank %d, slicing", + current_rank, + self.max_lora_rank, + ) + weight = weight[: self.max_lora_rank, :] + else: + # lora_B shape: (output_dim, rank) + current_rank = weight.shape[1] + if current_rank < self.max_lora_rank: + # Pad along rank dimension (axis 1) + pad_size = self.max_lora_rank - current_rank + weight = jnp.pad( + weight, + ((0, 0), (0, pad_size)), + mode="constant", + constant_values=0, + ) + elif current_rank > self.max_lora_rank: + # Slice to max_lora_rank + logger.warning( + "LoRA rank %d exceeds max_lora_rank %d, slicing", + current_rank, + self.max_lora_rank, + ) + weight = weight[:, : self.max_lora_rank] + + return weight + + def _apply_tp_slice( + self, + weight: jax.Array, + module_name: str, + tp_rank: int, + is_lora_a: bool, + ) -> jax.Array: + """ + Apply tensor parallel slicing to LoRA weights. + + Sharding strategy: + - Row-parallel (o_proj, down_proj): Shard input dimension of lora_A + - Column-parallel (qkv_proj, gate_up_proj): Shard output dimension of lora_B + + Args: + weight: LoRA weight tensor + module_name: Target module name + tp_rank: Tensor parallel rank + is_lora_a: True for A matrix, False for B matrix + + Returns: + Sliced weight tensor for this TP rank + """ + # Row-parallel modules: shard input dimension of lora_A + if module_name in {"o_proj", "down_proj"}: + if is_lora_a: + # lora_A shape: (rank, input_dim) + # Shard input_dim across TP ranks + input_dim = weight.shape[1] + chunk_size = input_dim // self.tp_size + start_idx = tp_rank * chunk_size + end_idx = start_idx + chunk_size + weight = weight[:, start_idx:end_idx] + # lora_B: no slicing for row-parallel + + # Column-parallel modules: shard output dimension of lora_B + elif module_name in {"qkv_proj", "gate_up_proj"} and not is_lora_a: + # lora_B shape: (output_dim, rank) + # Shard output_dim across TP ranks + output_dim = weight.shape[0] + chunk_size = output_dim // self.tp_size + start_idx = tp_rank * chunk_size + end_idx = start_idx + chunk_size + weight = weight[start_idx:end_idx, :] + # lora_A: no slicing for column-parallel + + return weight + + def get_buffer_id(self, lora_uid: str | None) -> int: + """Get buffer slot ID for a given LoRA adapter ID.""" + return self.uid_to_buffer_id[lora_uid] + + def get_tensor(self, module_name: str, layer_id: int, is_lora_a: bool) -> jax.Array: + """ + Get LoRA tensor for a specific module and layer. + + Args: + module_name: Target module name (e.g., "qkv_proj") + layer_id: Layer index + is_lora_a: True for A matrix, False for B matrix + + Returns: + JAX array with shape: + - A: (max_loras_per_batch, max_lora_rank, input_dim) + - B: (max_loras_per_batch, output_dim, max_lora_rank) + """ + if is_lora_a: + return self.A_buffer[module_name][layer_id] + else: + return self.B_buffer[module_name][layer_id] diff --git a/python/sgl_jax/srt/managers/schedule_batch.py b/python/sgl_jax/srt/managers/schedule_batch.py index ac4abc6e4..4c677c8a3 100644 --- a/python/sgl_jax/srt/managers/schedule_batch.py +++ b/python/sgl_jax/srt/managers/schedule_batch.py @@ -186,6 +186,9 @@ def __init__( # LoRA info self.lora_id = lora_id + # LoRA info + self.lora_id = lora_id + # Memory pool info self.req_pool_idx: int | None = None diff --git a/python/sgl_jax/srt/managers/tokenizer_manager.py b/python/sgl_jax/srt/managers/tokenizer_manager.py index 897f1c516..42c1193b1 100644 --- a/python/sgl_jax/srt/managers/tokenizer_manager.py +++ b/python/sgl_jax/srt/managers/tokenizer_manager.py @@ -28,7 +28,7 @@ from sgl_jax.srt.configs.model_config import ModelConfig from sgl_jax.srt.hf_transformers_utils import get_tokenizer -from sgl_jax.srt.lora import LoRARegistry +from sgl_jax.srt.lora.lora_registry import LoRARegistry from sgl_jax.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOut, diff --git a/python/sgl_jax/srt/model_executor/model_runner.py b/python/sgl_jax/srt/model_executor/model_runner.py index cc87e4505..3392f9c35 100644 --- a/python/sgl_jax/srt/model_executor/model_runner.py +++ b/python/sgl_jax/srt/model_executor/model_runner.py @@ -147,6 +147,12 @@ def initialize(self): ): self.is_hybrid = True + # Init lora + if server_args.enable_lora: + self.init_lora_manager() + + self.initialize_jit() + # Init memory pool and attention backends self.init_memory_pool( server_args.max_running_requests, @@ -626,6 +632,23 @@ def set_num_token_hybrid(self): self.swa_max_total_num_tokens, ) + def init_lora_manager(self): + """Initialize LoRA manager for LoRA adapter support.""" + from sgl_jax.srt.lora.lora_manager import LoRAManager + + self.lora_manager = LoRAManager( + base_model=self.model, + base_hf_config=self.model_config.hf_config, + max_loras_per_batch=self.server_args.max_loras_per_batch, + dtype=self.dtype, + mesh=self.mesh, + tp_size=self.tp_size, + max_lora_rank=self.server_args.max_lora_rank, + target_modules=self.server_args.lora_target_modules, + lora_paths=self.server_args.lora_paths, + server_args=self.server_args, + ) + class MockModelRunner(ModelRunner): def __init__( diff --git a/python/sgl_jax/srt/server_args.py b/python/sgl_jax/srt/server_args.py index 0bdf6267d..6ff931384 100644 --- a/python/sgl_jax/srt/server_args.py +++ b/python/sgl_jax/srt/server_args.py @@ -954,6 +954,9 @@ def check_server_args(self): self.chunked_prefill_size % self.page_size == 0 ), "chunked_prefill_size must be divisible by page_size" + # Check LoRA configuration + self.check_lora_server_args() + # Disallow overlap scheduler when speculative decoding is enabled if self.speculative_algorithm is not None and not self.disable_overlap_schedule: raise ValueError( @@ -961,13 +964,10 @@ def check_server_args(self): "Please pass --disable-overlap-schedule when using --speculative-algorithm." ) - # Check LoRA configuration - self.check_lora_server_args() - def check_lora_server_args(self): """Validate and normalize LoRA-related server arguments.""" # Import LoRARef here to avoid circular imports - from sgl_jax.srt.lora import LoRARef + from sgl_jax.srt.lora.lora_registry import LoRARef # Validate max_loras_per_batch assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"