diff --git a/Dockerfile.batch b/Dockerfile.batch new file mode 100644 index 000000000..d94fb5394 --- /dev/null +++ b/Dockerfile.batch @@ -0,0 +1,46 @@ +# Build Stage: using Go 1.24 image +FROM quay.io/projectquay/golang:1.24 AS builder +ARG TARGETOS +ARG TARGETARCH +ARG COMMIT_SHA=unknown +ARG BUILD_REF + +WORKDIR /workspace +# Copy the Go Modules manifests +COPY go.mod go.mod +COPY go.sum go.sum +# cache deps before building and copying source so that we don't need to re-download as much +# and so that source changes don't invalidate our downloaded layer +RUN go mod download + +# Copy the go source +COPY cmd/batch/ cmd/batch/ +COPY pkg/batch/ pkg/batch/ +COPY pkg/metrics/ pkg/metrics/ +COPY pkg/common pkg/common +COPY pkg/sidecar/version pkg/sidecar/version + +# Build +# the GOARCH has not a default value to allow the binary be built according to the host where the command +# was called. For example, if we call make image-build in a local env which has the Apple Silicon M1 SO +# the docker BUILDPLATFORM arg will be linux/arm64 when for Apple x86 it will be linux/amd64. Therefore, +# by leaving it empty we can ensure that the container and binary shipped on it will have the same platform. +ENV CGO_ENABLED=0 +ENV GOOS=${TARGETOS:-linux} +ENV GOARCH=${TARGETARCH} +RUN go build -a -o bin/batch \ + -ldflags="-X github.com/llm-d/llm-d-inference-scheduler/pkg/batch/version.CommitSHA=${COMMIT_SHA} -X github.com/llm-d/llm-d-inference-scheduler/pkg/batch/version.BuildRef=${BUILD_REF}" \ + cmd/batch/main.go + +FROM registry.access.redhat.com/ubi9/ubi-micro:latest +WORKDIR / +COPY --from=builder /workspace/bin/batch /app/batch +USER 65532:65532 + +# expose gRPC, health and metrics ports +EXPOSE 9002 +EXPOSE 9003 +EXPOSE 9090 + + +ENTRYPOINT ["/app/batch"] diff --git a/Makefile b/Makefile index 4d5263739..4858b79e4 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,7 @@ TARGETOS ?= $(shell go env GOOS) TARGETARCH ?= $(shell go env GOARCH) PROJECT_NAME ?= llm-d-inference-scheduler SIDECAR_IMAGE_NAME ?= llm-d-routing-sidecar +BATCH_IMAGE_NAME ?= llm-d-routing-batch VLLM_SIMULATOR_IMAGE_NAME ?= llm-d-inference-sim SIDECAR_NAME ?= pd-sidecar IMAGE_REGISTRY ?= ghcr.io/llm-d @@ -19,6 +20,11 @@ SIDECAR_TAG ?= dev export SIDECAR_TAG SIDECAR_IMAGE_TAG_BASE ?= $(IMAGE_REGISTRY)/$(SIDECAR_IMAGE_NAME) export SIDECAR_IMAGE ?= $(SIDECAR_IMAGE_TAG_BASE):$(SIDECAR_TAG) +BATCH_TAG ?= dev +export BATCH_TAG +BATCH_IMAGE_TAG_BASE ?= $(IMAGE_REGISTRY)/$(BATCH_IMAGE_NAME) +export BATCH_IMAGE ?= $(BATCH_IMAGE_TAG_BASE):$(BATCH_TAG) + NAMESPACE ?= hc4ai-operator VLLM_SIMULATOR_TAG ?= v0.6.1 export VLLM_SIMULATOR_TAG @@ -94,16 +100,24 @@ CGO_LDFLAGS := $(PYTHON_LDFLAGS) $(PYTHON_LIBS) '-L$(shell pwd)/lib' -ltokenizer # Internal variables for generic targets epp_IMAGE = $(EPP_IMAGE) sidecar_IMAGE = $(SIDECAR_IMAGE) +batch_IMAGE = $(BATCH_IMAGE) epp_NAME = epp sidecar_NAME = $(SIDECAR_NAME) +batch_NAME = batch epp_LDFLAGS = -ldflags="$(LDFLAGS)" sidecar_LDFLAGS = +batch_LDFLAGS = -ldflags="$(LDFLAGS)" epp_CGO_CFLAGS = "${CGO_CFLAGS}" sidecar_CGO_CFLAGS = +batch_CGO_CFLAGS = "${CGO_CFLAGS}" epp_CGO_LDFLAGS = "${CGO_LDFLAGS}" sidecar_CGO_LDFLAGS = +batch_CGO_LDFLAGS = "${CGO_LDFLAGS}" epp_TEST_FILES = go list ./... | grep -v /test/ | grep -v ./pkg/sidecar/ sidecar_TEST_FILES = go list ./pkg/sidecar/... +batch_TEST_FILES = go list ./... | grep -v /test/ | grep -v ./pkg/batch/ + + .PHONY: help help: ## Print help @@ -142,7 +156,7 @@ format: ## Format Go source files test: test-unit test-e2e ## Run unit tests and e2e tests .PHONY: test-unit -test-unit: test-unit-epp test-unit-sidecar +test-unit: test-unit-epp test-unit-sidecar test-unit-batch .PHONY: test-unit-% test-unit-%: download-tokenizer install-dependencies ## Run unit tests @@ -173,7 +187,7 @@ lint: check-golangci-lint check-typos ## Run lint ##@ Build .PHONY: build -build: build-epp build-sidecar ## Build the project +build: build-epp build-sidecar build-batch ## Build the project .PHONY: build-% build-%: check-go install-dependencies download-tokenizer ## Build the project @@ -183,7 +197,7 @@ build-%: check-go install-dependencies download-tokenizer ## Build the project ##@ Container Build/Push .PHONY: image-build -image-build: image-build-epp image-build-sidecar ## Build Docker image +image-build: image-build-epp image-build-sidecar image-build-batch ## Build Docker image .PHONY: image-build-% image-build-%: check-container-tool ## Build Docker image ## Build Docker image using $(CONTAINER_RUNTIME) @@ -197,7 +211,7 @@ image-build-%: check-container-tool ## Build Docker image ## Build Docker image -t $($*_IMAGE) -f Dockerfile.$* . .PHONY: image-push -image-push: image-push-epp image-push-sidecar ## Push container images to registry +image-push: image-push-epp image-push-sidecar image-push-batch ## Push container images to registry .PHONY: image-push-% image-push-%: check-container-tool ## Push container image to registry @@ -431,6 +445,7 @@ env-dev-kind: ## Run under kind ($(KIND_CLUSTER_NAME)) EPP_IMAGE=$(EPP_IMAGE) \ VLLM_SIMULATOR_IMAGE=${VLLM_SIMULATOR_IMAGE} \ SIDECAR_IMAGE=${SIDECAR_IMAGE} \ + BATCH_IMAGE=${BATCH_IMAGE} \ ./scripts/kind-dev-env.sh; \ fi diff --git a/cmd/batch/main.go b/cmd/batch/main.go new file mode 100644 index 000000000..9d9e4de84 --- /dev/null +++ b/cmd/batch/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "os" + + batchrunner "github.com/llm-d/llm-d-inference-scheduler/cmd/batch/runner" + "github.com/llm-d/llm-d-inference-scheduler/pkg/metrics" + ctrl "sigs.k8s.io/controller-runtime" +) + +func main() { + + if err := batchrunner.NewBatchRunner().WithCustomCollectors(metrics.GetBatchCollectors()...).Run(ctrl.SetupSignalHandler()); err != nil { + os.Exit(1) + } +} diff --git a/cmd/batch/runner/runner.go b/cmd/batch/runner/runner.go new file mode 100644 index 000000000..a1c9122fb --- /dev/null +++ b/cmd/batch/runner/runner.go @@ -0,0 +1,161 @@ +package batchrunner + +import ( + "context" + "flag" + "fmt" + "net/http" + + "github.com/llm-d/llm-d-inference-scheduler/pkg/batch" + "github.com/llm-d/llm-d-inference-scheduler/pkg/batch/redis" + "github.com/llm-d/llm-d-inference-scheduler/pkg/sidecar/version" + "github.com/prometheus/client_golang/prometheus" + uberzap "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "k8s.io/client-go/rest" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/log/zap" + "sigs.k8s.io/controller-runtime/pkg/metrics/filters" + metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +type BatchRunner struct { + customCollectors []prometheus.Collector +} + +var ( + setupLog = ctrl.Log.WithName("setup") + logVerbosity = flag.Int("v", logging.DEFAULT, "number for the log level verbosity") + concurrency = flag.Int("concurrency", 8, "number of concurrent workers") + endpoint = flag.String("endpoint", "http://localhost:30080/v1/completions", "inference endpoint") + metricsPort = flag.Int("metrics-port", runserver.DefaultMetricsPort, "The metrics port") + metricsEndpointAuth = flag.Bool("metrics-endpoint-auth", true, "Enables authentication and authorization of the metrics endpoint") + requestMergePolicy = flag.String("request-merge-policy", "random-robin", "The request merge policy to use. Supported policies: random-robin") + messageQueueImpl = flag.String("message-queue-impl", "redis-pubsub", "The message queue implementation to use. Supported implementations: redis-pubsub") +) + +func NewBatchRunner() *BatchRunner { + return &BatchRunner{} +} + +func (r *BatchRunner) Run(ctx context.Context) error { + opts := zap.Options{ + Development: true, + } + opts.BindFlags(flag.CommandLine) + flag.Parse() + initLogging(&opts) + + /*if *tracing { + err := common.InitTracing(ctx, setupLog) + if err != nil { + return err + } + }*/ + + ////////setupLog.Info("GIE build", "commit-sha", version.CommitSHA, "build-ref", version.BuildRef) + + // Validate flags + if err := validateFlags(); err != nil { + setupLog.Error(err, "Failed to validate flags") + return err + } + + // Print all flag values + flags := make(map[string]any) + flag.VisitAll(func(f *flag.Flag) { + flags[f.Name] = f.Value + }) + setupLog.Info("Flags processed", "flags", flags) + + // --- Get Kubernetes Config --- + cfg, err := ctrl.GetConfig() + if err != nil { + setupLog.Error(err, "Failed to get Kubernetes rest config") + return err + } + + metrics.Register(r.customCollectors...) + metrics.RecordInferenceExtensionInfo(version.CommitSHA, version.BuildRef) + // Register metrics handler. + // Metrics endpoint is enabled in 'config/default/kustomization.yaml'. The Metrics options configure the server. + // More info: + // - https://pkg.go.dev/sigs.k8s.io/controller-runtime@v0.19.1/pkg/metrics/server + // - https://book.kubebuilder.io/reference/metrics.html + metricsServerOptions := metricsserver.Options{ + BindAddress: fmt.Sprintf(":%d", *metricsPort), + FilterProvider: func() func(c *rest.Config, httpClient *http.Client) (metricsserver.Filter, error) { + if *metricsEndpointAuth { + return filters.WithAuthenticationAndAuthorization + } + + return nil + }(), + } + + httpClient := &http.Client{ + // TODO: configure + } + + msrv, _ := metricsserver.NewServer(metricsServerOptions, cfg, httpClient /* TODO: not sure about using the same one*/) + go msrv.Start(ctx) + + var policy batch.RequestMergePolicy + switch *requestMergePolicy { + case "random-robin": + policy = batch.NewRandomRobinPolicy() + default: + // TODO: validate this actually works + setupLog.Error(nil, "Unknown request merge policy", "policy", *requestMergePolicy) + return nil + } + + var impl batch.Flow + switch *messageQueueImpl { + case "redis-pubsub": + impl = redis.NewRedisMQFlow() + default: + // TODO: validate this actually works + setupLog.Error(nil, "Unknown message queue implementation", "impl", *messageQueueImpl) + return nil + } + + requestChannel := policy.MergeRequestChannels(impl.RequestChannels()).Channel + for w := 1; w <= *concurrency; w++ { + go batch.Worker(ctx, *endpoint, httpClient, requestChannel, impl.RetryChannel(), impl.ResultChannel()) + } + + impl.Start(ctx) + <-ctx.Done() + return nil +} + +// TODO: is this dup of +func initLogging(opts *zap.Options) { + // Unless -zap-log-level is explicitly set, use -v + useV := true + flag.Visit(func(f *flag.Flag) { + if f.Name == "zap-log-level" { + useV = false + } + }) + if useV { + // See https://pkg.go.dev/sigs.k8s.io/controller-runtime/pkg/log/zap#Options.Level + lvl := -1 * (*logVerbosity) + opts.Level = uberzap.NewAtomicLevelAt(zapcore.Level(int8(lvl))) + } + + logger := zap.New(zap.UseFlagOptions(opts), zap.RawZapOpts(uberzap.AddCaller())) + ctrl.SetLogger(logger) +} +func (r *BatchRunner) WithCustomCollectors(collectors ...prometheus.Collector) *BatchRunner { + r.customCollectors = collectors + return r +} +func validateFlags() error { + + return nil +} diff --git a/cmd/epp/main.go b/cmd/epp/main.go index 1952fcf30..37ae6baa0 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -39,7 +39,7 @@ func main() { plugins.RegisterAllPlugins() if err := runner.NewRunner(). - WithCustomCollectors(metrics.GetCollectors()...). + WithCustomCollectors(metrics.GetEPPCollectors()...). Run(ctrl.SetupSignalHandler()); err != nil { os.Exit(1) } diff --git a/deploy/components/inference-gateway/deployments.yaml b/deploy/components/inference-gateway/deployments.yaml index cc56789a3..1dc7341eb 100644 --- a/deploy/components/inference-gateway/deployments.yaml +++ b/deploy/components/inference-gateway/deployments.yaml @@ -63,3 +63,29 @@ spec: name: epp-config - name: cache emptyDir: {} +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: ${BATCH_NAME} + labels: + app: ${BATCH_NAME} +spec: + replicas: 1 + selector: + matchLabels: + app: ${BATCH_NAME} + template: + metadata: + labels: + app: ${BATCH_NAME} + spec: + serviceAccountName: ${BATCH_NAME} + terminationGracePeriodSeconds: 130 + containers: + - name: batch + image: ${BATCH_IMAGE} + imagePullPolicy: IfNotPresent + args: + - --redis-addr + - "redis-service:6379" \ No newline at end of file diff --git a/deploy/components/inference-gateway/service-accounts.yaml b/deploy/components/inference-gateway/service-accounts.yaml index a92013a0a..abc74b054 100644 --- a/deploy/components/inference-gateway/service-accounts.yaml +++ b/deploy/components/inference-gateway/service-accounts.yaml @@ -2,3 +2,8 @@ apiVersion: v1 kind: ServiceAccount metadata: name: ${EPP_NAME} +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + name: ${BATCH_NAME} \ No newline at end of file diff --git a/docs/images/batch_processor_redis_architecture.png b/docs/images/batch_processor_redis_architecture.png new file mode 100644 index 000000000..7b96766f9 Binary files /dev/null and b/docs/images/batch_processor_redis_architecture.png differ diff --git a/go.mod b/go.mod index 9ed37c188..b8a04ca7e 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/onsi/gomega v1.38.2 github.com/openai/openai-go v1.12.0 github.com/prometheus/client_golang v1.23.2 + github.com/alicebob/miniredis/v2 v2.35.0 github.com/stretchr/testify v1.11.1 golang.org/x/sync v0.18.0 google.golang.org/grpc v1.76.0 @@ -93,6 +94,7 @@ require ( github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/x448/float16 v0.8.4 // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect go.opentelemetry.io/otel v1.38.0 // indirect diff --git a/pkg/batch/GUIDE.md b/pkg/batch/GUIDE.md new file mode 100644 index 000000000..a8125f882 --- /dev/null +++ b/pkg/batch/GUIDE.md @@ -0,0 +1,7 @@ +# Batch Processor - User Guide + +The batch processor helps in workflows where you have requests that are latency tolerant. I.e., SLOs in minutes/ hours instead of seconds. + +The batch processor pulls requests from a message queue (or several MQs according to a policy), sends to the Inference Gateway (IGW) and retries if necessary (e.g., message was shedded). + +![Batch Processor - Redis architecture](/docs/images/batch_processor_redis_architecture.png "BP - Redis") \ No newline at end of file diff --git a/pkg/batch/README.md b/pkg/batch/README.md new file mode 100644 index 000000000..19e21c0ac --- /dev/null +++ b/pkg/batch/README.md @@ -0,0 +1,61 @@ +# Batch Processor + +## Overview +The batch processor (BP) provides asynchronous workflows for variable SLO-based inference requests. + + +## Architecture + +An underlying implementation should provide persistent messaging that adhere to the interface defined in [api.go](api.go). + +A pluggable request policy is used to merge multiple request channels into a single request channel on which the batch worker is listening. + +An example for such a policy is a [Random Robin Policy](random_robin_policy.go). + +Each [Batch Processor worker](worker.go) is responsible for pulling requests from the merged request channel, submit to the IGW and apply retry logic if needed. + + + +### Requests + +Request messages should have the following format: +```json +{ + "id" : "unique identifier for result mapping", + "deadline" : "deadline in Unix seconds", + "payload" : {regular inference payload} +} +``` + +Example: +```json +{ + "id" : "19933123533434", + "deadline" : "1764045130", + "payload": {"model":"food-review","prompt":"hi", "max_tokens":10,"temperature":0} +} +``` + +### Results + +Messages on the results channel will have the following structure: + +```json +{ + "id" : "id mapped to the request", + "payload" : {/*inference payload*/} , + // or + "error" : "error's reason" +} +``` + + +## Implementations + +### Redis + +An example implementation based on Redis is provided which behaves as follows: + +- Redis Lists as the request queues. +- Redis Sorted Set as the retry exponential backoff implementation. +- Redis List as the result queue. diff --git a/pkg/batch/api.go b/pkg/batch/api.go new file mode 100644 index 000000000..d6a34448a --- /dev/null +++ b/pkg/batch/api.go @@ -0,0 +1,43 @@ +package batch + +import "context" + +type Flow interface { + // starts processing requests. + Start(ctx context.Context) + + // returns the channel for requests. Implementation is responsible for populating this channel. + RequestChannels() []RequestChannel + // returns the channel that accepts messages to be retries with their backoff delay. + RetryChannel() chan RetryMessage + // returns the channel for storing the results. + ResultChannel() chan ResultMessage +} + +// TODO: how to handle retries here? +type RequestMergePolicy interface { + MergeRequestChannels(channels []RequestChannel) RequestChannel +} + +type RequestMessage struct { + Id string `json:"id"` + RetryCount int `json:"retry_count,omitempty"` + DeadlineUnixSec string `json:"deadline"` + Payload map[string]any `json:"payload"` +} + +// TODO: decide about metadata +type RequestChannel struct { + Channel chan RequestMessage + Metadata map[string]any +} + +type RetryMessage struct { + RequestMessage + BackoffDurationSeconds float64 +} + +type ResultMessage struct { + Id string `json:"id"` + Payload map[string]any `json:"payload"` +} diff --git a/pkg/batch/random_robin_policy.go b/pkg/batch/random_robin_policy.go new file mode 100644 index 000000000..3c7fd1b45 --- /dev/null +++ b/pkg/batch/random_robin_policy.go @@ -0,0 +1,47 @@ +package batch + +import "reflect" + +func NewRandomRobinPolicy() RequestMergePolicy { + return &RandomRobinPolicy{} +} + +type RandomRobinPolicy struct { +} + +func (r *RandomRobinPolicy) MergeRequestChannels(channels []RequestChannel) RequestChannel { + mergedChannel := make(chan RequestMessage) + + cases := make([]reflect.SelectCase, len(channels)) + for i, ch := range channels { + cases[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch.Channel)} + } + + go func() { + for { + i1, val, ok := reflect.Select(cases) + if !ok { + // one of the channels is closed, remove it + newCases := make([]reflect.SelectCase, 0, len(cases)-1) + for i2, c := range cases { + if i2 != i1 { + newCases = append(newCases, c) + } + } + cases = newCases + if len(cases) == 0 { + close(mergedChannel) + break + } + } else { + mergedChannel <- val.Interface().(RequestMessage) + } + + } + }() + + return RequestChannel{ + Channel: mergedChannel, + Metadata: map[string]any{}, + } +} diff --git a/pkg/batch/random_robin_policy_test.go b/pkg/batch/random_robin_policy_test.go new file mode 100644 index 000000000..3964eac61 --- /dev/null +++ b/pkg/batch/random_robin_policy_test.go @@ -0,0 +1,41 @@ +package batch + +import ( + "testing" +) + +func TestProcessAllChannels(t *testing.T) { + msgsPerChannel := 5 + channels := []RequestChannel{ + {Channel: make(chan RequestMessage, msgsPerChannel), Metadata: map[string]any{}}, + {Channel: make(chan RequestMessage, msgsPerChannel), Metadata: map[string]any{}}, + {Channel: make(chan RequestMessage, msgsPerChannel), Metadata: map[string]any{}}, + } + policy := NewRandomRobinPolicy() + + // Send messages to each channel + for i, ch := range channels { + for range msgsPerChannel { + ch.Channel <- RequestMessage{Id: string(rune('A' + i))} + } + } + mergedChannel := policy.MergeRequestChannels(channels).Channel + close(channels[0].Channel) + close(channels[1].Channel) + close(channels[2].Channel) + + counts := map[string]int{} + totalMessages := msgsPerChannel * 3 + for range totalMessages { + msg := <-mergedChannel + counts[msg.Id]++ + + } + + for i := range 3 { + id := string(rune('A' + i)) + if counts[id] != msgsPerChannel { + t.Errorf("Expected %d messages from channel %s, got %d", msgsPerChannel, id, counts[id]) + } + } +} diff --git a/pkg/batch/redis/redisimpl.go b/pkg/batch/redis/redisimpl.go new file mode 100644 index 000000000..5e3829e57 --- /dev/null +++ b/pkg/batch/redis/redisimpl.go @@ -0,0 +1,204 @@ +package redis + +import ( + "context" + "encoding/json" + "flag" + "fmt" + + "strconv" + "time" + + "github.com/llm-d/llm-d-inference-scheduler/pkg/batch" + "github.com/redis/go-redis/v9" + + "sigs.k8s.io/controller-runtime/pkg/log" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +var ( + redisAddr = flag.String("redis.addr", "localhost:6379", "address of the Redis server") + + // TODO: support multiple request queues with metadata (for policy) + requestQueueName = flag.String("redis.request-queue-name", "batch-queue", "name of the Redis queue for request messages") + retryQueueName = flag.String("redis.retry-queue-name", "batch-sortedset-retry", "name of the Redis sorted set for retry messages") + resultQueueName = flag.String("redis.result-queue-name", "batch-queue-result", "name of the Redis queue for result messages") +) + +// TODO: think about what to do if Redis is down +type RedisMQFlow struct { + rdb *redis.Client + requestChannel chan batch.RequestMessage + retryChannel chan batch.RetryMessage + resultChannel chan batch.ResultMessage +} + +func NewRedisMQFlow() *RedisMQFlow { + rdb := redis.NewClient(&redis.Options{ + Addr: *redisAddr, + + // TODO: check specific version of go-redis. might require higher version. + // Explicitly disable maintenance notifications + // This prevents the client from sending CLIENT MAINT_NOTIFICATIONS ON + // import "github.com/redis/go-redis/v9/maintnotifications" + // MaintNotificationsConfig: &maintnotifications.Config{ + // Mode: maintnotifications.ModeDisabled, + // }, + }) + return &RedisMQFlow{ + rdb: rdb, + requestChannel: make(chan batch.RequestMessage), + retryChannel: make(chan batch.RetryMessage), + resultChannel: make(chan batch.ResultMessage), + } +} + +func (r *RedisMQFlow) Start(ctx context.Context) { + go requestWorker(ctx, r.rdb, r.requestChannel, *requestQueueName) + + go addMsgToRetryWorker(ctx, r.rdb, r.retryChannel, *retryQueueName) + + go retryWorker(ctx, r.rdb, r.requestChannel) + + go resultWorker(ctx, r.rdb, r.resultChannel, *resultQueueName) +} +func (r *RedisMQFlow) RequestChannels() []batch.RequestChannel { + return []batch.RequestChannel{{Channel: r.requestChannel, Metadata: map[string]any{}}} +} + +func (r *RedisMQFlow) RetryChannel() chan batch.RetryMessage { + return r.retryChannel +} + +func (r *RedisMQFlow) ResultChannel() chan batch.ResultMessage { + return r.resultChannel +} + +// Listening on the results channel and responsible for writing results into Redis. +func resultWorker(ctx context.Context, rdb *redis.Client, resultChannel chan batch.ResultMessage, resultsQueueName string) { + logger := log.FromContext(ctx) + for { + select { + case <-ctx.Done(): + return + + case msg := <-resultChannel: + bytes, err := json.Marshal(msg) + var msgStr string + if err != nil { + msgStr = fmt.Sprintf(`{"id" : "%s", "error": "%s"}`, msg.Id, "Failed to marshal result to string") + } else { + msgStr = string(bytes) + } + err = publishRedis(ctx, rdb, resultsQueueName, msgStr) + if err != nil { + // Not going to retry here. Just log the error. + logger.V(logutil.DEFAULT).Error(err, "Failed to publish result message to Redis") + } + } + } +} + +// pulls from Redis Queue and put in the request channel +func requestWorker(ctx context.Context, rdb *redis.Client, msgChannel chan batch.RequestMessage, queueName string) { + sub := rdb.Subscribe(ctx, queueName) + defer sub.Close() + + // redis.WithChannelSize(100) -- TODO: consider exposing to config + ch := sub.Channel() + for { + select { + case <-ctx.Done(): + return + + case rmsg := <-ch: + var msg batch.RequestMessage + + err := json.Unmarshal([]byte(rmsg.Payload), &msg) + if err != nil { + // TODO: log failed to unmarshal message. + fmt.Println(err) + continue // skip this message + + } + msgChannel <- msg + } + } + +} + +// Puts msgs from the retry channel into a Redis sorted-set with a duration Score. +func addMsgToRetryWorker(ctx context.Context, rdb *redis.Client, retryChannel chan batch.RetryMessage, sortedSetName string) error { + for { + select { + case <-ctx.Done(): + return nil + + case msg := <-retryChannel: + score := float64(time.Now().Unix()) + msg.BackoffDurationSeconds + bytes, err := json.Marshal(msg.RequestMessage) + if err != nil { + fmt.Printf("Failed to marshal message for retry in Redis: %s", err.Error()) + continue // skip this message. TODO: log + } + err = rdb.ZAdd(ctx, sortedSetName, redis.Z{ + Score: score, + Member: string(bytes), + }).Err() + + if err != nil { + fmt.Printf("Failed to add message for retry in Redis: %s", err.Error()) + // TODO: + } + } + } + +} + +// TODO +// Every second polls the sorted set and publishes the messages that need to be retried into the request queue +func retryWorker(ctx context.Context, rdb *redis.Client, msgChannel chan batch.RequestMessage) { + for { + select { + case <-ctx.Done(): + return + + default: + currentTimeSec := float64(time.Now().Unix()) + results, err := rdb.ZRangeByScore(ctx, *retryQueueName, &redis.ZRangeBy{ + Min: "0", + Max: strconv.FormatFloat(currentTimeSec, 'f', -1, 64), + }).Result() + if err != nil { + panic(err) + } + for _, msg := range results { + var message batch.RequestMessage + err := json.Unmarshal([]byte(msg), &message) + if err != nil { + fmt.Println(err) + + } + err = rdb.ZRem(ctx, *retryQueueName, msg).Err() + if err != nil { + fmt.Println(err) + + } + // TODO: Publish to request channel or directly to request queue in Redis? + msgChannel <- message + } + time.Sleep(time.Second) + } + } + +} + +func publishRedis(ctx context.Context, rdb *redis.Client, channelId, msg string) error { + logger := log.FromContext(ctx) + err := rdb.Publish(ctx, channelId, msg).Err() + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Error publishing message:%s\n", err.Error()) + return err + } + return nil +} diff --git a/pkg/batch/worker.go b/pkg/batch/worker.go new file mode 100644 index 000000000..7f1b25e02 --- /dev/null +++ b/pkg/batch/worker.go @@ -0,0 +1,154 @@ +package batch + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "math" + "math/rand" + "net/http" + "strconv" + "time" + + "github.com/llm-d/llm-d-inference-scheduler/pkg/metrics" + "sigs.k8s.io/controller-runtime/pkg/log" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +var baseDelaySeconds = 2 + +func Worker(ctx context.Context, endpoint string, httpClient *http.Client, requestChannel chan RequestMessage, + retryChannel chan RetryMessage, resultChannel chan ResultMessage) { + + logger := log.FromContext(ctx) + for { + select { + case <-ctx.Done(): + logger.V(logutil.DEFAULT).Info("Worker finishing.") + return + case msg := <-requestChannel: + if msg.RetryCount == 0 { + // Only count first attempt as a new request. + metrics.BatchReqs.Inc() + } + payloadBytes := parseAndValidateRequest(resultChannel, msg) + if payloadBytes == nil { + continue + } + + sendInferenceRequest := func() { + logger.V(logutil.DEBUG).Info("Sending inference request.") + request, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewBuffer(payloadBytes)) + if err != nil { + metrics.FailedReqs.Inc() + resultChannel <- CreateErrorResultMessage(msg.Id, fmt.Sprintf("Failed to create request to inference: %s", err.Error())) + return + } + request.Header.Set("Content-Type", "application/json") + request.Header.Set("x-gateway-inference-objective", "food-review-2") + + result, err := httpClient.Do(request) + if err != nil { + metrics.FailedReqs.Inc() + resultChannel <- CreateErrorResultMessage(msg.Id, fmt.Sprintf("Failed to send request to inference: %s", err.Error())) + return + } + defer result.Body.Close() + // Retrying on too many requests or any server-side error. + if result.StatusCode == 429 || result.StatusCode >= 500 && result.StatusCode < 600 { + if result.StatusCode == 429 { + metrics.SheddedRequests.Inc() + } + retryMessage(msg, retryChannel, resultChannel) + } else { + payloadBytes, err := io.ReadAll(result.Body) + if err != nil { + // Retrying on IO-read error as well. + retryMessage(msg, retryChannel, resultChannel) + } else { + var resultPayload map[string]any + err := json.Unmarshal(payloadBytes, &resultPayload) + if err != nil { + // Not retrying on unmarshalling error. + metrics.FailedReqs.Inc() + resultChannel <- CreateErrorResultMessage(msg.Id, fmt.Sprintf("Failed to unmarshal inference result payload: %v", err)) + return + } + metrics.SuccessfulReqs.Inc() + resultChannel <- ResultMessage{ + Id: msg.Id, + Payload: resultPayload, + } + } + } + } + sendInferenceRequest() + } + } +} +func parseAndValidateRequest(resultChannel chan ResultMessage, msg RequestMessage) []byte { + deadline, err := strconv.ParseInt(msg.DeadlineUnixSec, 10, 64) + if err != nil { + metrics.FailedReqs.Inc() + resultChannel <- CreateErrorResultMessage(msg.Id, "Failed to parse deadline, should be in Unix seconds.") + return nil + } + + if deadline < time.Now().Unix() { + metrics.ExceededDeadlineReqs.Inc() + resultChannel <- CreateDeadlineExceededResultMessage(msg.Id) + return nil + } + + payloadBytes, err := json.Marshal(msg.Payload) + if err != nil { + metrics.FailedReqs.Inc() + resultChannel <- CreateErrorResultMessage(msg.Id, fmt.Sprintf("Failed to marshal message's payload: %s", err.Error())) + return nil + } + return payloadBytes +} + +// If it is not after deadline, just publish again. +func retryMessage(msg RequestMessage, retryChannel chan RetryMessage, resultChannel chan ResultMessage) { + deadline, err := strconv.ParseInt(msg.DeadlineUnixSec, 10, 64) + if err != nil { // Can't really happen because this was already parsed in the past. But we don't care to have this branch. + resultChannel <- CreateErrorResultMessage(msg.Id, "Failed to parse deadline. Should be in Unix time") + return + } + secondsToDeadline := deadline - time.Now().Unix() + if secondsToDeadline < 0 { + metrics.ExceededDeadlineReqs.Inc() + resultChannel <- CreateDeadlineExceededResultMessage(msg.Id) + } else { + msg.RetryCount++ + backoffDurationSeconds := math.Min( + float64(baseDelaySeconds)*(math.Pow(2, float64(msg.RetryCount))), + float64(secondsToDeadline)) + + jitter := rand.Float64() - 0.5 + finalDuration := backoffDurationSeconds + jitter + if finalDuration < 0 { + finalDuration = 0 + } + metrics.Retries.Inc() + retryChannel <- RetryMessage{ + RequestMessage: msg, + BackoffDurationSeconds: finalDuration, + } + + } + +} +func CreateErrorResultMessage(id string, errMsg string) ResultMessage { + return ResultMessage{ + Id: id, + Payload: map[string]any{"error": errMsg}, + } +} + +func CreateDeadlineExceededResultMessage(id string) ResultMessage { + return CreateErrorResultMessage(id, "deadline exceeded") +} diff --git a/pkg/batch/worker_test.go b/pkg/batch/worker_test.go new file mode 100644 index 000000000..fd6769bae --- /dev/null +++ b/pkg/batch/worker_test.go @@ -0,0 +1,144 @@ +package batch + +import ( + "context" + "fmt" + "net/http" + "testing" + "time" +) + +func TestRetryMessage_deadlinePassed(t *testing.T) { + retryChannel := make(chan RetryMessage, 1) + resultChannel := make(chan ResultMessage, 1) + msg := RequestMessage{ + Id: "123", + RetryCount: 0, + DeadlineUnixSec: fmt.Sprintf("%d", time.Now().Add(time.Second*-10).Unix()), + } + retryMessage(msg, retryChannel, resultChannel) + if len(retryChannel) > 0 { + t.Errorf("Message that its deadline passed should not be retried. Got a message in the retry channel") + return + } + if len(resultChannel) != 1 { + t.Errorf("Expected one message in the result channel") + return + + } + result := <-resultChannel + if result.Payload["error"] != "deadline exceeded" { + t.Errorf("Expected error to be: 'deadline exceeded', got: %s", result.Payload["error"]) + } + +} + +func TestRetryMessage_retry(t *testing.T) { + retryChannel := make(chan RetryMessage, 1) + resultChannel := make(chan ResultMessage, 1) + msg := RequestMessage{ + Id: "123", + RetryCount: 0, + DeadlineUnixSec: fmt.Sprintf("%d", time.Now().Add(time.Second*10).Unix()), + } + retryMessage(msg, retryChannel, resultChannel) + if len(resultChannel) > 0 { + t.Errorf("Should not have any messages in the result channel") + return + } + if len(retryChannel) != 1 { + t.Errorf("Expected one message in the retry channel") + return + } + retryMsg := <-retryChannel + if retryMsg.RetryCount != 1 { + t.Errorf("Expected retry count to be 1, got %d", msg.RetryCount) + } + +} + +// RoundTripFunc is a type that implements http.RoundTripper +type RoundTripFunc func(req *http.Request) (*http.Response, error) + +// RoundTrip executes a single HTTP transaction, obtaining the Response for a given Request. +func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +// NewTestClient returns an *http.Client with its Transport replaced by a custom RoundTripper. +func NewTestClient(fn RoundTripFunc) *http.Client { + return &http.Client{ + Transport: RoundTripFunc(fn), + } +} + +func TestSheddedRequest(t *testing.T) { + msgId := "123" + httpclient := NewTestClient(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: nil, + Header: make(http.Header), + }, nil + }) + requestChannel := make(chan RequestMessage, 1) + retryChannel := make(chan RetryMessage, 1) + resultChannel := make(chan ResultMessage, 1) + ctx := context.Background() + + go Worker(ctx, "http://localhost:30080/v1/completions", httpclient, requestChannel, retryChannel, resultChannel) + deadline := time.Now().Add(time.Second * 100).Unix() + + requestChannel <- RequestMessage{ + Id: msgId, + RetryCount: 0, + DeadlineUnixSec: fmt.Sprintf(("%d"), deadline), + Payload: map[string]any{"model": "food-review", "prompt": "hi", "max_tokens": 10, "temperature": 0}, + } + + select { + case r := <-retryChannel: + if r.Id != msgId { + t.Errorf("Expected retry message id to be %s, got %s", msgId, r.Id) + } + case <-resultChannel: + t.Errorf("Should not get result from a 5xx response") + + } + +} +func TestSuccessfulRequest(t *testing.T) { + msgId := "123" + httpclient := NewTestClient(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: nil, + Header: make(http.Header), + }, nil + }) + requestChannel := make(chan RequestMessage, 1) + retryChannel := make(chan RetryMessage, 1) + resultChannel := make(chan ResultMessage, 1) + ctx := context.Background() + + go Worker(ctx, "http://localhost:30080/v1/completions", httpclient, requestChannel, retryChannel, resultChannel) + + deadline := time.Now().Add(time.Second * 100).Unix() + + requestChannel <- RequestMessage{ + Id: msgId, + RetryCount: 0, + DeadlineUnixSec: fmt.Sprintf(("%d"), deadline), + Payload: map[string]any{"model": "food-review", "prompt": "hi", "max_tokens": 10, "temperature": 0}, + } + + select { + case <-retryChannel: + t.Errorf("Should not get a retry from a 200 response") + case r := <-resultChannel: + if r.Id != msgId { + t.Errorf("Expected result message id to be %s, got %s", msgId, r.Id) + } + } + +} diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 03aadebdb..f300678ff 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -27,15 +27,47 @@ var ( }, []string{"decision_type"}, // "decode-only" or "prefill-decode" ) + Retries = prometheus.NewCounter(prometheus.CounterOpts{ + Subsystem: SchedulerSubsystem, Name: "batch_request_retries_total", + Help: "Total number of batch request retries.", + }) + + BatchReqs = prometheus.NewCounter(prometheus.CounterOpts{ + Subsystem: SchedulerSubsystem, Name: "batch_request_total", + Help: "Total number of batch requests.", + }) + ExceededDeadlineReqs = prometheus.NewCounter(prometheus.CounterOpts{ + Subsystem: SchedulerSubsystem, Name: "batch_ßexceeded_deadline_requests_total", + Help: "Total number of batch requests that exceeded their deadline.", + }) + FailedReqs = prometheus.NewCounter(prometheus.CounterOpts{ + Subsystem: SchedulerSubsystem, Name: "batch_failed_requests_total", + Help: "Total number of batch requests that failed.", + }) + SuccessfulReqs = prometheus.NewCounter(prometheus.CounterOpts{ + Subsystem: SchedulerSubsystem, Name: "batch_successful_requests_total", + Help: "Total number of batch requests that succeeded.", + }) + SheddedRequests = prometheus.NewCounter(prometheus.CounterOpts{ + Subsystem: SchedulerSubsystem, Name: "batch_shedded_requests_total", + Help: "Total number of batch requests that were shedded.", + }) ) // GetCollectors returns all custom collectors for the llm-d-inference-scheduler. -func GetCollectors() []prometheus.Collector { +func GetEPPCollectors() []prometheus.Collector { return []prometheus.Collector{ SchedulerPDDecisionCount, } } +// GetCollectors returns all custom collectors for the batch processor. +func GetBatchCollectors() []prometheus.Collector { + return []prometheus.Collector{ + Retries, BatchReqs, ExceededDeadlineReqs, FailedReqs, SuccessfulReqs, SheddedRequests, + } +} + // RecordPDDecision records the type of P/D disaggregation decision made. func RecordPDDecision(decisionType string) { SchedulerPDDecisionCount.WithLabelValues(decisionType).Inc() diff --git a/scripts/kind-dev-env.sh b/scripts/kind-dev-env.sh index 204045b18..4a4e23376 100755 --- a/scripts/kind-dev-env.sh +++ b/scripts/kind-dev-env.sh @@ -36,6 +36,13 @@ export EPP_TAG="${EPP_TAG:-dev}" EPP_IMAGE="${EPP_IMAGE:-${IMAGE_REGISTRY}/llm-d-inference-scheduler:${EPP_TAG}}" export EPP_IMAGE +# Set a default BATCH_TAG if not provided +export BATCH_TAG="${BATCH_TAG:-dev}" + +# Set a default BATCH_IMAGE if not provided +BATCH_IMAGE="${BATCH_IMAGE:-${IMAGE_REGISTRY}/llm-d-inference-scheduler:${BATCH_TAG}}" +export BATCH_IMAGE + # Set the model name to deploy export MODEL_NAME="${MODEL_NAME:-food-review}" # Extract model family (e.g., "meta-llama" from "meta-llama/Llama-3.1-8B-Instruct") @@ -48,6 +55,9 @@ export MODEL_NAME_SAFE=$(echo "${MODEL_ID}" | tr '[:upper:]' '[:lower:]' | tr ' # Set the endpoint-picker to deploy export EPP_NAME="${EPP_NAME:-${MODEL_NAME_SAFE}-endpoint-picker}" +# Set the batch to deploy +export BATCH_NAME="${BATCH_NAME:-${MODEL_NAME_SAFE}-batch}" + # Set the default routing side car image tag export SIDECAR_TAG="${SIDECAR_TAG:-dev}" @@ -203,6 +213,13 @@ else kind --name ${CLUSTER_NAME} load docker-image ${SIDECAR_IMAGE} fi +# Load the batch image into the cluster +if [ "${CONTAINER_RUNTIME}" == "podman" ]; then + podman save ${BATCH_IMAGE} -o /dev/stdout | kind --name ${CLUSTER_NAME} load image-archive /dev/stdin +else + kind --name ${CLUSTER_NAME} load docker-image ${BATCH_IMAGE} +fi + # ------------------------------------------------------------------------------ # CRD Deployment (Gateway API + GIE) # ------------------------------------------------------------------------------ @@ -236,7 +253,7 @@ envsubst '$PRIMARY_PORT' < ${EPP_CONFIG} > ${TEMP_FILE} kubectl --context ${KUBE_CONTEXT} create configmap epp-config --from-file=epp-config.yaml=${TEMP_FILE} kustomize build --enable-helm ${KUSTOMIZE_DIR} \ - | envsubst '${POOL_NAME} ${MODEL_NAME} ${MODEL_NAME_SAFE} ${EPP_NAME} ${EPP_IMAGE} ${VLLM_SIMULATOR_IMAGE} \ + | envsubst '${POOL_NAME} ${MODEL_NAME} ${MODEL_NAME_SAFE} ${EPP_NAME} ${EPP_IMAGE} ${BATCH_NAME} ${BATCH_IMAGE} ${VLLM_SIMULATOR_IMAGE} \ ${PD_ENABLED} ${KV_CACHE_ENABLED} ${SIDECAR_IMAGE} ${TARGET_PORTS} \ ${VLLM_REPLICA_COUNT} ${VLLM_REPLICA_COUNT_P} ${VLLM_REPLICA_COUNT_D} ${VLLM_DATA_PARALLEL_SIZE}' \ | kubectl --context ${KUBE_CONTEXT} apply -f - diff --git a/test/integration/redisimpl_test.go b/test/integration/redisimpl_test.go new file mode 100644 index 000000000..582eba92a --- /dev/null +++ b/test/integration/redisimpl_test.go @@ -0,0 +1,58 @@ +package integration_test + +import ( + "context" + "flag" + "strconv" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/llm-d/llm-d-inference-scheduler/pkg/batch" + "github.com/llm-d/llm-d-inference-scheduler/pkg/batch/redis" +) + +func TestRedisImpl(t *testing.T) { + s := miniredis.RunT(t) + rAddr := s.Host() + ":" + s.Port() + + ctx := context.Background() + flag.Set("redis.addr", rAddr) + flow := redis.NewRedisMQFlow() + flow.Start(ctx) + + flow.RetryChannel() <- batch.RetryMessage{ + RequestMessage: batch.RequestMessage{ + Id: "test-id", + DeadlineUnixSec: strconv.FormatInt(time.Now().Add(time.Minute).Unix(), 10), + Payload: map[string]any{"model": "food-review", "prompt": "hi", "max_tokens": 10, "temperature": 0}, + }, + BackoffDurationSeconds: 2, + } + totalReqCount := 0 + for _, value := range flow.RequestChannels() { + totalReqCount += len(value.Channel) + } + + if totalReqCount > 0 { + t.Errorf("Expected no messages in request channels yet") + return + } + if len(flow.ResultChannel()) > 0 { + t.Errorf("Expected no messages in result channel yet") + return + } + time.Sleep(3 * time.Second) + + mergedChannel := batch.NewRandomRobinPolicy().MergeRequestChannels(flow.RequestChannels()) + + select { + case req := <-mergedChannel.Channel: + if req.Id != "test-id" { + t.Errorf("Expected message id to be test-id, got %s", req.Id) + } + case <-time.After(2 * time.Second): + t.Errorf("Expected message in request channel after backoff") + } + +}