diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1af3a13b4..77c8242ba 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,13 +11,13 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 - name: Verify vendor/ is not present run: stat vendor && exit 1 || exit 0 - name: Set up Go - uses: actions/setup-go@v5 + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 with: go-version: 1.24.2 cache: true diff --git a/.github/workflows/cli-build.yml b/.github/workflows/cli-build.yml index 26ca41738..0bc8b01d3 100644 --- a/.github/workflows/cli-build.yml +++ b/.github/workflows/cli-build.yml @@ -25,8 +25,8 @@ jobs: id-token: write contents: read steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 + - uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 + - uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 with: go-version-file: cmd/cli/go.mod cache: true @@ -35,7 +35,7 @@ jobs: working-directory: cmd/cli run: | make release VERSION=${{ github.sha }} - - uses: actions/upload-artifact@v4 + - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 with: name: dist path: | diff --git a/.github/workflows/cli-validate.yml b/.github/workflows/cli-validate.yml index f2098b35e..717afb550 100644 --- a/.github/workflows/cli-validate.yml +++ b/.github/workflows/cli-validate.yml @@ -31,11 +31,11 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 - name: List targets id: generate - uses: docker/bake-action/subaction/list-targets@v6 + uses: docker/bake-action/subaction/list-targets@3acf805d94d93a86cce4ca44798a76464a75b88c with: files: ./cmd/cli/docker-bake.hcl target: validate @@ -51,7 +51,7 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 diff --git a/.github/workflows/dmr-daily-check.yml b/.github/workflows/dmr-daily-check.yml index e0541251e..4345ee0d7 100644 --- a/.github/workflows/dmr-daily-check.yml +++ b/.github/workflows/dmr-daily-check.yml @@ -22,7 +22,7 @@ jobs: steps: - name: Set up Docker - uses: docker/setup-docker-action@v4 + uses: docker/setup-docker-action@3fb92d6d9c634363128c8cce4bc3b2826526370a - name: Install docker-model-plugin run: | diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3bf042e6a..4111703cd 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -19,16 +19,21 @@ on: required: false type: string default: "latest" + vllmVersion: + description: 'vLLM version' + required: false + type: string + default: "0.11.0" jobs: test: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 - name: Set up Go - uses: actions/setup-go@v5 + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 with: go-version: 1.24.2 cache: true @@ -41,7 +46,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout repo - uses: actions/checkout@v4 + uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 - name: Format tags id: tags @@ -59,15 +64,21 @@ jobs: echo "docker/model-runner:latest-cuda" >> "$GITHUB_OUTPUT" fi echo 'EOF' >> "$GITHUB_OUTPUT" + echo "vllm-cuda<> "$GITHUB_OUTPUT" + echo "docker/model-runner:${{ inputs.releaseTag }}-vllm-cuda" >> "$GITHUB_OUTPUT" + if [ "${{ inputs.pushLatest }}" == "true" ]; then + echo "docker/model-runner:latest-vllm-cuda" >> "$GITHUB_OUTPUT" + fi + echo 'EOF' >> "$GITHUB_OUTPUT" - name: Log in to DockerHub - uses: docker/login-action@v3 + uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef with: username: "docker" password: ${{ secrets.ORG_ACCESS_TOKEN }} - name: Set up Buildx - uses: docker/setup-buildx-action@v3 + uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 with: version: "lab:latest" driver: cloud @@ -75,9 +86,10 @@ jobs: install: true - name: Build CPU image - uses: docker/build-push-action@v5 + uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 with: file: Dockerfile + target: final-llamacpp platforms: linux/amd64, linux/arm64 build-args: | "LLAMA_SERVER_VERSION=${{ inputs.llamaServerVersion }}" @@ -87,9 +99,10 @@ jobs: tags: ${{ steps.tags.outputs.cpu }} - name: Build CUDA image - uses: docker/build-push-action@v5 + uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 with: file: Dockerfile + target: final-llamacpp platforms: linux/amd64, linux/arm64 build-args: | "LLAMA_SERVER_VERSION=${{ inputs.llamaServerVersion }}" @@ -99,3 +112,19 @@ jobs: sbom: true provenance: mode=max tags: ${{ steps.tags.outputs.cuda }} + + - name: Build vLLM CUDA image + uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 + with: + file: Dockerfile + target: final-vllm + platforms: linux/amd64 + build-args: | + "LLAMA_SERVER_VERSION=${{ inputs.llamaServerVersion }}" + "LLAMA_SERVER_VARIANT=cuda" + "BASE_IMAGE=nvidia/cuda:12.9.0-runtime-ubuntu24.04" + "VLLM_VERSION=${{ inputs.vllmVersion }}" + push: true + sbom: true + provenance: mode=max + tags: ${{ steps.tags.outputs.vllm-cuda }} diff --git a/Dockerfile b/Dockerfile index c4378eb9c..e71412318 100644 --- a/Dockerfile +++ b/Dockerfile @@ -35,7 +35,7 @@ RUN --mount=type=cache,target=/go/pkg/mod \ FROM docker/docker-model-backend-llamacpp:${LLAMA_SERVER_VERSION}-${LLAMA_SERVER_VARIANT} AS llama-server # --- Final image --- -FROM docker.io/${BASE_IMAGE} AS final +FROM docker.io/${BASE_IMAGE} AS llamacpp ARG LLAMA_SERVER_VARIANT @@ -55,9 +55,6 @@ RUN mkdir -p /var/run/model-runner /app/bin /models && \ chown -R modelrunner:modelrunner /var/run/model-runner /app /models && \ chmod -R 755 /models -# Copy the built binary from builder -COPY --from=builder /app/model-runner /app/model-runner - # Copy the llama.cpp binary from the llama-server stage ARG LLAMA_BINARY_PATH COPY --from=llama-server ${LLAMA_BINARY_PATH}/ /app/. @@ -77,3 +74,31 @@ ENV LD_LIBRARY_PATH=/app/lib LABEL com.docker.desktop.service="model-runner" ENTRYPOINT ["/app/model-runner"] + +# --- vLLM variant --- +FROM llamacpp AS vllm + +ARG VLLM_VERSION + +USER root + +RUN apt update && apt install -y python3 python3-venv python3-dev curl ca-certificates build-essential && rm -rf /var/lib/apt/lists/* + +RUN mkdir -p /opt/vllm-env && chown -R modelrunner:modelrunner /opt/vllm-env + +USER modelrunner + +# Install uv and vLLM as modelrunner user +RUN curl -LsSf https://astral.sh/uv/install.sh | sh \ + && ~/.local/bin/uv venv --python /usr/bin/python3 /opt/vllm-env \ + && ~/.local/bin/uv pip install --python /opt/vllm-env/bin/python "vllm==${VLLM_VERSION}" + +RUN /opt/vllm-env/bin/python -c "import vllm; print(vllm.__version__)" > /opt/vllm-env/version + +FROM llamacpp AS final-llamacpp +# Copy the built binary from builder +COPY --from=builder /app/model-runner /app/model-runner + +FROM vllm AS final-vllm +# Copy the built binary from builder +COPY --from=builder /app/model-runner /app/model-runner diff --git a/cmd/cli/commands/backend.go b/cmd/cli/commands/backend.go index c13a74f93..a6cbeb5ae 100644 --- a/cmd/cli/commands/backend.go +++ b/cmd/cli/commands/backend.go @@ -13,6 +13,7 @@ import ( var ValidBackends = map[string]bool{ "llama.cpp": true, "openai": true, + "vllm": true, } // validateBackend checks if the provided backend is valid diff --git a/cmd/cli/docs/reference/docker_model_list.yaml b/cmd/cli/docs/reference/docker_model_list.yaml index fe6951682..778cd4de5 100644 --- a/cmd/cli/docs/reference/docker_model_list.yaml +++ b/cmd/cli/docs/reference/docker_model_list.yaml @@ -8,7 +8,7 @@ plink: docker_model.yaml options: - option: backend value_type: string - description: Specify the backend to use (llama.cpp, openai) + description: Specify the backend to use (llama.cpp, openai, vllm) deprecated: false hidden: true experimental: false diff --git a/cmd/cli/docs/reference/docker_model_run.yaml b/cmd/cli/docs/reference/docker_model_run.yaml index 6edbcbb56..13ad91ff3 100644 --- a/cmd/cli/docs/reference/docker_model_run.yaml +++ b/cmd/cli/docs/reference/docker_model_run.yaml @@ -12,7 +12,7 @@ plink: docker_model.yaml options: - option: backend value_type: string - description: Specify the backend to use (llama.cpp, openai) + description: Specify the backend to use (llama.cpp, openai, vllm) deprecated: false hidden: true experimental: false diff --git a/main.go b/main.go index 1c64b8681..b373c296a 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,7 @@ import ( "github.com/docker/model-runner/pkg/gpuinfo" "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/backends/llamacpp" + "github.com/docker/model-runner/pkg/inference/backends/vllm" "github.com/docker/model-runner/pkg/inference/config" "github.com/docker/model-runner/pkg/inference/memory" "github.com/docker/model-runner/pkg/inference/models" @@ -119,9 +120,19 @@ func main() { memEstimator.SetDefaultBackend(llamaCppBackend) + vllmBackend, err := vllm.New( + log, + modelManager, + log.WithFields(logrus.Fields{"component": "vllm"}), + nil, + ) + if err != nil { + log.Fatalf("unable to initialize %s backend: %v", vllm.Name, err) + } + scheduler := scheduling.NewScheduler( log, - map[string]inference.Backend{llamacpp.Name: llamaCppBackend}, + map[string]inference.Backend{llamacpp.Name: llamaCppBackend, vllm.Name: vllmBackend}, llamaCppBackend, modelManager, http.DefaultClient, diff --git a/pkg/distribution/distribution/client.go b/pkg/distribution/distribution/client.go index 8b51c9010..eb987e54d 100644 --- a/pkg/distribution/distribution/client.go +++ b/pkg/distribution/distribution/client.go @@ -6,8 +6,9 @@ import ( "fmt" "io" "net/http" + "slices" - "github.com/docker/model-runner/pkg/distribution/internal/utils" + "github.com/docker/model-runner/pkg/internal/utils" "github.com/sirupsen/logrus" "github.com/docker/model-runner/pkg/distribution/internal/progress" @@ -15,6 +16,7 @@ import ( "github.com/docker/model-runner/pkg/distribution/registry" "github.com/docker/model-runner/pkg/distribution/tarball" "github.com/docker/model-runner/pkg/distribution/types" + "github.com/docker/model-runner/pkg/inference/platform" ) // Client provides model distribution functionality @@ -408,6 +410,13 @@ func (c *Client) GetBundle(ref string) (types.ModelBundle, error) { return c.store.BundleForModel(ref) } +func GetSupportedFormats() []types.Format { + if platform.SupportsVLLM() { + return []types.Format{types.FormatGGUF, types.FormatSafetensors} + } + return []types.Format{types.FormatGGUF} +} + func checkCompat(image types.ModelArtifact) error { manifest, err := image.Manifest() if err != nil { @@ -423,7 +432,7 @@ func checkCompat(image types.ModelArtifact) error { return fmt.Errorf("reading model config: %w", err) } - if config.Format == types.FormatSafetensors { + if !slices.Contains(GetSupportedFormats(), config.Format) { return ErrUnsupportedFormat } diff --git a/pkg/distribution/distribution/client_test.go b/pkg/distribution/distribution/client_test.go index c3a982c89..12fc7d663 100644 --- a/pkg/distribution/distribution/client_test.go +++ b/pkg/distribution/distribution/client_test.go @@ -26,6 +26,7 @@ import ( "github.com/docker/model-runner/pkg/distribution/internal/progress" "github.com/docker/model-runner/pkg/distribution/internal/safetensors" mdregistry "github.com/docker/model-runner/pkg/distribution/registry" + "github.com/docker/model-runner/pkg/inference/platform" ) var ( @@ -418,7 +419,7 @@ func TestClientPullModel(t *testing.T) { } }) - t.Run("pull safetensors model returns error", func(t *testing.T) { + t.Run("pull safetensors model returns error on unsupported platforms", func(t *testing.T) { // Create temp directory for the safetensors file tempDir, err := os.MkdirTemp("", "safetensors-test-*") if err != nil { @@ -461,10 +462,18 @@ func TestClientPullModel(t *testing.T) { t.Fatalf("Failed to create test client: %v", err) } - // Try to pull the safetensors model - should fail with ErrUnsupportedFormat + // Try to pull the safetensors model err = testClient.PullModel(context.Background(), tag, nil) - if !errors.Is(err, ErrUnsupportedFormat) { - t.Fatalf("Expected ErrUnsupportedFormat, got: %v", err) + if platform.SupportsVLLM() { + // On Linux, safetensors should be supported + if err != nil { + t.Fatalf("Expected no error on Linux, got: %v", err) + } + } else { + // On non-Linux, should fail with ErrUnsupportedFormat + if !errors.Is(err, ErrUnsupportedFormat) { + t.Fatalf("Expected ErrUnsupportedFormat on non-Linux platforms, got: %v", err) + } } }) diff --git a/pkg/distribution/internal/utils/utils.go b/pkg/distribution/internal/utils/utils.go index 188df771a..74025471e 100644 --- a/pkg/distribution/internal/utils/utils.go +++ b/pkg/distribution/internal/utils/utils.go @@ -1,94 +1,9 @@ package utils import ( - "fmt" "io" - "net/http" - "net/url" - "os" - "strings" - "unicode" ) -// FormatBytes converts bytes to a human-readable string with appropriate unit -func FormatBytes(bytes int) string { - size := float64(bytes) - var unit string - switch { - case size >= 1<<30: - size /= 1 << 30 - unit = "GB" - case size >= 1<<20: - size /= 1 << 20 - unit = "MB" - case size >= 1<<10: - size /= 1 << 10 - unit = "KB" - default: - unit = "bytes" - } - return fmt.Sprintf("%.2f %s", size, unit) -} - -// ShowProgress displays a progress bar for data transfer operations -func ShowProgress(operation string, progressChan chan int64, totalSize int64) { - for bytesComplete := range progressChan { - if totalSize > 0 { - mbComplete := float64(bytesComplete) / (1024 * 1024) - mbTotal := float64(totalSize) / (1024 * 1024) - fmt.Printf("\r%s: %.2f MB / %.2f MB", operation, mbComplete, mbTotal) - } else { - mb := float64(bytesComplete) / (1024 * 1024) - fmt.Printf("\r%s: %.2f MB", operation, mb) - } - } - fmt.Println() // Move to new line after progress -} - -// ReadContent reads content from a local file or URL -func ReadContent(source string) ([]byte, error) { - // Check if the source is a URL - if strings.HasPrefix(source, "http://") || strings.HasPrefix(source, "https://") { - // Parse the URL - _, err := url.Parse(source) - if err != nil { - return nil, fmt.Errorf("invalid URL: %v", err) - } - - // Make HTTP request - resp, err := http.Get(source) - if err != nil { - return nil, fmt.Errorf("failed to download file: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("failed to download file: HTTP status %d", resp.StatusCode) - } - - // Create progress reader - contentLength := resp.ContentLength - progressChan := make(chan int64, 100) - - // Start progress reporting goroutine - go ShowProgress("Downloading", progressChan, contentLength) - - // Create a wrapper reader to track progress - progressReader := &ProgressReader{ - Reader: resp.Body, - ProgressChan: progressChan, - } - - // Read the content - content, err := io.ReadAll(progressReader) - close(progressChan) - return content, err - } - - // If not a URL, treat as local file path - return os.ReadFile(source) -} - // ProgressReader wraps an io.Reader to track reading progress type ProgressReader struct { Reader io.Reader @@ -104,48 +19,3 @@ func (pr *ProgressReader) Read(p []byte) (int, error) { } return n, err } - -// SanitizeForLog sanitizes a string for safe logging by removing or escaping -// control characters that could cause log injection attacks. -// TODO: Consider migrating to structured logging which -// handles sanitization automatically through field encoding. -func SanitizeForLog(s string) string { - if s == "" { - return "" - } - - var result strings.Builder - result.Grow(len(s)) - - for _, r := range s { - switch { - // Replace newlines and carriage returns with escaped versions. - case r == '\n': - result.WriteString("\\n") - case r == '\r': - result.WriteString("\\r") - case r == '\t': - result.WriteString("\\t") - // Remove other control characters (0x00-0x1F, 0x7F). - case unicode.IsControl(r): - // Skip control characters or replace with placeholder. - result.WriteString("?") - // Escape backslashes to prevent escape sequence injection. - case r == '\\': - result.WriteString("\\\\") - // Keep printable characters. - case unicode.IsPrint(r): - result.WriteRune(r) - default: - // Replace non-printable characters with placeholder. - result.WriteString("?") - } - } - - const maxLength = 100 - if result.Len() > maxLength { - return result.String()[:maxLength] + "...[truncated]" - } - - return result.String() -} diff --git a/pkg/inference/backend.go b/pkg/inference/backend.go index 20ddda2a9..50258aa17 100644 --- a/pkg/inference/backend.go +++ b/pkg/inference/backend.go @@ -84,7 +84,7 @@ type Backend interface { // to be loaded. Backends should not load multiple models at once and should // instead load only the specified model. Backends should still respond to // OpenAI API requests for other models with a 421 error code. - Run(ctx context.Context, socket, model string, mode BackendMode, config *BackendConfiguration) error + Run(ctx context.Context, socket, model string, modelRef string, mode BackendMode, config *BackendConfiguration) error // Status returns a description of the backend's state. Status() string // GetDiskUsage returns the disk usage of the backend. diff --git a/pkg/inference/backends/llamacpp/llamacpp.go b/pkg/inference/backends/llamacpp/llamacpp.go index 2e5030201..7fa513899 100644 --- a/pkg/inference/backends/llamacpp/llamacpp.go +++ b/pkg/inference/backends/llamacpp/llamacpp.go @@ -24,6 +24,7 @@ import ( "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/config" "github.com/docker/model-runner/pkg/inference/models" + "github.com/docker/model-runner/pkg/internal/utils" "github.com/docker/model-runner/pkg/logging" "github.com/docker/model-runner/pkg/sandbox" "github.com/docker/model-runner/pkg/tailbuffer" @@ -133,7 +134,7 @@ func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error { } // Run implements inference.Backend.Run. -func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference.BackendMode, config *inference.BackendConfiguration) error { +func (l *llamaCpp) Run(ctx context.Context, socket, model string, _ string, mode inference.BackendMode, config *inference.BackendConfiguration) error { bundle, err := l.modelManager.GetBundle(model) if err != nil { return fmt.Errorf("failed to get model: %w", err) @@ -154,7 +155,12 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference return fmt.Errorf("failed to get args for llama.cpp: %w", err) } - l.log.Infof("llamaCppArgs: %v", args) + // Sanitize args for safe logging + sanitizedArgs := make([]string, len(args)) + for i, arg := range args { + sanitizedArgs[i] = utils.SanitizeForLog(arg) + } + l.log.Infof("llamaCppArgs: %v", sanitizedArgs) tailBuf := tailbuffer.NewTailBuffer(1024) serverLogStream := l.serverLog.Writer() out := io.MultiWriter(serverLogStream, tailBuf) diff --git a/pkg/inference/backends/mlx/mlx.go b/pkg/inference/backends/mlx/mlx.go index 38e32136b..df080a3aa 100644 --- a/pkg/inference/backends/mlx/mlx.go +++ b/pkg/inference/backends/mlx/mlx.go @@ -49,7 +49,7 @@ func (m *mlx) Install(ctx context.Context, httpClient *http.Client) error { } // Run implements inference.Backend.Run. -func (m *mlx) Run(ctx context.Context, socket, model string, mode inference.BackendMode, config *inference.BackendConfiguration) error { +func (m *mlx) Run(ctx context.Context, socket, model string, modelRef string, mode inference.BackendMode, config *inference.BackendConfiguration) error { // TODO: Implement. m.log.Warn("MLX backend is not yet supported") return errors.New("not implemented") diff --git a/pkg/inference/backends/vllm/vllm.go b/pkg/inference/backends/vllm/vllm.go index d78caa587..00ef5e75b 100644 --- a/pkg/inference/backends/vllm/vllm.go +++ b/pkg/inference/backends/vllm/vllm.go @@ -3,16 +3,30 @@ package vllm import ( "context" "errors" + "fmt" + "io" + "io/fs" "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "github.com/docker/model-runner/pkg/diskusage" "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/models" + "github.com/docker/model-runner/pkg/inference/platform" + "github.com/docker/model-runner/pkg/internal/utils" "github.com/docker/model-runner/pkg/logging" + "github.com/docker/model-runner/pkg/sandbox" + "github.com/docker/model-runner/pkg/tailbuffer" ) const ( // Name is the backend name. - Name = "vllm" + Name = "vllm" + vllmDir = "/opt/vllm-env/bin" ) // vLLM is the vLLM-based backend implementation. @@ -21,13 +35,27 @@ type vLLM struct { log logging.Logger // modelManager is the shared model manager. modelManager *models.Manager + // serverLog is the logger to use for the vLLM server process. + serverLog logging.Logger + // config is the configuration for the vLLM backend. + config *Config + // status is the state in which the vLLM backend is in. + status string } // New creates a new vLLM-based backend. -func New(log logging.Logger, modelManager *models.Manager) (inference.Backend, error) { +func New(log logging.Logger, modelManager *models.Manager, serverLog logging.Logger, conf *Config) (inference.Backend, error) { + // If no config is provided, use the default configuration + if conf == nil { + conf = NewDefaultVLLMConfig() + } + return &vLLM{ log: log, modelManager: modelManager, + serverLog: serverLog, + config: conf, + status: "not installed", }, nil } @@ -36,33 +64,154 @@ func (v *vLLM) Name() string { return Name } -// UsesExternalModelManagement implements -// inference.Backend.UsesExternalModelManagement. -func (l *vLLM) UsesExternalModelManagement() bool { +func (v *vLLM) UsesExternalModelManagement() bool { return false } -// Install implements inference.Backend.Install. -func (v *vLLM) Install(ctx context.Context, httpClient *http.Client) error { - // TODO: Implement. - return errors.New("not implemented") +func (v *vLLM) Install(_ context.Context, _ *http.Client) error { + if !platform.SupportsVLLM() { + return errors.New("not implemented") + } + + vllmBinaryPath := v.binaryPath() + if _, err := os.Stat(vllmBinaryPath); err != nil { + if errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("vLLM binary not found at %s", vllmBinaryPath) + } + return fmt.Errorf("failed to check vLLM binary: %w", err) + } + + // Read vLLM version from file (created in Dockerfile via `print(vllm.__version__)`). + versionPath := filepath.Join(filepath.Dir(vllmDir), "version") + versionBytes, err := os.ReadFile(versionPath) + if err != nil { + v.log.Warnf("could not get vllm version: %v", err) + v.status = "running vllm version: unknown" + } else { + v.status = fmt.Sprintf("running vllm version: %s", strings.TrimSpace(string(versionBytes))) + } + + return nil } -// Run implements inference.Backend.Run. -func (v *vLLM) Run(ctx context.Context, socket, model string, mode inference.BackendMode, config *inference.BackendConfiguration) error { - // TODO: Implement. - v.log.Warn("vLLM backend is not yet supported") - return errors.New("not implemented") +func (v *vLLM) Run(ctx context.Context, socket, model string, modelRef string, mode inference.BackendMode, backendConfig *inference.BackendConfiguration) error { + if !platform.SupportsVLLM() { + v.log.Warn("vLLM backend is not yet supported") + return errors.New("not implemented") + } + + bundle, err := v.modelManager.GetBundle(model) + if err != nil { + return fmt.Errorf("failed to get model: %w", err) + } + + if err := os.RemoveAll(socket); err != nil && !errors.Is(err, fs.ErrNotExist) { + v.log.Warnf("failed to remove socket file %s: %v\n", socket, err) + v.log.Warnln("vLLM may not be able to start") + } + + // Get arguments from config + args, err := v.config.GetArgs(bundle, socket, mode, backendConfig) + if err != nil { + return fmt.Errorf("failed to get vLLM arguments: %w", err) + } + + // Add served model name + args = append(args, "--served-model-name", model, modelRef) + + // Sanitize args for safe logging + sanitizedArgs := make([]string, len(args)) + for i, arg := range args { + sanitizedArgs[i] = utils.SanitizeForLog(arg) + } + v.log.Infof("vLLM args: %v", sanitizedArgs) + tailBuf := tailbuffer.NewTailBuffer(1024) + serverLogStream := v.serverLog.Writer() + out := io.MultiWriter(serverLogStream, tailBuf) + vllmSandbox, err := sandbox.Create( + ctx, + "", + func(command *exec.Cmd) { + command.Cancel = func() error { + if runtime.GOOS == "windows" { + return command.Process.Kill() + } + return command.Process.Signal(os.Interrupt) + } + command.Stdout = serverLogStream + command.Stderr = out + }, + vllmDir, + v.binaryPath(), + args..., + ) + if err != nil { + return fmt.Errorf("unable to start vLLM: %w", err) + } + defer vllmSandbox.Close() + + vllmErrors := make(chan error, 1) + go func() { + vllmErr := vllmSandbox.Command().Wait() + serverLogStream.Close() + + errOutput := new(strings.Builder) + if _, err := io.Copy(errOutput, tailBuf); err != nil { + v.log.Warnf("failed to read server output tail: %v", err) + } + + if len(errOutput.String()) != 0 { + vllmErr = fmt.Errorf("vLLM exit status: %w\nwith output: %s", vllmErr, errOutput.String()) + } else { + vllmErr = fmt.Errorf("vLLM exit status: %w", vllmErr) + } + + vllmErrors <- vllmErr + close(vllmErrors) + if err := os.Remove(socket); err != nil && !errors.Is(err, fs.ErrNotExist) { + v.log.Warnf("failed to remove socket file %s on exit: %v\n", socket, err) + } + }() + defer func() { + <-vllmErrors + }() + + select { + case <-ctx.Done(): + return nil + case vllmErr := <-vllmErrors: + select { + case <-ctx.Done(): + return nil + default: + } + return fmt.Errorf("vLLM terminated unexpectedly: %w", vllmErr) + } } func (v *vLLM) Status() string { - return "not running" + return v.status } func (v *vLLM) GetDiskUsage() (int64, error) { - return 0, nil + size, err := diskusage.Size(vllmDir) + if err != nil { + return 0, fmt.Errorf("error while getting store size: %v", err) + } + return size, nil +} + +func (v *vLLM) GetRequiredMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (inference.RequiredMemory, error) { + if !platform.SupportsVLLM() { + return inference.RequiredMemory{}, errors.New("not implemented") + } + + return inference.RequiredMemory{ + RAM: 1, + VRAM: 1, + }, nil } -func (v *vLLM) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (inference.RequiredMemory, error) { - return inference.RequiredMemory{}, errors.New("not implemented") +func (v *vLLM) binaryPath() string { + return filepath.Join(vllmDir, "vllm") } diff --git a/pkg/inference/backends/vllm/vllm_config.go b/pkg/inference/backends/vllm/vllm_config.go new file mode 100644 index 000000000..741126d80 --- /dev/null +++ b/pkg/inference/backends/vllm/vllm_config.go @@ -0,0 +1,82 @@ +package vllm + +import ( + "fmt" + "path/filepath" + "strconv" + + "github.com/docker/model-runner/pkg/distribution/types" + "github.com/docker/model-runner/pkg/inference" +) + +// Config is the configuration for the vLLM backend. +type Config struct { + // Args are the base arguments that are always included. + Args []string +} + +// NewDefaultVLLMConfig creates a new VLLMConfig with default values. +func NewDefaultVLLMConfig() *Config { + return &Config{ + Args: []string{}, + } +} + +// GetArgs implements BackendConfig.GetArgs. +func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error) { + // Start with the arguments from VLLMConfig + args := append([]string{}, c.Args...) + + // Add the serve command and model path (use directory for safetensors) + safetensorsPath := bundle.SafetensorsPath() + if safetensorsPath == "" { + return nil, fmt.Errorf("safetensors path required by vLLM backend") + } + modelPath := filepath.Dir(safetensorsPath) + // vLLM expects the directory containing the safetensors files + args = append(args, "serve", modelPath) + + // Add socket arguments + args = append(args, "--uds", socket) + + // Add mode-specific arguments + switch mode { + case inference.BackendModeCompletion: + // Default mode for vLLM + case inference.BackendModeEmbedding: + // vLLM doesn't have a specific embedding flag like llama.cpp + // Embedding models are detected automatically + default: + return nil, fmt.Errorf("unsupported backend mode %q", mode) + } + + // Add max-model-len if specified in model config or backend config + if maxLen := GetMaxModelLen(bundle.RuntimeConfig(), config); maxLen != nil { + args = append(args, "--max-model-len", strconv.FormatUint(*maxLen, 10)) + } + // If nil, vLLM will automatically derive from the model config + + // Add arguments from backend config + if config != nil { + args = append(args, config.RuntimeFlags...) + } + + return args, nil +} + +// GetMaxModelLen returns the max model length (context size) from model config or backend config. +// Model config takes precedence over backend config. +// Returns nil if neither is specified (vLLM will auto-derive from model). +func GetMaxModelLen(modelCfg types.Config, backendCfg *inference.BackendConfiguration) *uint64 { + // Model config takes precedence + if modelCfg.ContextSize != nil { + return modelCfg.ContextSize + } + // else use backend config + if backendCfg != nil && backendCfg.ContextSize > 0 { + val := uint64(backendCfg.ContextSize) + return &val + } + // Return nil to let vLLM auto-derive from model config + return nil +} diff --git a/pkg/inference/backends/vllm/vllm_config_test.go b/pkg/inference/backends/vllm/vllm_config_test.go new file mode 100644 index 000000000..e6bd598b5 --- /dev/null +++ b/pkg/inference/backends/vllm/vllm_config_test.go @@ -0,0 +1,209 @@ +package vllm + +import ( + "testing" + + "github.com/docker/model-runner/pkg/distribution/types" + "github.com/docker/model-runner/pkg/inference" +) + +type mockModelBundle struct { + safetensorsPath string + runtimeConfig types.Config +} + +func (m *mockModelBundle) GGUFPath() string { + return "" +} + +func (m *mockModelBundle) SafetensorsPath() string { + return m.safetensorsPath +} + +func (m *mockModelBundle) ChatTemplatePath() string { + return "" +} + +func (m *mockModelBundle) MMPROJPath() string { + return "" +} + +func (m *mockModelBundle) RuntimeConfig() types.Config { + return m.runtimeConfig +} + +func (m *mockModelBundle) RootDir() string { + return "/path/to/bundle" +} + +func TestGetArgs(t *testing.T) { + tests := []struct { + name string + config *inference.BackendConfiguration + bundle *mockModelBundle + expected []string + expectError bool + }{ + { + name: "empty safetensors path should error", + bundle: &mockModelBundle{ + safetensorsPath: "", + }, + config: nil, + expected: nil, + expectError: true, + }, + { + name: "basic args without context size", + bundle: &mockModelBundle{ + safetensorsPath: "/path/to/model", + }, + config: nil, + expected: []string{ + "serve", + "/path/to", + "--uds", + "/tmp/socket", + }, + }, + { + name: "with backend context size", + bundle: &mockModelBundle{ + safetensorsPath: "/path/to/model", + }, + config: &inference.BackendConfiguration{ + ContextSize: 8192, + }, + expected: []string{ + "serve", + "/path/to", + "--uds", + "/tmp/socket", + "--max-model-len", + "8192", + }, + }, + { + name: "with runtime flags", + bundle: &mockModelBundle{ + safetensorsPath: "/path/to/model", + }, + config: &inference.BackendConfiguration{ + RuntimeFlags: []string{"--gpu-memory-utilization", "0.9"}, + }, + expected: []string{ + "serve", + "/path/to", + "--uds", + "/tmp/socket", + "--gpu-memory-utilization", + "0.9", + }, + }, + { + name: "with model context size (takes precedence)", + bundle: &mockModelBundle{ + safetensorsPath: "/path/to/model", + runtimeConfig: types.Config{ + ContextSize: ptrUint64(16384), + }, + }, + config: &inference.BackendConfiguration{ + ContextSize: 8192, + }, + expected: []string{ + "serve", + "/path/to", + "--uds", + "/tmp/socket", + "--max-model-len", + "16384", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := NewDefaultVLLMConfig() + args, err := config.GetArgs(tt.bundle, "/tmp/socket", inference.BackendModeCompletion, tt.config) + + if tt.expectError { + if err == nil { + t.Fatalf("expected error but got none") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(args) != len(tt.expected) { + t.Fatalf("expected %d args, got %d\nexpected: %v\ngot: %v", len(tt.expected), len(args), tt.expected, args) + } + + for i, arg := range args { + if arg != tt.expected[i] { + t.Errorf("arg[%d]: expected %q, got %q", i, tt.expected[i], arg) + } + } + }) + } +} + +func TestGetMaxModelLen(t *testing.T) { + tests := []struct { + name string + modelCfg types.Config + backendCfg *inference.BackendConfiguration + expectedValue *uint64 + }{ + { + name: "no config", + modelCfg: types.Config{}, + backendCfg: nil, + expectedValue: nil, + }, + { + name: "backend config only", + modelCfg: types.Config{}, + backendCfg: &inference.BackendConfiguration{ + ContextSize: 4096, + }, + expectedValue: ptrUint64(4096), + }, + { + name: "model config only", + modelCfg: types.Config{ + ContextSize: ptrUint64(8192), + }, + backendCfg: nil, + expectedValue: ptrUint64(8192), + }, + { + name: "model config takes precedence", + modelCfg: types.Config{ + ContextSize: ptrUint64(16384), + }, + backendCfg: &inference.BackendConfiguration{ + ContextSize: 4096, + }, + expectedValue: ptrUint64(16384), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetMaxModelLen(tt.modelCfg, tt.backendCfg) + if (result == nil) != (tt.expectedValue == nil) { + t.Errorf("expected nil=%v, got nil=%v", tt.expectedValue == nil, result == nil) + } else if result != nil && *result != *tt.expectedValue { + t.Errorf("expected %d, got %d", *tt.expectedValue, *result) + } + }) + } +} + +func ptrUint64(v uint64) *uint64 { + return &v +} diff --git a/pkg/inference/platform/platform.go b/pkg/inference/platform/platform.go new file mode 100644 index 000000000..74c87e7be --- /dev/null +++ b/pkg/inference/platform/platform.go @@ -0,0 +1,8 @@ +package platform + +import "runtime" + +// SupportsVLLM returns true if vLLM is supported on the current platform. +func SupportsVLLM() bool { + return runtime.GOOS == "linux" +} diff --git a/pkg/inference/scheduling/loader.go b/pkg/inference/scheduling/loader.go index 3116f04be..bac7dd09b 100644 --- a/pkg/inference/scheduling/loader.go +++ b/pkg/inference/scheduling/loader.go @@ -469,7 +469,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string return l.slots[existing.slot], nil } } - + if runtime.GOOS == "windows" { // On Windows, we can use up to half of the total system RAM as shared GPU memory, // limited by the currently available RAM. @@ -504,7 +504,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string } // Create the runner. l.log.Infof("Loading %s backend runner with model %s in %s mode", backendName, modelID, mode) - runner, err := run(l.log, backend, modelID, mode, slot, runnerConfig, l.openAIRecorder) + runner, err := run(l.log, backend, modelID, modelRef, mode, slot, runnerConfig, l.openAIRecorder) if err != nil { l.log.Warnf("Unable to start %s backend runner with model %s in %s mode: %v", backendName, modelID, mode, err, diff --git a/pkg/inference/scheduling/runner.go b/pkg/inference/scheduling/runner.go index 985fea48e..212df1623 100644 --- a/pkg/inference/scheduling/runner.go +++ b/pkg/inference/scheduling/runner.go @@ -14,6 +14,7 @@ import ( "time" "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/internal/utils" "github.com/docker/model-runner/pkg/logging" "github.com/docker/model-runner/pkg/metrics" ) @@ -74,7 +75,8 @@ type runner struct { func run( log logging.Logger, backend inference.Backend, - model string, + modelID string, + modelRef string, mode inference.BackendMode, slot int, runnerConfig *inference.BackendConfiguration, @@ -136,7 +138,7 @@ func run( r := &runner{ log: log, backend: backend, - model: model, + model: modelID, mode: mode, cancel: runCancel, done: runDone, @@ -175,13 +177,13 @@ func run( } } - r.openAIRecorder.SetConfigForModel(model, runnerConfig) + r.openAIRecorder.SetConfigForModel(modelID, runnerConfig) // Start the backend run loop. go func() { - if err := backend.Run(runCtx, socket, model, mode, runnerConfig); err != nil { + if err := backend.Run(runCtx, socket, modelID, modelRef, mode, runnerConfig); err != nil { log.Warnf("Backend %s running model %s exited with error: %v", - backend.Name(), model, err, + backend.Name(), utils.SanitizeForLog(modelRef), err, ) r.err = err } diff --git a/pkg/inference/scheduling/scheduler.go b/pkg/inference/scheduling/scheduler.go index 3d3bdfd22..e3b87b7f2 100644 --- a/pkg/inference/scheduling/scheduler.go +++ b/pkg/inference/scheduling/scheduler.go @@ -13,9 +13,12 @@ import ( "time" "github.com/docker/model-runner/pkg/distribution/distribution" + "github.com/docker/model-runner/pkg/distribution/types" "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/backends/vllm" "github.com/docker/model-runner/pkg/inference/memory" "github.com/docker/model-runner/pkg/inference/models" + "github.com/docker/model-runner/pkg/internal/utils" "github.com/docker/model-runner/pkg/logging" "github.com/docker/model-runner/pkg/metrics" "github.com/docker/model-runner/pkg/middleware" @@ -181,26 +184,6 @@ func (s *Scheduler) handleOpenAIInference(w http.ResponseWriter, r *http.Request return } - // Wait for the corresponding backend installation to complete or fail. We - // don't allow any requests to be scheduled for a backend until it has - // completed installation. - if err := s.installer.wait(r.Context(), backend.Name()); err != nil { - if errors.Is(err, ErrBackendNotFound) { - http.Error(w, err.Error(), http.StatusNotFound) - } else if errors.Is(err, errInstallerNotStarted) { - http.Error(w, err.Error(), http.StatusServiceUnavailable) - } else if errors.Is(err, context.Canceled) { - // This could be due to the client aborting the request (in which - // case this response will be ignored) or the inference service - // shutting down (since that will also cancel the request context). - // Either way, provide a response, even if it's ignored. - http.Error(w, "service unavailable", http.StatusServiceUnavailable) - } else { - http.Error(w, fmt.Errorf("backend installation failed: %w", err).Error(), http.StatusServiceUnavailable) - } - return - } - // Determine the backend operation mode. backendMode, ok := backendModeForRequest(r.URL.Path) if !ok { @@ -232,6 +215,42 @@ func (s *Scheduler) handleOpenAIInference(w http.ResponseWriter, r *http.Request } // Non-blocking call to track the model usage. s.tracker.TrackModel(model, r.UserAgent(), "inference/"+backendMode.String()) + + // Automatically identify models for vLLM. + config, err := model.Config() + if err != nil { + s.log.Warnln("failed to fetch model config:", err) + } else { + if config.Format == types.FormatSafetensors { + if vllmBackend, ok := s.backends[vllm.Name]; ok { + backend = vllmBackend + } else { + s.log.Warnf("Model %s is in safetensors format but vLLM backend is not available. "+ + "Backend %s may not support this format and could fail at runtime.", + utils.SanitizeForLog(request.Model), backend.Name()) + } + } + } + } + + // Wait for the corresponding backend installation to complete or fail. We + // don't allow any requests to be scheduled for a backend until it has + // completed installation. + if err := s.installer.wait(r.Context(), backend.Name()); err != nil { + if errors.Is(err, ErrBackendNotFound) { + http.Error(w, err.Error(), http.StatusNotFound) + } else if errors.Is(err, errInstallerNotStarted) { + http.Error(w, err.Error(), http.StatusServiceUnavailable) + } else if errors.Is(err, context.Canceled) { + // This could be due to the client aborting the request (in which + // case this response will be ignored) or the inference service + // shutting down (since that will also cancel the request context). + // Either way, provide a response, even if it's ignored. + http.Error(w, "service unavailable", http.StatusServiceUnavailable) + } else { + http.Error(w, fmt.Errorf("backend installation failed: %w", err).Error(), http.StatusServiceUnavailable) + } + return } modelID := s.modelManager.ResolveModelID(request.Model) diff --git a/pkg/internal/utils/log.go b/pkg/internal/utils/log.go new file mode 100644 index 000000000..bb63f2220 --- /dev/null +++ b/pkg/internal/utils/log.go @@ -0,0 +1,51 @@ +package utils + +import ( + "strings" + "unicode" +) + +// SanitizeForLog sanitizes a string for safe logging by removing or escaping +// control characters that could cause log injection attacks. +// TODO: Consider migrating to structured logging which +// handles sanitization automatically through field encoding. +func SanitizeForLog(s string) string { + if s == "" { + return "" + } + + var result strings.Builder + result.Grow(len(s)) + + for _, r := range s { + switch { + // Replace newlines and carriage returns with escaped versions. + case r == '\n': + result.WriteString("\\n") + case r == '\r': + result.WriteString("\\r") + case r == '\t': + result.WriteString("\\t") + // Remove other control characters (0x00-0x1F, 0x7F). + case unicode.IsControl(r): + // Skip control characters or replace with placeholder. + result.WriteString("?") + // Escape backslashes to prevent escape sequence injection. + case r == '\\': + result.WriteString("\\\\") + // Keep printable characters. + case unicode.IsPrint(r): + result.WriteRune(r) + default: + // Replace non-printable characters with placeholder. + result.WriteString("?") + } + } + + const maxLength = 100 + if result.Len() > maxLength { + return result.String()[:maxLength] + "...[truncated]" + } + + return result.String() +}