diff --git a/Dockerfile.dev b/Dockerfile.dev new file mode 100644 index 000000000..e3bbe93df --- /dev/null +++ b/Dockerfile.dev @@ -0,0 +1,105 @@ +# Build Stage: using Go 1.25 image +FROM quay.io/projectquay/golang:1.25 AS builder +ARG TARGETOS +ARG TARGETARCH + +# Install build tools +# The builder is based on UBI8, so we need epel-release-8. +RUN dnf install -y 'https://dl.fedoraproject.org/pub/epel/epel-release-latest-8.noarch.rpm' && \ + dnf install -y gcc-c++ libstdc++ libstdc++-devel clang zeromq-devel pkgconfig python3.12-devel python3.12-pip git && \ + dnf clean all +# python3.12-devel needed for CGO compilation (Python headers and python3.12-config for linker flags) + +WORKDIR /workspace + +# Copy the Go Modules manifests +COPY llm-d-inference-scheduler/go.mod go.mod +COPY llm-d-inference-scheduler/go.sum go.sum + +# Copy the go source +COPY llm-d-inference-scheduler/cmd/ cmd/ +COPY llm-d-inference-scheduler/pkg/ pkg/ +COPY llm-d-inference-scheduler/test/ test/ + +# Copy local dependencies from sibling directories +COPY llm-d-kv-cache-manager/ /workspace/llm-d-kv-cache-manager/ +COPY gateway-api-inference-extension/ /workspace/gateway-api-inference-extension/ + +# Set up replace directives to use local checkouts +RUN go mod edit -replace github.com/llm-d/llm-d-kv-cache-manager=/workspace/llm-d-kv-cache-manager +RUN go mod edit -replace sigs.k8s.io/gateway-api-inference-extension=/workspace/gateway-api-inference-extension +RUN go mod tidy + +# HuggingFace tokenizer bindings +RUN mkdir -p lib +# Ensure that the RELEASE_VERSION matches the one used in the imported llm-d-kv-cache-manager version +ARG RELEASE_VERSION=v1.22.1 +RUN curl -L https://github.com/daulet/tokenizers/releases/download/${RELEASE_VERSION}/libtokenizers.${TARGETOS}-${TARGETARCH}.tar.gz | tar -xz -C lib +RUN ranlib lib/*.a + +# Build +# the GOARCH has not a default value to allow the binary be built according to the host where the command +# was called. For example, if we call make image-build in a local env which has the Apple Silicon M1 SO +# the docker BUILDPLATFORM arg will be linux/arm64 when for Apple x86 it will be linux/amd64. Therefore, +# by leaving it empty we can ensure that the container and binary shipped on it will have the same platform. +ENV CGO_ENABLED=1 +ENV GOOS=${TARGETOS:-linux} +ENV GOARCH=${TARGETARCH} +ENV PYTHON=python3.12 +ENV PYTHONPATH=/usr/lib64/python3.12/site-packages:/usr/lib/python3.12/site-packages + +ARG COMMIT_SHA=unknown +ARG BUILD_REF +RUN export CGO_CFLAGS="$(python3.12-config --cflags) -I/workspace/lib" && \ + export CGO_LDFLAGS="$(python3.12-config --ldflags --embed) -L/workspace/lib -ltokenizers -ldl -lm" && \ + go build -a -o bin/epp -ldflags="-extldflags '-L$(pwd)/lib' -X sigs.k8s.io/gateway-api-inference-extension/version.CommitSHA=${COMMIT_SHA} -X sigs.k8s.io/gateway-api-inference-extension/version.BuildRef=${BUILD_REF}" cmd/epp/main.go + +# Use ubi9 as a minimal base image to package the manager binary +# Refer to https://catalog.redhat.com/software/containers/ubi9/ubi-minimal/615bd9b4075b022acc111bf5 for more details +FROM registry.access.redhat.com/ubi9/ubi-minimal:latest +WORKDIR / +COPY --from=builder /workspace/bin/epp /app/epp + +# Install zeromq runtime library and Python runtime needed by the manager. +# The final image is UBI9, so we need epel-release-9. +USER root +RUN microdnf install -y dnf && \ + dnf install -y 'https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm' && \ + dnf install -y zeromq python3.12 python3.12-libs python3.12-pip && \ + dnf clean all && \ + rm -rf /var/cache/dnf /var/lib/dnf && \ + ln -sf /usr/bin/python3.12 /usr/bin/python3 && \ + ln -sf /usr/bin/python3.12 /usr/bin/python +# Note: python3.12 package does not automatically create python3/python symlinks - they must be created manually + +# Install wrapper as a module in site-packages (from local checkout) +RUN mkdir -p /usr/local/lib/python3.12/site-packages/ +COPY llm-d-kv-cache-manager/pkg/preprocessing/chat_completions/render_jinja_template_wrapper.py /usr/local/lib/python3.12/site-packages/ + +# Python deps (no cache, single target) – filter out torch +ENV PIP_NO_CACHE_DIR=1 PIP_DISABLE_PIP_VERSION_CHECK=1 +COPY llm-d-kv-cache-manager/pkg/preprocessing/chat_completions/requirements.txt /tmp/requirements.txt +RUN sed '/^torch\b/d' /tmp/requirements.txt > /tmp/requirements.notorch.txt && \ + python3.12 -m pip install --no-cache-dir --upgrade pip setuptools wheel && \ + python3.12 -m pip install --no-cache-dir --target /usr/local/lib/python3.12/site-packages -r /tmp/requirements.notorch.txt && \ + python3.12 -m pip install --no-cache-dir --target /usr/local/lib/python3.12/site-packages PyYAML && \ + rm /tmp/requirements.txt /tmp/requirements.notorch.txt && \ + rm -rf /root/.cache/pip + +# Python env +ENV PYTHONPATH="/usr/local/lib/python3.12/site-packages:/usr/lib/python3.12/site-packages" +ENV PYTHON=python3.12 +ENV PATH=/usr/bin:/usr/local/bin:$PATH +ENV HF_HOME="/tmp/.cache" + +USER 65532:65532 + +# expose gRPC, health and metrics ports +EXPOSE 9002 +EXPOSE 9003 +EXPOSE 9090 + +# expose port for KV-Events ZMQ SUB socket +EXPOSE 5557 + +ENTRYPOINT ["/app/epp"] diff --git a/Dockerfile.sidecar b/Dockerfile.sidecar index 754b5346d..5f70cbeac 100644 --- a/Dockerfile.sidecar +++ b/Dockerfile.sidecar @@ -17,6 +17,7 @@ RUN go mod download COPY cmd/pd-sidecar/main.go cmd/cmd.go COPY pkg/sidecar pkg/sidecar COPY pkg/common pkg/common +COPY pkg/telemetry pkg/telemetry # Build # the GOARCH has not a default value to allow the binary be built according to the host where the command diff --git a/cmd/epp/main.go b/cmd/epp/main.go index 1952fcf30..714009e46 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -27,20 +27,54 @@ package main import ( "os" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/gateway-api-inference-extension/cmd/epp/runner" "github.com/llm-d/llm-d-inference-scheduler/pkg/metrics" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins" + "github.com/llm-d/llm-d-inference-scheduler/pkg/telemetry" ) func main() { + ctx := ctrl.SetupSignalHandler() + + // Initialize tracing before creating any spans + shutdownTracing, err := telemetry.InitTracing(ctx) + if err != nil { + // Log error but don't fail - tracing is optional + ctrl.Log.Error(err, "Failed to initialize tracing") + } + if shutdownTracing != nil { + defer func() { + if err := shutdownTracing(ctx); err != nil { + ctrl.Log.Error(err, "Failed to shutdown tracing") + } + }() + } + + // Add startup span to verify tracing is working + tracer := telemetry.Tracer() + ctx, span := tracer.Start(ctx, "llm_d.epp.startup") + span.SetAttributes( + attribute.String("component", "llm-d-inference-scheduler"), + attribute.String("operation", "startup"), + ) + defer span.End() + // Register llm-d-inference-scheduler plugins plugins.RegisterAllPlugins() + // Note: GIE built-in plugins are automatically registered by the runner + // when it processes configuration in runner.parsePluginsConfiguration() + if err := runner.NewRunner(). WithCustomCollectors(metrics.GetCollectors()...). - Run(ctrl.SetupSignalHandler()); err != nil { + Run(ctx); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, "startup failed") os.Exit(1) } + span.SetStatus(codes.Ok, "") } diff --git a/cmd/pd-sidecar/main.go b/cmd/pd-sidecar/main.go index cd086eb7d..fb7e96261 100644 --- a/cmd/pd-sidecar/main.go +++ b/cmd/pd-sidecar/main.go @@ -29,6 +29,9 @@ import ( "github.com/llm-d/llm-d-inference-scheduler/pkg/sidecar/proxy" "github.com/llm-d/llm-d-inference-scheduler/pkg/sidecar/version" + "github.com/llm-d/llm-d-inference-scheduler/pkg/telemetry" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" ) var ( @@ -70,6 +73,29 @@ func main() { ctx := ctrl.SetupSignalHandler() log.IntoContext(ctx, logger) + // Initialize tracing before creating any spans + shutdownTracing, err := telemetry.InitTracing(ctx) + if err != nil { + // Log error but don't fail - tracing is optional + logger.Error(err, "Failed to initialize tracing") + } + if shutdownTracing != nil { + defer func() { + if err := shutdownTracing(ctx); err != nil { + logger.Error(err, "Failed to shutdown tracing") + } + }() + } + + // Add startup span to verify tracing is working + tracer := telemetry.Tracer() + ctx, span := tracer.Start(ctx, "llm_d.pd_proxy.startup") + span.SetAttributes( + attribute.String("component", "llm-d-pd-proxy"), + attribute.String("operation", "startup"), + ) + defer span.End() + logger.Info("Proxy starting", "Built on", version.BuildRef, "From Git SHA", version.CommitSHA) // Validate connector @@ -108,6 +134,8 @@ func main() { targetURL, err := url.Parse(scheme + "://localhost:" + *vLLMPort) if err != nil { logger.Error(err, "failed to create targetURL") + span.RecordError(err) + span.SetStatus(codes.Error, "failed to create targetURL") return } @@ -121,6 +149,8 @@ func main() { } if err != nil { logger.Error(err, "failed to create TLS certificate") + span.RecordError(err) + span.SetStatus(codes.Error, "failed to create TLS certificate") return } cert = &tempCert @@ -139,11 +169,15 @@ func main() { validator, err := proxy.NewAllowlistValidator(*enableSSRFProtection, *poolGroup, *inferencePoolNamespace, *inferencePoolName) if err != nil { logger.Error(err, "failed to create SSRF protection validator") + span.RecordError(err) + span.SetStatus(codes.Error, "failed to create SSRF protection validator") return } proxyServer := proxy.NewProxy(*port, targetURL, config) + span.SetStatus(codes.Ok, "") + if err := proxyServer.Start(ctx, cert, validator); err != nil { logger.Error(err, "failed to start proxy server") } diff --git a/go.mod b/go.mod index 2a8d836b9..755534fe2 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,10 @@ require ( github.com/openai/openai-go v1.12.0 github.com/prometheus/client_golang v1.23.2 github.com/stretchr/testify v1.11.1 + go.opentelemetry.io/otel v1.38.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 + go.opentelemetry.io/otel/sdk v1.38.0 + go.opentelemetry.io/otel/trace v1.38.0 golang.org/x/sync v0.18.0 google.golang.org/grpc v1.77.0 k8s.io/api v0.34.2 @@ -95,13 +99,9 @@ require ( github.com/x448/float16 v0.8.4 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect - go.opentelemetry.io/otel v1.38.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 // indirect go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.38.0 // indirect go.opentelemetry.io/otel/metric v1.38.0 // indirect - go.opentelemetry.io/otel/sdk v1.38.0 // indirect - go.opentelemetry.io/otel/trace v1.38.0 // indirect go.opentelemetry.io/proto/otlp v1.7.1 // indirect go.uber.org/atomic v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect diff --git a/pkg/plugins/pre-request/pd_prerequest.go b/pkg/plugins/pre-request/pd_prerequest.go index beebbe46c..1b59dcfc9 100644 --- a/pkg/plugins/pre-request/pd_prerequest.go +++ b/pkg/plugins/pre-request/pd_prerequest.go @@ -7,11 +7,15 @@ import ( "fmt" "net" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" "github.com/llm-d/llm-d-inference-scheduler/pkg/common" + "github.com/llm-d/llm-d-inference-scheduler/pkg/telemetry" ) const ( @@ -67,17 +71,48 @@ func (p *PrefillHeaderHandler) WithName(name string) *PrefillHeaderHandler { } // PreRequest wires prefill SchedulerProfile result into a header to indicate prefill worker -func (p *PrefillHeaderHandler) PreRequest(_ context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult) { +func (p *PrefillHeaderHandler) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult) { + tracer := telemetry.Tracer() + _, span := tracer.Start(ctx, "llm_d.epp.prerequest.pd_disaggregation", + trace.WithSpanKind(trace.SpanKindInternal), + ) + defer span.End() + + // Add component and request attributes + span.SetAttributes( + attribute.String("component", "llm-d-inference-scheduler"), + attribute.String("operation", "prefill_disaggregation"), + ) + + if request != nil && request.TargetModel != "" { + span.SetAttributes(attribute.String("gen_ai.request.model", request.TargetModel)) + } + if request != nil && request.RequestId != "" { + span.SetAttributes(attribute.String("gen_ai.request.id", request.RequestId)) + } + if _, found := request.Headers[common.PrefillPodHeader]; found { request.Headers[common.PrefillPodHeader] = "" // clear header, if already set } prefillProfileRunResult, exists := schedulingResult.ProfileResults[p.prefillProfile] if !exists { + span.SetAttributes( + attribute.Bool("llm_d.epp.pd.disaggregation_enabled", false), + attribute.String("llm_d.epp.pd.reason", "no_prefill_profile_result"), + ) + span.SetStatus(codes.Ok, "") return // prefill profile failed to run or we chose not to run it, no-op in this case } targetPod := prefillProfileRunResult.TargetPods[0].GetPod() prefillHostPort := net.JoinHostPort(targetPod.Address, targetPod.Port) request.Headers[common.PrefillPodHeader] = prefillHostPort // in the form of + + span.SetAttributes( + attribute.Bool("llm_d.epp.pd.disaggregation_enabled", true), + attribute.String("llm_d.epp.pd.prefill_pod_address", targetPod.Address), + attribute.String("llm_d.epp.pd.prefill_pod_port", targetPod.Port), + ) + span.SetStatus(codes.Ok, "") } diff --git a/pkg/plugins/profile/pd_profile_handler.go b/pkg/plugins/profile/pd_profile_handler.go index a3fe3e75d..9a335f1e5 100644 --- a/pkg/plugins/profile/pd_profile_handler.go +++ b/pkg/plugins/profile/pd_profile_handler.go @@ -8,9 +8,13 @@ import ( "fmt" "net" "strconv" + "strings" "github.com/llm-d/llm-d-inference-scheduler/pkg/metrics" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" @@ -19,6 +23,7 @@ import ( logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" "github.com/llm-d/llm-d-inference-scheduler/pkg/common" + "github.com/llm-d/llm-d-inference-scheduler/pkg/telemetry" ) const ( @@ -115,12 +120,60 @@ func (h *PdProfileHandler) WithName(name string) *PdProfileHandler { return h } +// extractPromptFromRequest extracts the prompt text from an LLMRequest based on request type. +func extractPromptFromRequest(request *types.LLMRequest) string { + if request == nil || request.Body == nil { + return "" + } + + // Handle completions request + if request.Body.Completions != nil { + return request.Body.Completions.Prompt + } + + // Handle chat completions request + if request.Body.ChatCompletions != nil { + var messages []string + for _, msg := range request.Body.ChatCompletions.Messages { + messages = append(messages, msg.Content.PlainText()) + } + return strings.Join(messages, " ") + } + + return "" +} + // Pick selects the SchedulingProfiles to run from the list of candidate profiles, while taking into consideration the request properties and the // previously executed cycles along with their results. func (h *PdProfileHandler) Pick(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, profiles map[string]*framework.SchedulerProfile, profileResults map[string]*types.ProfileRunResult) map[string]*framework.SchedulerProfile { + // Start tracing span for profile picking operation + tracer := telemetry.Tracer() + ctx, span := tracer.Start(ctx, "llm_d.epp.profile_handler.pick", + trace.WithSpanKind(trace.SpanKindInternal), + ) + defer span.End() + + // Set initial attributes + span.SetAttributes( + attribute.Int("llm_d.profile_handler.total_profiles", len(profiles)), + attribute.Int("llm_d.profile_handler.executed_profiles", len(profileResults)), + ) + + if request != nil && request.TargetModel != "" { + span.SetAttributes(attribute.String("gen_ai.request.model", request.TargetModel)) + } + if request != nil && request.RequestId != "" { + span.SetAttributes(attribute.String("gen_ai.request.id", request.RequestId)) + } + if _, executed := profileResults[h.decodeProfile]; !executed { // if decode profile was not executed yet, first let the scheduler run the decode profile + span.SetAttributes( + attribute.String("llm_d.profile_handler.decision", "run_decode"), + attribute.String("llm_d.profile_handler.selected_profile", h.decodeProfile), + ) + span.SetStatus(codes.Ok, "") return map[string]*framework.SchedulerProfile{ h.decodeProfile: profiles[h.decodeProfile], } @@ -130,6 +183,11 @@ func (h *PdProfileHandler) Pick(ctx context.Context, cycleState *types.CycleStat // when a profile run fails its result value is nil. we need to check decode result before continuing to prefill // check if all configured profiles have been executed, or if decode failed, no need to run more profiles. if len(profiles) == len(profileResults) || profileResults[h.decodeProfile] == nil { + span.SetAttributes( + attribute.String("llm_d.profile_handler.decision", "complete"), + attribute.Bool("llm_d.profile_handler.decode_failed", profileResults[h.decodeProfile] == nil), + ) + span.SetStatus(codes.Ok, "") return map[string]*framework.SchedulerProfile{} } @@ -137,9 +195,16 @@ func (h *PdProfileHandler) Pick(ctx context.Context, cycleState *types.CycleStat userInput, err := getUserInputBytes(request) if err != nil { log.FromContext(ctx).V(logutil.DEBUG).Error(err, "Failed to get user input bytes") + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil } + span.SetAttributes( + attribute.Int("llm_d.profile_handler.pd_threshold", h.pdThreshold), + attribute.Int("llm_d.profile_handler.user_input_bytes", len(userInput)), + ) + // if we're here that means decode profile ran successfully, and we have additional profile configured that didn't run yet, // which means PD is enabled (otherwise, prefill profile is not configured at all and this profile handler is not used). // inspect decode execution result to decide if prefill should run or not. @@ -154,17 +219,35 @@ func (h *PdProfileHandler) Pick(ctx context.Context, cycleState *types.CycleStat hitPercentagePrefix = float64(hitPrefix*h.hashBlockSize) / float64(len(userInput)) log.FromContext(ctx).V(logutil.DEBUG).Info("Computed hit percentage for prefix cache", "hitPercentage", hitPercentagePrefix, "promptLength", len(userInput)) + + span.SetAttributes( + attribute.Float64("llm_d.profile_handler.cache_hit_ratio", hitPercentagePrefix), + attribute.Int("llm_d.profile_handler.cache_hits", hitPrefix), + ) } - if (1.0-hitPercentagePrefix)*float64(len(userInput)) < float64(h.pdThreshold) { + nonCachedBytes := (1.0 - hitPercentagePrefix) * float64(len(userInput)) + span.SetAttributes(attribute.Float64("llm_d.profile_handler.non_cached_bytes", nonCachedBytes)) + + if nonCachedBytes < float64(h.pdThreshold) { log.FromContext(ctx).Info("Non-cached suffix is smaller than threshold, using decode profile only", "hitPercentage", hitPercentagePrefix) metrics.RecordPDDecision(metrics.DecisionTypeDecodeOnly) + span.SetAttributes( + attribute.String("llm_d.profile_handler.decision", "decode_only"), + attribute.String("llm_d.profile_handler.reason", "below_threshold"), + ) + span.SetStatus(codes.Ok, "") return map[string]*framework.SchedulerProfile{} // do not run prefill } } metrics.RecordPDDecision(metrics.DecisionTypePrefillDecode) // run the prefill profile + span.SetAttributes( + attribute.String("llm_d.profile_handler.decision", "prefill_decode"), + attribute.String("llm_d.profile_handler.selected_profile", h.prefillProfile), + ) + span.SetStatus(codes.Ok, "") return map[string]*framework.SchedulerProfile{ h.prefillProfile: profiles[h.prefillProfile], } diff --git a/pkg/plugins/scorer/precise_prefix_cache.go b/pkg/plugins/scorer/precise_prefix_cache.go index 9f6866c2c..e167105f0 100644 --- a/pkg/plugins/scorer/precise_prefix_cache.go +++ b/pkg/plugins/scorer/precise_prefix_cache.go @@ -6,15 +6,21 @@ import ( "errors" "fmt" "os" + "strings" "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache" "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache/kvevents" preprocessing "github.com/llm-d/llm-d-kv-cache-manager/pkg/preprocessing/chat_completions" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + + "github.com/llm-d/llm-d-inference-scheduler/pkg/telemetry" ) const ( @@ -123,24 +129,75 @@ func (s *PrecisePrefixCacheScorer) WithName(name string) *PrecisePrefixCacheScor return s } +// extractPromptFromRequest extracts the prompt text from an LLMRequest based on request type. +func extractPromptFromRequest(request *types.LLMRequest) string { + if request == nil || request.Body == nil { + return "" + } + + // Handle completions request + if request.Body.Completions != nil { + return request.Body.Completions.Prompt + } + + // Handle chat completions request + if request.Body.ChatCompletions != nil { + var messages []string + for _, msg := range request.Body.ChatCompletions.Messages { + messages = append(messages, msg.Content.PlainText()) + } + return strings.Join(messages, " ") + } + + return "" +} + // Score scores the provided pod based on the KVCache index state. // The returned scores are normalized to a range of 0-1. func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { + // Start tracing span for scoring operation + tracer := telemetry.Tracer() + ctx, span := tracer.Start(ctx, "llm_d.epp.scorer.prefix_cache", + trace.WithSpanKind(trace.SpanKindInternal), + ) + defer span.End() + logger := log.FromContext(ctx).WithName(s.typedName.String()) debugLogger := logger.V(logutil.DEBUG) + // Set initial attributes + span.SetAttributes( + attribute.Int("llm_d.scorer.candidate_pods", len(pods)), + ) + + if request != nil && request.TargetModel != "" { + span.SetAttributes(attribute.String("gen_ai.request.model", request.TargetModel)) + } + if request != nil && request.RequestId != "" { + span.SetAttributes(attribute.String("gen_ai.request.id", request.RequestId)) + } + if request == nil { debugLogger.Info("Request is nil, skipping scoring") + span.SetAttributes(attribute.String("llm_d.scorer.result", "skipped_nil_request")) + span.SetStatus(codes.Ok, "") return nil } scores, err := s.getScores(ctx, request) if err != nil { logger.Error(err, "Failed to get pod scores") + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil } debugLogger.Info("Got pod scores", "scores", scores) + // Track scoring statistics + span.SetAttributes( + attribute.Int("llm_d.scorer.scores_computed", len(scores)), + ) + podToKey := func(pod types.Pod) (string, bool) { metricsPod := pod.GetPod() if metricsPod == nil { @@ -150,7 +207,30 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, _ *types.CycleStat return metricsPod.Address, true } - return indexedScoresToNormalizedScoredPods(pods, podToKey, scores) + normalizedScores := indexedScoresToNormalizedScoredPods(pods, podToKey, scores) + + // Calculate score distribution for observability + if len(normalizedScores) > 0 { + maxScore := 0.0 + totalScore := 0.0 + for _, score := range normalizedScores { + if score > maxScore { + maxScore = score + } + totalScore += score + } + avgScore := totalScore / float64(len(normalizedScores)) + + span.SetAttributes( + attribute.Float64("llm_d.scorer.score.max", maxScore), + attribute.Float64("llm_d.scorer.score.avg", avgScore), + attribute.Int("llm_d.scorer.pods_scored", len(normalizedScores)), + ) + } + + span.SetAttributes(attribute.String("llm_d.scorer.result", "success")) + span.SetStatus(codes.Ok, "") + return normalizedScores } // getScores retrieves the pod scores from the KV-cache indexer diff --git a/pkg/sidecar/proxy/chat_completions.go b/pkg/sidecar/proxy/chat_completions.go index 5ab731a6e..eacc66c4f 100644 --- a/pkg/sidecar/proxy/chat_completions.go +++ b/pkg/sidecar/proxy/chat_completions.go @@ -17,10 +17,16 @@ limitations under the License. package proxy import ( + "context" "net/http" "strings" + "time" "github.com/llm-d/llm-d-inference-scheduler/pkg/common" + "github.com/llm-d/llm-d-inference-scheduler/pkg/telemetry" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" ) var ( @@ -32,6 +38,22 @@ var ( ) func (s *Server) chatCompletionsHandler(w http.ResponseWriter, r *http.Request) { + requestStart := time.Now() + tracer := telemetry.Tracer() + ctx, span := tracer.Start(r.Context(), "llm_d.pd_proxy.request", + trace.WithSpanKind(trace.SpanKindServer), + ) + defer span.End() + + // Update request context with span and start time + ctx = context.WithValue(ctx, "request_start_time", requestStart) + r = r.WithContext(ctx) + + span.SetAttributes( + attribute.String("llm_d.pd_proxy.connector", s.config.Connector), + attribute.String("llm_d.pd_proxy.request.path", r.URL.Path), + ) + var prefillHostPorts []string prefillHostPorts = r.Header.Values(common.PrefillPodHeader) @@ -54,8 +76,21 @@ func (s *Server) chatCompletionsHandler(w http.ResponseWriter, r *http.Request) } } + // Add span for header parsing (always executed) + _, headerSpan := tracer.Start(ctx, "llm_d.pd_proxy.parse_headers") + headerSpan.SetAttributes( + attribute.Int("llm_d.pd_proxy.prefill_headers_count", numHosts), + attribute.Bool("llm_d.pd_proxy.prefiller_sampling_enabled", s.config.EnablePrefillerSampling), + ) + headerSpan.End() + if len(prefillHostPort) == 0 { s.logger.V(4).Info("skip disaggregated prefill") + span.SetAttributes( + attribute.Bool("llm_d.pd_proxy.disaggregation_enabled", false), + attribute.String("llm_d.pd_proxy.reason", "no_prefill_header"), + ) + span.SetStatus(codes.Ok, "") if !s.forwardDataParallel || !s.dataParallelHandler(w, r) { s.decoderProxy.ServeHTTP(w, r) @@ -63,6 +98,12 @@ func (s *Server) chatCompletionsHandler(w http.ResponseWriter, r *http.Request) return } + span.SetAttributes( + attribute.Bool("llm_d.pd_proxy.disaggregation_enabled", true), + attribute.String("llm_d.pd_proxy.prefill_target", prefillHostPort), + attribute.Int("llm_d.pd_proxy.prefill_candidates", numHosts), + ) + // SSRF Protection: Check if the prefill target is allowed if !s.allowlistValidator.IsAllowed(prefillHostPort) { s.logger.Error(nil, "SSRF protection: prefill target not in allowlist", @@ -70,10 +111,16 @@ func (s *Server) chatCompletionsHandler(w http.ResponseWriter, r *http.Request) "clientIP", r.RemoteAddr, "userAgent", r.Header.Get("User-Agent"), "requestPath", r.URL.Path) + span.SetAttributes( + attribute.String("llm_d.pd_proxy.error", "ssrf_protection_denied"), + attribute.String("llm_d.pd_proxy.denied_target", prefillHostPort), + ) + span.SetStatus(codes.Error, "SSRF protection: prefill target not in allowlist") http.Error(w, "Forbidden: prefill target not allowed by SSRF protection", http.StatusForbidden) return } s.logger.V(4).Info("SSRF protection: prefill target allowed", "target", prefillHostPort) s.runConnectorProtocol(w, r, prefillHostPort) + span.SetStatus(codes.Ok, "") } diff --git a/pkg/sidecar/proxy/connector_lmcache.go b/pkg/sidecar/proxy/connector_lmcache.go index f19412e83..c68f821bd 100644 --- a/pkg/sidecar/proxy/connector_lmcache.go +++ b/pkg/sidecar/proxy/connector_lmcache.go @@ -21,11 +21,20 @@ import ( "io" "net/http" "strings" + "time" + + "github.com/llm-d/llm-d-inference-scheduler/pkg/telemetry" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" ) func (s *Server) runLMCacheProtocol(w http.ResponseWriter, r *http.Request, prefillPodHostPort string) { s.logger.Info("running LMCache protocol") + tracer := telemetry.Tracer() + ctx := r.Context() + // Read and parse request body defer r.Body.Close() //nolint:all original, err := io.ReadAll(r.Body) @@ -44,9 +53,18 @@ func (s *Server) runLMCacheProtocol(w http.ResponseWriter, r *http.Request, pref return } + // Prefill Stage + ctx, prefillSpan := tracer.Start(ctx, "llm_d.pd_proxy.prefill", + trace.WithSpanKind(trace.SpanKindInternal), + ) + prefillSpan.SetAttributes( + attribute.String("llm_d.pd_proxy.prefill_target", prefillPodHostPort), + attribute.String("llm_d.pd_proxy.connector", "lmcache"), + ) + prefillStart := time.Now() + // Create prefiller request. Set max_tokens to 1. - ctx := r.Context() preq := r.Clone(ctx) completionRequest[requestFieldMaxTokens] = 1 @@ -75,17 +93,84 @@ func (s *Server) runLMCacheProtocol(w http.ResponseWriter, r *http.Request, pref pw := &bufferedResponseWriter{} prefillHandler.ServeHTTP(pw, preq) + prefillDuration := time.Since(prefillStart) + prefillSpan.SetAttributes( + attribute.Int("llm_d.pd_proxy.prefill.status_code", pw.statusCode), + attribute.Float64("llm_d.pd_proxy.prefill.duration_ms", float64(prefillDuration.Milliseconds())), + ) + if pw.statusCode < 200 || pw.statusCode >= 300 { s.logger.Error(err, "request failed", "code", pw.statusCode) + prefillSpan.SetStatus(codes.Error, "prefill request failed") + prefillSpan.End() w.WriteHeader(pw.statusCode) return } + prefillSpan.SetStatus(codes.Ok, "") + prefillSpan.End() + + // Decode Stage + ctx, decodeSpan := tracer.Start(ctx, "llm_d.pd_proxy.decode", + trace.WithSpanKind(trace.SpanKindInternal), + ) + defer decodeSpan.End() + + decodeSpan.SetAttributes(attribute.String("llm_d.pd_proxy.connector", "lmcache")) + decodeStart := time.Now() // Forward original request to local decoder + r = r.WithContext(ctx) r.Body = io.NopCloser(strings.NewReader(string(original))) - if !s.forwardDataParallel || !s.dataParallelHandler(w, r) { + dataParallelUsed := s.forwardDataParallel && s.dataParallelHandler(w, r) + decodeSpan.SetAttributes(attribute.Bool("llm_d.pd_proxy.decode.data_parallel", dataParallelUsed)) + + if !dataParallelUsed { s.logger.V(4).Info("sending request to decoder", "to", s.decoderURL.Host) + decodeSpan.SetAttributes(attribute.String("llm_d.pd_proxy.decode.target", s.decoderURL.Host)) s.decoderProxy.ServeHTTP(w, r) } + + decodeDuration := time.Since(decodeStart) + decodeSpan.SetAttributes(attribute.Float64("llm_d.pd_proxy.decode.duration_ms", float64(decodeDuration.Milliseconds()))) + decodeSpan.SetStatus(codes.Ok, "") + + // Calculate end-to-end P/D metrics and add to decode span + // These metrics represent the "true" TTFT and latency from the coordinator's perspective + // Note: After tracer.Start() above, ctx contains the decode span, so SpanFromContext returns it + if currentSpan := trace.SpanFromContext(ctx); currentSpan.SpanContext().IsValid() { + // Get request start time from context + var totalDuration time.Duration + var trueTTFT time.Duration + if requestStartValue := ctx.Value("request_start_time"); requestStartValue != nil { + if requestStart, ok := requestStartValue.(time.Time); ok { + totalDuration = time.Since(requestStart) + + // The "true TTFT" in P/D mode is the time until the decoder can start generating + // This includes: gateway routing + scheduling + prefill time + KV transfer coordination overhead + // The decode vLLM will report a low TTFT (since KV is already transferred), + // but this captures the real end-to-end TTFT from the client's perspective + // + // True TTFT = time from gateway request start to decode start + // This includes all coordinator overhead that vLLM-level metrics miss + trueTTFT = decodeStart.Sub(requestStart) + } + } + + // KV transfer overhead: time between prefill completion and decode start + kvTransferOverhead := decodeStart.Sub(prefillStart.Add(prefillDuration)) + + currentSpan.SetAttributes( + // End-to-end P/D timing metrics + attribute.Float64("llm_d.pd_proxy.total_duration_ms", float64(totalDuration.Milliseconds())), + attribute.Float64("llm_d.pd_proxy.true_ttft_ms", float64(trueTTFT.Milliseconds())), + + // Component breakdown + attribute.Float64("llm_d.pd_proxy.prefill_duration_ms", float64(prefillDuration.Milliseconds())), + attribute.Float64("llm_d.pd_proxy.decode_duration_ms", float64(decodeDuration.Milliseconds())), + + // Coordination overhead + attribute.Float64("llm_d.pd_proxy.kv_transfer_overhead_ms", float64(kvTransferOverhead.Milliseconds())), + ) + } } diff --git a/pkg/sidecar/proxy/connector_nixlv2.go b/pkg/sidecar/proxy/connector_nixlv2.go index 265072bbf..5641fe3c0 100644 --- a/pkg/sidecar/proxy/connector_nixlv2.go +++ b/pkg/sidecar/proxy/connector_nixlv2.go @@ -21,8 +21,13 @@ import ( "io" "net/http" "strings" + "time" "github.com/google/uuid" + "github.com/llm-d/llm-d-inference-scheduler/pkg/telemetry" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" ) func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefillPodHostPort string) { @@ -57,9 +62,20 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi uuidStr := uuid.String() // Prefill Stage + tracer := telemetry.Tracer() + ctx := r.Context() + + ctx, prefillSpan := tracer.Start(ctx, "llm_d.pd_proxy.prefill", + trace.WithSpanKind(trace.SpanKindInternal), + ) + prefillSpan.SetAttributes( + attribute.String("llm_d.pd_proxy.request_id", uuidStr), + attribute.String("llm_d.pd_proxy.prefill_target", prefillPodHostPort), + attribute.String("llm_d.pd_proxy.connector", "nixlv2"), + ) + prefillStart := time.Now() // 1. Prepare prefill request - ctx := r.Context() preq := r.Clone(ctx) preq.Header.Add(requestHeaderRequestID, uuidStr) @@ -107,11 +123,21 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi pw := &bufferedResponseWriter{} prefillHandler.ServeHTTP(pw, preq) + prefillDuration := time.Since(prefillStart) + prefillSpan.SetAttributes( + attribute.Int("llm_d.pd_proxy.prefill.status_code", pw.statusCode), + attribute.Float64("llm_d.pd_proxy.prefill.duration_ms", float64(prefillDuration.Milliseconds())), + ) + if pw.statusCode < 200 || pw.statusCode >= 300 { s.logger.Error(err, "request failed", "code", pw.statusCode) + prefillSpan.SetStatus(codes.Error, "prefill request failed") + prefillSpan.End() w.WriteHeader(pw.statusCode) return } + prefillSpan.SetStatus(codes.Ok, "") + prefillSpan.End() // Process response - extract p/d fields var prefillerResponse map[string]any @@ -133,15 +159,31 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi // Decode Stage + ctx, decodeSpan := tracer.Start(ctx, "llm_d.pd_proxy.decode", + trace.WithSpanKind(trace.SpanKindInternal), + ) + defer decodeSpan.End() + + decodeSpan.SetAttributes( + attribute.String("llm_d.pd_proxy.request_id", uuidStr), + attribute.String("llm_d.pd_proxy.connector", "nixlv2"), + ) + decodeStart := time.Now() + // 1. Prepare decode request dreq := r.Clone(ctx) dreq.Header.Add(requestHeaderRequestID, uuidStr) delete(completionRequest, requestFieldStream) + streamingEnabled := false if streamOk { completionRequest[requestFieldStream] = streamValue + if streamBool, ok := streamValue.(bool); ok { + streamingEnabled = streamBool + } } + decodeSpan.SetAttributes(attribute.Bool("llm_d.pd_proxy.decode.streaming", streamingEnabled)) if streamOptionsOk { completionRequest[requestFieldStreamOptions] = streamOptionsValue } @@ -168,8 +210,64 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi // 2. Forward to local decoder. s.logger.V(5).Info("sending request to decoder", "body", string(dbody)) - if !s.forwardDataParallel || !s.dataParallelHandler(w, dreq) { + dataParallelUsed := s.forwardDataParallel && s.dataParallelHandler(w, dreq) + decodeSpan.SetAttributes(attribute.Bool("llm_d.pd_proxy.decode.data_parallel", dataParallelUsed)) + + if !dataParallelUsed { s.logger.V(4).Info("sending request to decoder", "to", s.decoderURL.Host) + decodeSpan.SetAttributes(attribute.String("llm_d.pd_proxy.decode.target", s.decoderURL.Host)) s.decoderProxy.ServeHTTP(w, dreq) } + + decodeDuration := time.Since(decodeStart) + decodeSpan.SetAttributes(attribute.Float64("llm_d.pd_proxy.decode.duration_ms", float64(decodeDuration.Milliseconds()))) + decodeSpan.SetStatus(codes.Ok, "") + + // Calculate end-to-end P/D metrics and add to decode span + // These metrics represent the "true" TTFT and latency from the coordinator's perspective + // Note: After tracer.Start() above, ctx contains the decode span, so SpanFromContext returns it + if currentSpan := trace.SpanFromContext(ctx); currentSpan.SpanContext().IsValid() { + // Get request start time from context + var totalDuration time.Duration + var trueTTFT time.Duration + if requestStartValue := ctx.Value("request_start_time"); requestStartValue != nil { + if requestStart, ok := requestStartValue.(time.Time); ok { + totalDuration = time.Since(requestStart) + + // The "true TTFT" in P/D mode is the time until the decoder can start generating + // This includes: gateway routing + scheduling + prefill time + KV transfer coordination overhead + // The decode vLLM will report a low TTFT (since KV is already transferred), + // but this captures the real end-to-end TTFT from the client's perspective + // + // True TTFT = time from gateway request start to decode start + // This includes all coordinator overhead that vLLM-level metrics miss + trueTTFT = decodeStart.Sub(requestStart) + } + } + + // KV transfer overhead: time between prefill vLLM completion and decode request start + // This captures the coordination overhead between prefill and decode stages + // Note: This is an approximation - ideally we'd measure from prefill vLLM completion + // to when the decode vLLM receives the first token, but that requires response parsing + kvTransferOverhead := decodeStart.Sub(prefillStart.Add(prefillDuration)) + + // For TPOT (Time Per Output Token), we would need to: + // 1. Parse streaming response to detect token boundaries + // 2. Calculate: (total_decode_time - decode_ttft) / (num_output_tokens - 1) + // This is complex and requires response intercepting, so we defer to trace analysis + + currentSpan.SetAttributes( + // End-to-end P/D timing metrics + // These are the metrics that should be used instead of per-instance vLLM metrics + attribute.Float64("llm_d.pd_proxy.total_duration_ms", float64(totalDuration.Milliseconds())), + attribute.Float64("llm_d.pd_proxy.true_ttft_ms", float64(trueTTFT.Milliseconds())), + + // Component breakdown for analysis + attribute.Float64("llm_d.pd_proxy.prefill_duration_ms", float64(prefillDuration.Milliseconds())), + attribute.Float64("llm_d.pd_proxy.decode_duration_ms", float64(decodeDuration.Milliseconds())), + + // Coordination overhead between prefill and decode + attribute.Float64("llm_d.pd_proxy.kv_transfer_overhead_ms", float64(kvTransferOverhead.Milliseconds())), + ) + } } diff --git a/pkg/sidecar/proxy/connector_sglang.go b/pkg/sidecar/proxy/connector_sglang.go index b02fb7231..4e03068ae 100644 --- a/pkg/sidecar/proxy/connector_sglang.go +++ b/pkg/sidecar/proxy/connector_sglang.go @@ -27,6 +27,11 @@ import ( "strconv" "strings" "time" + + "github.com/llm-d/llm-d-inference-scheduler/pkg/telemetry" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" ) var ( @@ -76,12 +81,28 @@ func (s *Server) runSGLangProtocol(w http.ResponseWriter, r *http.Request, prefi } func (s *Server) sendSGLangConcurrentRequests(w http.ResponseWriter, r *http.Request, body []byte, prefillHost string) { + tracer := telemetry.Tracer() + ctx := r.Context() + + // Prefill Stage - async + ctx, prefillSpan := tracer.Start(ctx, "llm_d.pd_proxy.prefill", + trace.WithSpanKind(trace.SpanKindInternal), + ) + prefillSpan.SetAttributes( + attribute.String("llm_d.pd_proxy.prefill_target", prefillHost), + attribute.String("llm_d.pd_proxy.connector", "sglang"), + attribute.Bool("llm_d.pd_proxy.prefill.async", true), + ) + prefillStart := time.Now() + // Create separate requests for prefill and decode prefillReq := cloneWithJSONBody(r, body) decodeReq := cloneWithJSONBody(r, body) prefillHandler, err := s.prefillerProxyHandler(prefillHost) if err != nil { + prefillSpan.SetStatus(codes.Error, "failed to create prefill handler") + prefillSpan.End() if err := errorBadGateway(err, w); err != nil { s.logger.Error(err, "failed to send error response to client") } @@ -90,13 +111,77 @@ func (s *Server) sendSGLangConcurrentRequests(w http.ResponseWriter, r *http.Req // Send prefill request asynchronously go func() { + defer prefillSpan.End() pw := &bufferedResponseWriter{} prefillHandler.ServeHTTP(pw, prefillReq) + prefillDuration := time.Since(prefillStart) + prefillSpan.SetAttributes( + attribute.Int("llm_d.pd_proxy.prefill.status_code", pw.statusCode), + attribute.Float64("llm_d.pd_proxy.prefill.duration_ms", float64(prefillDuration.Milliseconds())), + ) + if pw.statusCode >= 200 && pw.statusCode < 300 { + prefillSpan.SetStatus(codes.Ok, "") + } else { + prefillSpan.SetStatus(codes.Error, "prefill request failed") + } s.logger.V(5).Info("prefill request completed", "status", pw.statusCode) }() + // Decode Stage - sync + ctx, decodeSpan := tracer.Start(ctx, "llm_d.pd_proxy.decode", + trace.WithSpanKind(trace.SpanKindInternal), + ) + defer decodeSpan.End() + + decodeSpan.SetAttributes( + attribute.String("llm_d.pd_proxy.connector", "sglang"), + attribute.Bool("llm_d.pd_proxy.decode.concurrent_with_prefill", true), + ) + decodeStart := time.Now() + // Send decode request synchronously + decodeReq = decodeReq.WithContext(ctx) s.decoderProxy.ServeHTTP(w, decodeReq) + + decodeDuration := time.Since(decodeStart) + decodeSpan.SetAttributes( + attribute.Float64("llm_d.pd_proxy.decode.duration_ms", float64(decodeDuration.Milliseconds())), + attribute.String("llm_d.pd_proxy.decode.target", s.decoderURL.Host), + ) + decodeSpan.SetStatus(codes.Ok, "") + + // Calculate end-to-end P/D metrics and add to decode span + // Note: SGLang runs prefill and decode concurrently, so timing is different from sequential P/D + // Note: After tracer.Start() above, ctx contains the decode span, so SpanFromContext returns it + if currentSpan := trace.SpanFromContext(ctx); currentSpan.SpanContext().IsValid() { + // Get request start time from context + var totalDuration time.Duration + var trueTTFT time.Duration + if requestStartValue := ctx.Value("request_start_time"); requestStartValue != nil { + if requestStart, ok := requestStartValue.(time.Time); ok { + totalDuration = time.Since(requestStart) + + // For SGLang, prefill and decode run concurrently, but True TTFT still needs to capture + // the full coordinator overhead from gateway start to when decode can begin generating. + // This includes: gateway routing + scheduling overhead + time to start decode request + // Note: In concurrent mode, this is different from sequential P/D where we wait for prefill + trueTTFT = decodeStart.Sub(requestStart) + } + } + + currentSpan.SetAttributes( + // End-to-end P/D timing metrics for concurrent P/D + attribute.Float64("llm_d.pd_proxy.total_duration_ms", float64(totalDuration.Milliseconds())), + attribute.Float64("llm_d.pd_proxy.true_ttft_ms", float64(trueTTFT.Milliseconds())), + + // Component breakdown (note: prefill runs concurrently) + attribute.Float64("llm_d.pd_proxy.decode_duration_ms", float64(decodeDuration.Milliseconds())), + + // Note: prefill_duration_ms is tracked in the async prefill span + // SGLang-specific: prefill and decode overlap in time + attribute.Bool("llm_d.pd_proxy.concurrent_pd", true), + ) + } } func cloneWithJSONBody(r *http.Request, body []byte) *http.Request { diff --git a/pkg/sidecar/proxy/proxy_helpers.go b/pkg/sidecar/proxy/proxy_helpers.go index 5f68bb4ea..43bd23535 100644 --- a/pkg/sidecar/proxy/proxy_helpers.go +++ b/pkg/sidecar/proxy/proxy_helpers.go @@ -10,6 +10,8 @@ import ( "net/url" "syscall" "time" + + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" ) // startHTTP starts the HTTP reverse proxy. @@ -27,8 +29,15 @@ func (s *Server) startHTTP(ctx context.Context, cert *tls.Certificate) error { } s.addr = ln.Addr() + // Wrap handler with OpenTelemetry middleware to extract trace context from incoming requests + handler := otelhttp.NewHandler(s.handler, "llm-d-pd-proxy", + otelhttp.WithSpanNameFormatter(func(_ string, r *http.Request) string { + return "llm_d.pd_proxy." + r.Method + " " + r.URL.Path + }), + ) + server := &http.Server{ - Handler: s.handler, + Handler: handler, // No ReadTimeout/WriteTimeout for LLM inference - can take hours for large contexts IdleTimeout: 300 * time.Second, // 5 minutes for keep-alive connections ReadHeaderTimeout: 30 * time.Second, // Reasonable for headers only diff --git a/pkg/telemetry/tracing.go b/pkg/telemetry/tracing.go new file mode 100644 index 000000000..58fa7732a --- /dev/null +++ b/pkg/telemetry/tracing.go @@ -0,0 +1,126 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package telemetry + +import ( + "context" + "fmt" + "os" + "strings" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/sdk/resource" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + semconv "go.opentelemetry.io/otel/semconv/v1.24.0" + "go.opentelemetry.io/otel/trace" + "sigs.k8s.io/controller-runtime/pkg/log" +) + +const ( + defaultServiceName = "llm-d-inference-scheduler" +) + +var ( + // serviceName holds the service name for the tracer + // Set during InitTracing() from OTEL_SERVICE_NAME env var + serviceName = defaultServiceName +) + +// InitTracing initializes OpenTelemetry tracing with OTLP exporter. +// Configuration is done via environment variables: +// - OTEL_SERVICE_NAME: Service name for tracing (default: llm-d-inference-scheduler) +// - OTEL_EXPORTER_OTLP_ENDPOINT: OTLP collector endpoint (default: http://localhost:4317) +// - OTEL_TRACES_SAMPLER: Sampling strategy (default: parentbased_traceidratio) +// - OTEL_TRACES_SAMPLER_ARG: Sampling ratio (default: 0.1 for 10%) +func InitTracing(ctx context.Context) (func(context.Context) error, error) { + logger := log.FromContext(ctx) + + // Get service name from environment, fallback to default + svcName := os.Getenv("OTEL_SERVICE_NAME") + if svcName == "" { + svcName = defaultServiceName + } + // Store in package variable for Tracer() function + serviceName = svcName + + // Get OTLP endpoint from environment + endpoint := os.Getenv("OTEL_EXPORTER_OTLP_ENDPOINT") + if endpoint == "" { + endpoint = "localhost:4317" + } + + // Strip http:// or https:// prefix if present + // otlptracegrpc.WithEndpoint() expects host:port only + endpoint = stripScheme(endpoint) + + logger.Info("Initializing OpenTelemetry tracing", "endpoint", endpoint, "service", svcName) + + // Create OTLP trace exporter + exporter, err := otlptracegrpc.New(ctx, + otlptracegrpc.WithEndpoint(endpoint), + otlptracegrpc.WithInsecure(), // Use WithTLSCredentials() in production + ) + if err != nil { + return nil, fmt.Errorf("failed to create OTLP trace exporter: %w", err) + } + + // Create resource with service name + res, err := resource.New(ctx, + resource.WithAttributes( + semconv.ServiceName(svcName), + ), + ) + if err != nil { + return nil, fmt.Errorf("failed to create resource: %w", err) + } + + // Create trace provider with parent-based sampling + tp := sdktrace.NewTracerProvider( + sdktrace.WithBatcher(exporter), + sdktrace.WithResource(res), + sdktrace.WithSampler(sdktrace.ParentBased(sdktrace.TraceIDRatioBased(0.1))), // 10% sampling + ) + + // Set global trace provider + otel.SetTracerProvider(tp) + + // Set W3C trace context propagator + otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator( + propagation.TraceContext{}, + propagation.Baggage{}, + )) + + logger.Info("OpenTelemetry tracing initialized successfully") + + // Return shutdown function + return tp.Shutdown, nil +} + +// Tracer returns a tracer for the inference scheduler +func Tracer() trace.Tracer { + return otel.Tracer(serviceName) +} + +// stripScheme removes http:// or https:// prefix from endpoint URL +// OpenTelemetry gRPC exporter expects host:port format only +func stripScheme(endpoint string) string { + endpoint = strings.TrimPrefix(endpoint, "http://") + endpoint = strings.TrimPrefix(endpoint, "https://") + return endpoint +}