Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions cmd/llm-d-routing-sidecar/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ import (

func main() {
port := flag.String("port", "8000", "the port the sidecar is listening on")
vLLMPort := flag.String("vllm-port", "8001", "the port vLLM is listening on")
connector := flag.String("connector", "nixlv2", "the P/D connector being used. Either nixl, nixlv2 or lmcache")
vLLMPort := flag.String("vllm-port", "8001", "the port vLLM is listening on (also used for SGLang)")
connector := flag.String("connector", "nixlv2", "the P/D connector being used. Either nixl, nixlv2, lmcache, or sglang")
prefillerUseTLS := flag.Bool("prefiller-use-tls", false, "whether to use TLS when sending requests to prefillers")
decoderUseTLS := flag.Bool("decoder-use-tls", false, "whether to use TLS when sending requests to the decoder")
prefillerInsecureSkipVerify := flag.Bool("prefiller-tls-insecure-skip-verify", false, "configures the proxy to skip TLS verification for requests to prefiller")
Expand All @@ -53,8 +53,8 @@ func main() {
ctx := signals.SetupSignalHandler(context.Background())
logger := klog.FromContext(ctx)

if *connector != proxy.ConnectorNIXLV1 && *connector != proxy.ConnectorNIXLV2 && *connector != proxy.ConnectorLMCache {
logger.Info("Error: --connector must either be 'nixl', 'nixlv2' or 'lmcache'")
if *connector != proxy.ConnectorNIXLV1 && *connector != proxy.ConnectorNIXLV2 && *connector != proxy.ConnectorLMCache && *connector != proxy.ConnectorSGLang {
logger.Info("Error: --connector must either be 'nixl', 'nixlv2', 'lmcache', or 'sglang'")
return
}
if *connector == proxy.ConnectorNIXLV1 {
Expand All @@ -81,6 +81,7 @@ func main() {
if *decoderUseTLS {
scheme = "https"
}

targetURL, err := url.Parse(scheme + "://localhost:" + *vLLMPort)
if err != nil {
logger.Error(err, "failed to create targetURL")
Expand Down
148 changes: 148 additions & 0 deletions internal/proxy/connector_sglang.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
Copyright 2025 The llm-d Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package proxy

import (
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"os"
"strconv"
"strings"
"time"
)

func (s *Server) runSGLangProtocol(w http.ResponseWriter, r *http.Request, prefillPodHostPort string) {
s.logger.V(4).Info("running SGLang protocol", "url", prefillPodHostPort)

// Make Request
requestData, err := s.parseSGLangRequest(r)

if err != nil {
if err := errorJSONInvalid(err, w); err != nil {
s.logger.Error(err, "failed to send error response to client")
}
return
}

// Validate prefill host
if prefillPodHostPort == "" {
err := fmt.Errorf("prefill host required for SGLang P/D disaggregation")
if err := errorJSONInvalid(err, w); err != nil {
s.logger.Error(err, "failed to send error response to client")
}
return
}

roomID := s.generateSGLangRoomID()

// Inject bootstrap info for both prefill and decode
bootstrapInfo := s.addSGLangBootstrapInfo(requestData, prefillPodHostPort, roomID)

body, err := json.Marshal(bootstrapInfo)
if err != nil {
if err := errorJSONInvalid(err, w); err != nil {
s.logger.Error(err, "failed to send error response to client")
}
return
}

newReq := r.Clone(r.Context())
newReq.Body = io.NopCloser(strings.NewReader(string(body)))
newReq.ContentLength = int64(len(body))
newReq.Header.Set("Content-Type", "application/json")

// Send concurrent prefill and decode requests
s.sendSGLangConcurrentRequests(w, newReq, prefillPodHostPort)
}

func (s *Server) sendSGLangConcurrentRequests(w http.ResponseWriter, r *http.Request, prefillHost string) {
Req := r.Clone(r.Context())
Req.Body = r.Body
Req.ContentLength = r.ContentLength

// Send prefill request asynchronously
go func() {
prefillHandler, err := s.prefillerProxyHandler(prefillHost)
if err != nil {
s.logger.Error(err, "failed to get prefiller proxy handler", "prefill_host", prefillHost)
return
}
pw := &bufferedResponseWriter{}

prefillHandler.ServeHTTP(pw, Req)
s.logger.V(5).Info("prefill request completed", "status", pw.statusCode)
}()
// Send decode request synchronously
s.decoderProxy.ServeHTTP(w, Req)
}

func (s *Server) addSGLangBootstrapInfo(requestData map[string]interface{}, prefillHostPort string, roomID int64) map[string]interface{} {
modifiedRequest := make(map[string]interface{})
for k, v := range requestData {
modifiedRequest[k] = v
}

// Generate bootstrap host from prefill host
bootstrapHost, bootstrapPort := s.getBootstrapHost(prefillHostPort)

// Add bootstrap information
modifiedRequest[requestFieldBootstrapHost] = bootstrapHost
modifiedRequest[requestFieldBootstrapPort] = bootstrapPort
modifiedRequest[requestFieldBootstrapRoom] = roomID

s.logger.V(5).Info("bootstrap info added",
"bootstrap_host", bootstrapHost,
"bootstrap_port", bootstrapPort,
"bootstrap_room", roomID)

return modifiedRequest
}

func (s *Server) parseSGLangRequest(r *http.Request) (map[string]interface{}, error) {
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %w", err)
}

var requestData map[string]interface{}
if err := json.Unmarshal(body, &requestData); err != nil {
return nil, fmt.Errorf("failed to parse request body: %w", err)
}

return requestData, nil
}

func (s *Server) generateSGLangRoomID() int64 {
return time.Now().UnixNano() + int64(rand.Intn(1000))
}

func (s *Server) getBootstrapHost(prefillHostPort string) (string, int) {
// Extract hostname from prefill host
parts := strings.Split(prefillHostPort, ":")
hostname := parts[0]
// Get bootstrap port from environment variable
bootstrapPort := 8998 // Default SGLang bootstrap port
if portStr := os.Getenv("SGLANG_BOOTSTRAP_PORT"); portStr != "" {
if port, err := strconv.Atoi(portStr); err == nil {
bootstrapPort = port
}
}
return hostname, bootstrapPort
}
12 changes: 12 additions & 0 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ const (
requestFieldStream = "stream"
requestFieldStreamOptions = "stream_options"

// SGLang bootstrap fields
requestFieldBootstrapHost = "bootstrap_host"
requestFieldBootstrapPort = "bootstrap_port"
requestFieldBootstrapRoom = "bootstrap_room"
requestFieldBootstrapRoomID = "bootstrap_room_id"
requestFieldRoomID = "room_id"

// ConnectorNIXLV1 enables the (now deprecated) P/D NIXL v1 protocol
ConnectorNIXLV1 = "nixl"

Expand All @@ -58,6 +65,9 @@ const (

// ConnectorLMCache enables (now deprecated) P/D LMCache protocol
ConnectorLMCache = "lmcache"

// ConnectorSGLang enables SGLang P/D disaggregation protocol
ConnectorSGLang = "sglang"
)

// Config represents the proxy server configuration
Expand Down Expand Up @@ -131,6 +141,8 @@ func NewProxy(port string, decodeURL *url.URL, config Config) (*Server, error) {
server.runConnectorProtocol = server.runLMCacheProtocol
case ConnectorNIXLV1:
server.runConnectorProtocol = server.runNIXLProtocolV1
case ConnectorSGLang:
server.runConnectorProtocol = server.runSGLangProtocol
case ConnectorNIXLV2:
fallthrough
default:
Expand Down