diff --git a/Dockerfile b/Dockerfile index 7d71f3dc1..d303f1250 100644 --- a/Dockerfile +++ b/Dockerfile @@ -147,6 +147,44 @@ RUN curl -LsSf https://astral.sh/uv/install.sh | sh \ && ~/.local/bin/uv pip install --python /opt/sglang-env/bin/python "sglang==${SGLANG_VERSION}" RUN /opt/sglang-env/bin/python -c "import sglang; print(sglang.__version__)" > /opt/sglang-env/version + +# --- Diffusers variant --- +FROM llamacpp AS diffusers + +ARG DIFFUSERS_VERSION=0.31.0 +ARG TORCH_VERSION=2.5.1 + +USER root + +RUN apt update && apt install -y \ + python3 python3-venv python3-dev \ + curl ca-certificates build-essential \ + && rm -rf /var/lib/apt/lists/* + +RUN mkdir -p /opt/diffusers-env && chown -R modelrunner:modelrunner /opt/diffusers-env + +USER modelrunner + +# Install uv and diffusers as modelrunner user +RUN curl -LsSf https://astral.sh/uv/install.sh | sh \ + && ~/.local/bin/uv venv --python /usr/bin/python3 /opt/diffusers-env \ + && ~/.local/bin/uv pip install --python /opt/diffusers-env/bin/python \ + "diffusers==${DIFFUSERS_VERSION}" \ + "torch==${TORCH_VERSION}" \ + "transformers" \ + "accelerate" \ + "safetensors" \ + "fastapi" \ + "uvicorn[standard]" \ + "pillow" + +# Determine the Python site-packages directory dynamically and copy the Python server code +RUN PYTHON_SITE_PACKAGES=$(/opt/diffusers-env/bin/python -c "import site; print(site.getsitepackages()[0])") && \ + mkdir -p "$(dirname "$PYTHON_SITE_PACKAGES")" && \ + cp -r python/diffusers_server "$PYTHON_SITE_PACKAGES/" + +RUN /opt/diffusers-env/bin/python -c "import diffusers; print(diffusers.__version__)" > /opt/diffusers-env/version + FROM llamacpp AS final-llamacpp # Copy the built binary from builder COPY --from=builder /app/model-runner /app/model-runner @@ -158,3 +196,7 @@ COPY --from=builder /app/model-runner /app/model-runner FROM sglang AS final-sglang # Copy the built binary from builder-sglang (without vLLM) COPY --from=builder-sglang /app/model-runner /app/model-runner + +FROM diffusers AS final-diffusers +# Copy the built binary from builder +COPY --from=builder /app/model-runner /app/model-runner diff --git a/backends_vllm_stub.go b/backends_vllm_stub.go index dceb094a7..70eed732b 100644 --- a/backends_vllm_stub.go +++ b/backends_vllm_stub.go @@ -9,7 +9,7 @@ import ( ) func initVLLMBackend(log *logrus.Logger, modelManager *models.Manager) (inference.Backend, error) { - return nil, nil + return nil, nil // VLLM backend is disabled } func registerVLLMBackend(backends map[string]inference.Backend, backend inference.Backend) { diff --git a/main.go b/main.go index 36fdeb8ca..f75f107ce 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "fmt" "net" "net/http" "os" @@ -13,6 +14,7 @@ import ( "github.com/docker/model-runner/pkg/anthropic" "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/backends/diffusers" "github.com/docker/model-runner/pkg/inference/backends/llamacpp" "github.com/docker/model-runner/pkg/inference/backends/mlx" "github.com/docker/model-runner/pkg/inference/backends/sglang" @@ -25,20 +27,140 @@ import ( "github.com/docker/model-runner/pkg/ollama" "github.com/docker/model-runner/pkg/responses" "github.com/docker/model-runner/pkg/routing" + "github.com/mattn/go-shellwords" "github.com/sirupsen/logrus" ) +const ( + defaultSocketName = "model-runner.sock" + defaultModelsPath = ".docker/models" + defaultLlamaServerPath = "/Applications/Docker.app/Contents/Resources/model-runner/bin" + socketFileMode = 0o600 + defaultDirectoryMode = 0o755 +) + var log = logrus.New() -func main() { - ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) - defer cancel() +func initializeBackends(log *logrus.Logger, modelManager *models.Manager, llamaServerPath string, llamaCppConfig config.BackendConfig) (map[string]inference.Backend, inference.Backend, error) { + // Initialize llama.cpp backend + llamaCppBackend, err := initLlamaCppBackend(log, modelManager, llamaServerPath, llamaCppConfig) + if err != nil { + return nil, nil, err + } + + // Initialize VLLM backend + vllmBackend, err := initVLLMBackend(log, modelManager) + if err != nil { + return nil, nil, fmt.Errorf("unable to initialize %s backend: %w", vllm.Name, err) + } + // Initialize other backends with explicit error handling + mlxBackend, err := initMlxBackend(log, modelManager) + if err != nil { + return nil, nil, fmt.Errorf("unable to initialize %s backend: %w", mlx.Name, err) + } + + sglangBackend, err := initSglangBackend(log, modelManager) + if err != nil { + return nil, nil, fmt.Errorf("unable to initialize %s backend: %w", sglang.Name, err) + } + + diffusersBackend, err := initDiffusersBackend(log, modelManager) + if err != nil { + return nil, nil, fmt.Errorf("unable to initialize %s backend: %w", diffusers.Name, err) + } + + backends := map[string]inference.Backend{ + llamacpp.Name: llamaCppBackend, + mlx.Name: mlxBackend, + sglang.Name: sglangBackend, + diffusers.Name: diffusersBackend, + } + + // Only register VLLM backend if it was properly initialized (not nil) + if vllmBackend != nil { + registerVLLMBackend(backends, vllmBackend) + } + + return backends, llamaCppBackend, nil +} + +func initLlamaCppBackend(log *logrus.Logger, modelManager *models.Manager, llamaServerPath string, llamaCppConfig config.BackendConfig) (inference.Backend, error) { + backend, err := llamacpp.New( + log, + modelManager, + log.WithFields(logrus.Fields{"component": llamacpp.Name}), + llamaServerPath, + getLlamaCppUpdateDir(log), + llamaCppConfig, + ) + if err != nil { + return nil, fmt.Errorf("unable to initialize %s backend: %w", llamacpp.Name, err) + } + return backend, nil +} + +func initMlxBackend(log *logrus.Logger, modelManager *models.Manager) (inference.Backend, error) { + backend, err := mlx.New( + log, + modelManager, + log.WithFields(logrus.Fields{"component": mlx.Name}), + nil, + ) + if err != nil { + return nil, fmt.Errorf("unable to initialize %s backend: %w", mlx.Name, err) + } + return backend, nil +} + +func initSglangBackend(log *logrus.Logger, modelManager *models.Manager) (inference.Backend, error) { + backend, err := sglang.New( + log, + modelManager, + log.WithFields(logrus.Fields{"component": sglang.Name}), + nil, + ) + if err != nil { + return nil, fmt.Errorf("unable to initialize %s backend: %w", sglang.Name, err) + } + return backend, nil +} + +func initDiffusersBackend(log *logrus.Logger, modelManager *models.Manager) (inference.Backend, error) { + backend, err := diffusers.New( + log, + modelManager, + log.WithFields(logrus.Fields{"component": diffusers.Name}), + nil, + ) + if err != nil { + return nil, fmt.Errorf("unable to initialize %s backend: %w", diffusers.Name, err) + } + return backend, nil +} + +func getLlamaCppUpdateDir(log *logrus.Logger) string { + wd, err := os.Getwd() + if err != nil { + log.Errorf("Failed to get working directory, using current directory: %v", err) + wd = "." + } + d := filepath.Join(wd, "updated-inference", "bin") + if err := os.MkdirAll(d, defaultDirectoryMode); err != nil { + log.Errorf("Failed to create directory %s: %v", d, err) + } + return d +} + +func getSocketName() string { sockName := os.Getenv("MODEL_RUNNER_SOCK") if sockName == "" { - sockName = "model-runner.sock" + sockName = defaultSocketName } + return sockName +} +func getModelPath() string { userHomeDir, err := os.UserHomeDir() if err != nil { log.Fatalf("Failed to get user home directory: %v", err) @@ -46,9 +168,20 @@ func main() { modelPath := os.Getenv("MODELS_PATH") if modelPath == "" { - modelPath = filepath.Join(userHomeDir, ".docker", "models") + modelPath = filepath.Join(userHomeDir, defaultModelsPath) } + return modelPath +} + +func getLlamaServerPath() string { + llamaServerPath := os.Getenv("LLAMA_SERVER_PATH") + if llamaServerPath == "" { + llamaServerPath = defaultLlamaServerPath + } + return llamaServerPath +} +func configureLlamaCpp() error { _, disableServerUpdate := os.LookupEnv("DISABLE_SERVER_UPDATE") if disableServerUpdate { llamacpp.ShouldUpdateServerLock.Lock() @@ -60,92 +193,107 @@ func main() { if ok { llamacpp.SetDesiredServerVersion(desiredServerVersion) } + return nil +} - llamaServerPath := os.Getenv("LLAMA_SERVER_PATH") - if llamaServerPath == "" { - llamaServerPath = "/Applications/Docker.app/Contents/Resources/model-runner/bin" - } - +func createProxyTransport() *http.Transport { // Create a proxy-aware HTTP transport - // Use a safe type assertion with fallback, and explicitly set Proxy to http.ProxyFromEnvironment + // Use a safe type assertion with fallback var baseTransport *http.Transport if t, ok := http.DefaultTransport.(*http.Transport); ok { baseTransport = t.Clone() } else { - baseTransport = &http.Transport{} - } - baseTransport.Proxy = http.ProxyFromEnvironment - - clientConfig := models.ClientConfig{ - StoreRootPath: modelPath, - Logger: log.WithFields(logrus.Fields{"component": "model-manager"}), - Transport: baseTransport, + // Fallback to a default transport if type assertion fails + baseTransport = &http.Transport{ + Proxy: http.ProxyFromEnvironment, + } } - modelManager := models.NewManager(log.WithFields(logrus.Fields{"component": "model-manager"}), clientConfig) - modelHandler := models.NewHTTPHandler( - log, - modelManager, - nil, - ) - log.Infof("LLAMA_SERVER_PATH: %s", llamaServerPath) + return baseTransport +} - // Create llama.cpp configuration from environment variables - llamaCppConfig := createLlamaCppConfigFromEnv() +func main() { + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() - llamaCppBackend, err := llamacpp.New( - log, - modelManager, - log.WithFields(logrus.Fields{"component": llamacpp.Name}), - llamaServerPath, - func() string { - wd, _ := os.Getwd() - d := filepath.Join(wd, "updated-inference", "bin") - _ = os.MkdirAll(d, 0o755) - return d - }(), - llamaCppConfig, - ) + // Initialize configuration and services + config, err := initializeAppConfig() if err != nil { - log.Fatalf("unable to initialize %s backend: %v", llamacpp.Name, err) + log.Fatalf("Failed to initialize app config: %v", err) } - vllmBackend, err := initVLLMBackend(log, modelManager) + // Initialize backends + backends, llamaCppBackend, err := initializeBackends(log, config.modelManager, config.llamaServerPath, config.llamaCppConfig) if err != nil { - log.Fatalf("unable to initialize %s backend: %v", vllm.Name, err) + log.Fatalf("Failed to initialize backends: %v", err) } - mlxBackend, err := mlx.New( - log, - modelManager, - log.WithFields(logrus.Fields{"component": mlx.Name}), - nil, - ) - if err != nil { - log.Fatalf("unable to initialize %s backend: %v", mlx.Name, err) - } + // Create scheduler + scheduler := createScheduler(config, backends, llamaCppBackend) - sglangBackend, err := sglang.New( - log, - modelManager, - log.WithFields(logrus.Fields{"component": sglang.Name}), - nil, - ) - if err != nil { - log.Fatalf("unable to initialize %s backend: %v", sglang.Name, err) + // Setup HTTP handlers + router := setupHTTPHandlers(config, scheduler) + + // Start server + server, serverErrors := startServer(router, config.sockName) + + // Start scheduler + schedulerErrors := make(chan error, 1) + go func() { + schedulerErrors <- scheduler.Run(ctx) + }() + + // Wait for shutdown + waitForShutdown(ctx, server, serverErrors, schedulerErrors) + log.Infoln("Docker Model Runner stopped") +} + +// AppConfig holds the application configuration +type AppConfig struct { + sockName string + modelPath string + llamaServerPath string + modelManager *models.Manager + llamaCppConfig config.BackendConfig +} + +// initializeAppConfig initializes the application configuration +func initializeAppConfig() (*AppConfig, error) { + sockName := getSocketName() + modelPath := getModelPath() + llamaServerPath := getLlamaServerPath() + + if err := configureLlamaCpp(); err != nil { + return nil, fmt.Errorf("failed to configure llama.cpp: %w", err) } - backends := map[string]inference.Backend{ - llamacpp.Name: llamaCppBackend, - mlx.Name: mlxBackend, - sglang.Name: sglangBackend, + baseTransport := createProxyTransport() + clientConfig := models.ClientConfig{ + StoreRootPath: modelPath, + Logger: log.WithFields(logrus.Fields{"component": "model-manager"}), + Transport: baseTransport, } - registerVLLMBackend(backends, vllmBackend) + modelManager := models.NewManager(log.WithFields(logrus.Fields{"component": "model-manager"}), clientConfig) + log.Infof("LLAMA_SERVER_PATH: %s", llamaServerPath) + + // Create llama.cpp configuration from environment variables + llamaCppConfig := createLlamaCppConfigFromEnv() - scheduler := scheduling.NewScheduler( + return &AppConfig{ + sockName: sockName, + modelPath: modelPath, + llamaServerPath: llamaServerPath, + modelManager: modelManager, + llamaCppConfig: llamaCppConfig, + }, nil +} + +// createScheduler creates a new scheduler instance +func createScheduler(config *AppConfig, backends map[string]inference.Backend, llamaCppBackend inference.Backend) *scheduling.Scheduler { + return scheduling.NewScheduler( log, backends, llamaCppBackend, - modelManager, + config.modelManager, http.DefaultClient, metrics.NewTracker( http.DefaultClient, @@ -154,6 +302,15 @@ func main() { false, ), ) +} + +// setupHTTPHandlers sets up all HTTP handlers for the application +func setupHTTPHandlers(config *AppConfig, scheduler *scheduling.Scheduler) *routing.NormalizedServeMux { + modelHandler := models.NewHTTPHandler( + log, + config.modelManager, + nil, + ) // Create the HTTP handler for the scheduler schedulerHTTP := scheduling.NewHTTPHandler(scheduler, modelHandler, nil) @@ -183,11 +340,11 @@ func main() { router.Handle("/score", aliasHandler) // Add Ollama API compatibility layer (only register with trailing slash to catch sub-paths) - ollamaHandler := ollama.NewHTTPHandler(log, scheduler, schedulerHTTP, nil, modelManager) + ollamaHandler := ollama.NewHTTPHandler(log, scheduler, schedulerHTTP, nil, config.modelManager) router.Handle(ollama.APIPrefix+"/", ollamaHandler) // Add Anthropic Messages API compatibility layer - anthropicHandler := anthropic.NewHandler(log, schedulerHTTP, nil, modelManager) + anthropicHandler := anthropic.NewHandler(log, schedulerHTTP, nil, config.modelManager) router.Handle(anthropic.APIPrefix+"/", anthropicHandler) // Register root handler LAST - it will only catch exact "/" requests that don't match other patterns @@ -213,6 +370,11 @@ func main() { log.Info("Metrics endpoint disabled") } + return router +} + +// startServer starts the HTTP server and returns the server instance and error channel +func startServer(router *routing.NormalizedServeMux, sockName string) (*http.Server, chan error) { server := &http.Server{ Handler: router, ReadHeaderTimeout: 10 * time.Second, @@ -240,16 +402,20 @@ func main() { if err != nil { log.Fatalf("Failed to listen on socket: %v", err) } + // Set appropriate permissions on the socket file to restrict access + if err := os.Chmod(sockName, socketFileMode); err != nil { + log.Errorf("Failed to set socket file permissions: %v", err) + } go func() { serverErrors <- server.Serve(ln) }() } - schedulerErrors := make(chan error, 1) - go func() { - schedulerErrors <- scheduler.Run(ctx) - }() + return server, serverErrors +} +// waitForShutdown waits for shutdown signals and handles cleanup +func waitForShutdown(ctx context.Context, server *http.Server, serverErrors chan error, schedulerErrors chan error) { select { case err := <-serverErrors: if err != nil { @@ -266,7 +432,6 @@ func main() { log.Errorf("Scheduler error: %v", err) } } - log.Infoln("Docker Model Runner stopped") } // createLlamaCppConfigFromEnv creates a LlamaCppConfig from environment variables @@ -279,15 +444,20 @@ func createLlamaCppConfigFromEnv() config.BackendConfig { return nil // nil will cause the backend to use its default configuration } - // Split the string by spaces, respecting quoted arguments - args := splitArgs(argsStr) + // Split the string by spaces, respecting quoted arguments using shellwords + args, err := shellwords.Parse(argsStr) + if err != nil { + log.Errorf("Failed to parse LLAMA_ARGS: %v. Using default configuration.", err) + return nil + } // Check for disallowed arguments disallowedArgs := []string{"--model", "--host", "--embeddings", "--mmproj"} for _, arg := range args { for _, disallowed := range disallowedArgs { - if arg == disallowed { - log.Fatalf("LLAMA_ARGS cannot override the %s argument as it is controlled by the model runner", disallowed) + if isDisallowedArg(arg, disallowed) { + log.Errorf("LLAMA_ARGS cannot override the %s argument as it is controlled by the model runner. Using default configuration.", disallowed) + return nil } } } @@ -298,29 +468,7 @@ func createLlamaCppConfigFromEnv() config.BackendConfig { } } -// splitArgs splits a string into arguments, respecting quoted arguments -func splitArgs(s string) []string { - var args []string - var currentArg strings.Builder - inQuotes := false - - for _, r := range s { - switch { - case r == '"' || r == '\'': - inQuotes = !inQuotes - case r == ' ' && !inQuotes: - if currentArg.Len() > 0 { - args = append(args, currentArg.String()) - currentArg.Reset() - } - default: - currentArg.WriteRune(r) - } - } - - if currentArg.Len() > 0 { - args = append(args, currentArg.String()) - } - - return args +// isDisallowedArg checks if an argument matches a disallowed argument pattern +func isDisallowedArg(arg, disallowed string) bool { + return arg == disallowed || (strings.HasPrefix(arg, disallowed) && len(arg) > len(disallowed) && arg[len(disallowed)] == '=') } diff --git a/main_test.go b/main_test.go index d4e993dc2..ffdd4d42b 100644 --- a/main_test.go +++ b/main_test.go @@ -29,11 +29,21 @@ func TestCreateLlamaCppConfigFromEnv(t *testing.T) { llamaArgs: "--model test.gguf", wantErr: true, }, + { + name: "disallowed model arg with equals", + llamaArgs: "--model=test.gguf", + wantErr: true, + }, { name: "disallowed host arg", llamaArgs: "--host localhost:8080", wantErr: true, }, + { + name: "disallowed host arg with equals", + llamaArgs: "--host=localhost:8080", + wantErr: true, + }, { name: "disallowed embeddings arg", llamaArgs: "--embeddings", @@ -44,6 +54,11 @@ func TestCreateLlamaCppConfigFromEnv(t *testing.T) { llamaArgs: "--mmproj test.mmproj", wantErr: true, }, + { + name: "disallowed mmproj arg with equals", + llamaArgs: "--mmproj=test.mmproj", + wantErr: true, + }, { name: "multiple disallowed args", llamaArgs: "--model test.gguf --host localhost:8080", @@ -78,14 +93,18 @@ func TestCreateLlamaCppConfigFromEnv(t *testing.T) { config := createLlamaCppConfigFromEnv() + // With the new error handling, we don't exit on error, just log it + // So we expect exitCode to always be 0 (no fatal exit) + if exitCode != 0 { + t.Errorf("Expected exit code 0, got %d", exitCode) + } + if tt.wantErr { - if exitCode != 1 { - t.Errorf("Expected exit code 1, got %d", exitCode) + // For error cases, we now return nil config instead of exiting + if config != nil { + t.Error("Expected nil config for error cases") } } else { - if exitCode != 0 { - t.Errorf("Expected exit code 0, got %d", exitCode) - } if tt.llamaArgs == "" { if config != nil { t.Error("Expected nil config for empty args") diff --git a/pkg/diskusage/diskusage.go b/pkg/diskusage/diskusage.go index e2ba3ff1a..09b018c8b 100644 --- a/pkg/diskusage/diskusage.go +++ b/pkg/diskusage/diskusage.go @@ -5,6 +5,7 @@ import ( "path/filepath" ) +// Size calculates the total size of files in the given directory path. func Size(path string) (int64, error) { var size int64 err := filepath.WalkDir(path, func(_ string, d fs.DirEntry, err error) error { diff --git a/pkg/distribution/huggingface/client.go b/pkg/distribution/huggingface/client.go index 9dc5a64e9..03da1222d 100644 --- a/pkg/distribution/huggingface/client.go +++ b/pkg/distribution/huggingface/client.go @@ -81,7 +81,8 @@ func (c *Client) ListFiles(ctx context.Context, repo, revision string) ([]RepoFi } // HuggingFace API endpoint for listing files - url := fmt.Sprintf("%s/api/models/%s/tree/%s", c.baseURL, repo, revision) + // Use recursive=true to get files from subdirectories (needed for diffusers models) + url := fmt.Sprintf("%s/api/models/%s/tree/%s?recursive=true", c.baseURL, repo, revision) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) if err != nil { @@ -94,7 +95,9 @@ func (c *Client) ListFiles(ctx context.Context, repo, revision string) ([]RepoFi if err != nil { return nil, fmt.Errorf("list files: %w", err) } - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() if err := c.checkResponse(resp, repo); err != nil { return nil, err @@ -131,7 +134,7 @@ func (c *Client) DownloadFile(ctx context.Context, repo, revision, filename stri } if err := c.checkResponse(resp, repo); err != nil { - resp.Body.Close() + _ = resp.Body.Close() return nil, 0, err } diff --git a/pkg/distribution/huggingface/downloader.go b/pkg/distribution/huggingface/downloader.go index 0bb537319..632c8e837 100644 --- a/pkg/distribution/huggingface/downloader.go +++ b/pkg/distribution/huggingface/downloader.go @@ -155,7 +155,9 @@ func (d *Downloader) downloadFileWithProgress(ctx context.Context, file RepoFile if err != nil { return "", err } - defer reader.Close() + defer func() { + _ = reader.Close() + }() // Create local file f, err := os.Create(localPath) diff --git a/pkg/distribution/huggingface/model.go b/pkg/distribution/huggingface/model.go index fc3c07b27..b28841e94 100644 --- a/pkg/distribution/huggingface/model.go +++ b/pkg/distribution/huggingface/model.go @@ -1,10 +1,12 @@ package huggingface import ( + "archive/tar" "context" "fmt" "io" "log" + "os" "path/filepath" "sort" "strings" @@ -28,15 +30,25 @@ func BuildModel(ctx context.Context, client *Client, repo, revision string, temp return nil, fmt.Errorf("list files: %w", err) } - // Step 2: Filter to model files (safetensors + configs) - safetensorsFiles, configFiles := FilterModelFiles(files) + // Step 2: Detect model type and filter files accordingly + modelType := DetectModelType(files) - if len(safetensorsFiles) == 0 { - return nil, fmt.Errorf("no safetensors files found in repository %s", repo) + var modelFiles, configFiles []RepoFile + if modelType == ModelTypeDiffusers { + modelFiles, configFiles = FilterDiffusersFiles(files) + if progressWriter != nil { + _ = progress.WriteProgress(progressWriter, "Detected diffusers model", 0, 0, 0, "") + } + } else { + modelFiles, configFiles = FilterModelFiles(files) + } + + if len(modelFiles) == 0 { + return nil, fmt.Errorf("no model files found in repository %s", repo) } // Combine all files to download - allFiles := append(safetensorsFiles, configFiles...) + allFiles := append(modelFiles, configFiles...) if progressWriter != nil { totalSize := TotalSize(allFiles) @@ -57,7 +69,12 @@ func BuildModel(ctx context.Context, client *Client, repo, revision string, temp _ = progress.WriteProgress(progressWriter, "Building model artifact...", 0, 0, 0, "") } - model, err := buildModelFromFiles(result.LocalPaths, safetensorsFiles, configFiles, tempDir) + var model types.ModelArtifact + if modelType == ModelTypeDiffusers { + model, err = buildDiffusersModel(result.LocalPaths, modelFiles, configFiles, tempDir) + } else { + model, err = buildModelFromFiles(result.LocalPaths, modelFiles, configFiles, tempDir) + } if err != nil { return nil, fmt.Errorf("build model: %w", err) } @@ -119,6 +136,128 @@ func buildModelFromFiles(localPaths map[string]string, safetensorsFiles, configF return b.Model(), nil } +// buildDiffusersModel constructs an OCI model artifact for a diffusers model +func buildDiffusersModel(localPaths map[string]string, modelFiles, configFiles []RepoFile, tempDir string) (types.ModelArtifact, error) { + // For diffusers models, we create a tar archive preserving the directory structure + allFiles := append(modelFiles, configFiles...) + + // Create a tar archive of all files + archivePath, err := createDiffusersArchive(localPaths, allFiles, tempDir) + if err != nil { + return nil, fmt.Errorf("create diffusers archive: %w", err) + } + + // We still need safetensors files to create the base model + // Find a safetensors file to use as the base + var safetensorsPaths, binPaths []string + for _, f := range modelFiles { + lowerFilename := strings.ToLower(f.Filename()) + localPath, ok := localPaths[f.Path] + if !ok { + continue + } + if strings.HasSuffix(lowerFilename, ".safetensors") { + safetensorsPaths = append(safetensorsPaths, localPath) + } else if strings.HasSuffix(lowerFilename, ".bin") { + binPaths = append(binPaths, localPath) + } + } + + if len(safetensorsPaths) == 0 { + safetensorsPaths = binPaths + } + + if len(safetensorsPaths) == 0 { + return nil, fmt.Errorf("no model weight files found") + } + + sort.Strings(safetensorsPaths) + + // Create builder from the first weight file to establish base + // Note: We use the first file to establish the base, but the full archive will contain all files + // The builder.FromSafetensors function handles both .safetensors and .bin files + b, err := builder.FromSafetensors([]string{safetensorsPaths[0]}) + if err != nil { + return nil, fmt.Errorf("create builder: %w", err) + } + + // Add the diffusers archive as a directory tar layer + // This will be extracted preserving the full directory structure + b, err = b.WithDirTar(archivePath) + if err != nil { + return nil, fmt.Errorf("add diffusers archive: %w", err) + } + + return b.Model(), nil +} + +// createDiffusersArchive creates a tar archive of diffusers model files preserving directory structure +func createDiffusersArchive(localPaths map[string]string, files []RepoFile, tempDir string) (string, error) { + archivePath := filepath.Join(tempDir, "diffusers-model.tar") + + f, err := os.Create(archivePath) + if err != nil { + return "", fmt.Errorf("create archive file: %w", err) + } + defer f.Close() + + tw := tar.NewWriter(f) + defer tw.Close() + + for _, file := range files { + localPath, ok := localPaths[file.Path] + if !ok { + log.Printf("Warning: skipping file %s (not downloaded)", file.Path) + continue + } + + // Add file to archive with its original path (preserving directory structure) + if err := addFileToTar(tw, localPath, file.Path); err != nil { + return "", fmt.Errorf("add file %s to archive: %w", file.Path, err) + } + } + + return archivePath, nil +} + +// addFileToTar adds a file to a tar archive with the specified archive path +func addFileToTar(tw *tar.Writer, sourcePath, archivePath string) error { + // Get file info + info, err := os.Stat(sourcePath) + if err != nil { + return fmt.Errorf("stat file: %w", err) + } + + // Create tar header + header, err := tar.FileInfoHeader(info, "") + if err != nil { + return fmt.Errorf("create tar header: %w", err) + } + + // Use the archive path (with forward slashes for tar) + header.Name = filepath.ToSlash(archivePath) + + // Write header + if err := tw.WriteHeader(header); err != nil { + return fmt.Errorf("write tar header: %w", err) + } + + // If it's a file (not directory), write contents + if !info.IsDir() { + file, err := os.Open(sourcePath) + if err != nil { + return fmt.Errorf("open file: %w", err) + } + defer file.Close() + + if _, err := io.Copy(tw, file); err != nil { + return fmt.Errorf("copy file contents: %w", err) + } + } + + return nil +} + // createConfigArchive creates a tar archive of config files in the specified tempDir func createConfigArchive(localPaths map[string]string, configFiles []RepoFile, tempDir string) (string, error) { // Collect config file paths (excluding chat templates which are added separately) diff --git a/pkg/distribution/huggingface/repository.go b/pkg/distribution/huggingface/repository.go index 79a6890e7..a3fd8cede 100644 --- a/pkg/distribution/huggingface/repository.go +++ b/pkg/distribution/huggingface/repository.go @@ -36,6 +36,16 @@ func (f *RepoFile) Filename() string { return path.Base(f.Path) } +// ModelType represents the type of model (LLM vs diffusers) +type ModelType int + +const ( + // ModelTypeLLM is a standard LLM model with safetensors at root + ModelTypeLLM ModelType = iota + // ModelTypeDiffusers is a diffusers model with model_index.json + ModelTypeDiffusers +) + // fileType represents the type of file for model packaging type fileType int @@ -46,12 +56,19 @@ const ( fileTypeSafetensors // fileTypeConfig is a configuration file (json, txt, etc.) fileTypeConfig + // fileTypeDiffusersIndex is the model_index.json file for diffusers models + fileTypeDiffusersIndex ) // classifyFile determines the file type based on filename func classifyFile(filename string) fileType { lower := strings.ToLower(filename) + // Check for diffusers model_index.json + if lower == "model_index.json" { + return fileTypeDiffusersIndex + } + // Check for safetensors files if strings.HasSuffix(lower, ".safetensors") { return fileTypeSafetensors @@ -87,6 +104,8 @@ func FilterModelFiles(files []RepoFile) (safetensors []RepoFile, configs []RepoF safetensors = append(safetensors, f) case fileTypeConfig: configs = append(configs, f) + case fileTypeDiffusersIndex: + // Skip diffusers index files here since they're handled separately case fileTypeUnknown: // Skip unknown file types } @@ -112,3 +131,50 @@ func isSafetensorsModel(files []RepoFile) bool { } return false } + +// DetectModelType determines if the repository contains an LLM or diffusers model +func DetectModelType(files []RepoFile) ModelType { + for _, f := range files { + if f.Type == "file" && f.Path == "model_index.json" { + return ModelTypeDiffusers + } + } + return ModelTypeLLM +} + +// IsDiffusersModel checks if the repository is a diffusers model +func IsDiffusersModel(files []RepoFile) bool { + return DetectModelType(files) == ModelTypeDiffusers +} + +// FilterDiffusersFiles filters repository files for a diffusers model. +// For diffusers models, we need to download: +// - model_index.json +// - All *.safetensors and *.bin files (including in subdirectories) +// - All config.json files (in root and subdirectories) +// - scheduler_config.json, preprocessor_config.json, etc. +func FilterDiffusersFiles(files []RepoFile) (modelFiles []RepoFile, configFiles []RepoFile) { + for _, f := range files { + if f.Type != "file" { + continue + } + + lower := strings.ToLower(f.Filename()) + + // Include model weight files + if strings.HasSuffix(lower, ".safetensors") || strings.HasSuffix(lower, ".bin") { + modelFiles = append(modelFiles, f) + continue + } + + // Include config files + if strings.HasSuffix(lower, ".json") || + strings.HasSuffix(lower, ".txt") || + strings.HasSuffix(lower, ".yaml") || + strings.HasSuffix(lower, ".yml") { + configFiles = append(configFiles, f) + continue + } + } + return modelFiles, configFiles +} diff --git a/pkg/distribution/internal/partial/layer.go b/pkg/distribution/internal/partial/layer.go index 0884bc19d..cd2ad7913 100644 --- a/pkg/distribution/internal/partial/layer.go +++ b/pkg/distribution/internal/partial/layer.go @@ -13,11 +13,13 @@ import ( var _ v1.Layer = &Layer{} +// Layer represents a layer in a model distribution. type Layer struct { Path string v1.Descriptor } +// NewLayer creates a new layer from a file path and media type. func NewLayer(path string, mt ggcrtypes.MediaType) (*Layer, error) { f, err := os.Open(path) if err != nil { @@ -70,26 +72,32 @@ func NewLayer(path string, mt ggcrtypes.MediaType) (*Layer, error) { }, err } +// Digest returns the layer's digest. func (l Layer) Digest() (v1.Hash, error) { return l.DiffID() } +// DiffID returns the layer's diff ID. func (l Layer) DiffID() (v1.Hash, error) { return l.Descriptor.Digest, nil } +// Compressed returns a reader for the compressed layer contents. func (l Layer) Compressed() (io.ReadCloser, error) { return l.Uncompressed() } +// Uncompressed returns a reader for the uncompressed layer contents. func (l Layer) Uncompressed() (io.ReadCloser, error) { return os.Open(l.Path) } +// Size returns the size of the layer. func (l Layer) Size() (int64, error) { return l.Descriptor.Size, nil } +// MediaType returns the media type of the layer. func (l Layer) MediaType() (ggcrtypes.MediaType, error) { return l.Descriptor.MediaType, nil } diff --git a/pkg/distribution/internal/partial/model.go b/pkg/distribution/internal/partial/model.go index 38610a04a..457b47bee 100644 --- a/pkg/distribution/internal/partial/model.go +++ b/pkg/distribution/internal/partial/model.go @@ -19,30 +19,37 @@ type BaseModel struct { var _ types.ModelArtifact = &BaseModel{} +// Layers returns the layers of the model. func (m *BaseModel) Layers() ([]v1.Layer, error) { return m.LayerList, nil } +// Size returns the total size of the model. func (m *BaseModel) Size() (int64, error) { return partial.Size(m) } +// ConfigName returns the hash of the model's config file. func (m *BaseModel) ConfigName() (v1.Hash, error) { return partial.ConfigName(m) } +// ConfigFile returns the model's config file. func (m *BaseModel) ConfigFile() (*v1.ConfigFile, error) { return nil, fmt.Errorf("invalid for model") } +// Digest returns the digest of the model. func (m *BaseModel) Digest() (v1.Hash, error) { return partial.Digest(m) } +// Manifest returns the manifest of the model. func (m *BaseModel) Manifest() (*v1.Manifest, error) { return ManifestForLayers(m) } +// LayerByDigest returns the layer with the given digest. func (m *BaseModel) LayerByDigest(hash v1.Hash) (v1.Layer, error) { for _, l := range m.LayerList { d, err := l.Digest() @@ -56,6 +63,7 @@ func (m *BaseModel) LayerByDigest(hash v1.Hash) (v1.Layer, error) { return nil, fmt.Errorf("layer not found") } +// LayerByDiffID returns the layer with the given diff ID. func (m *BaseModel) LayerByDiffID(hash v1.Hash) (v1.Layer, error) { for _, l := range m.LayerList { d, err := l.DiffID() @@ -69,14 +77,17 @@ func (m *BaseModel) LayerByDiffID(hash v1.Hash) (v1.Layer, error) { return nil, fmt.Errorf("layer not found") } +// RawManifest returns the raw manifest of the model. func (m *BaseModel) RawManifest() ([]byte, error) { return partial.RawManifest(m) } +// RawConfigFile returns the raw config file of the model. func (m *BaseModel) RawConfigFile() ([]byte, error) { return json.Marshal(m.ModelConfigFile) } +// MediaType returns the media type of the model. func (m *BaseModel) MediaType() (ggcr.MediaType, error) { manifest, err := m.Manifest() if err != nil { @@ -85,14 +96,17 @@ func (m *BaseModel) MediaType() (ggcr.MediaType, error) { return manifest.MediaType, nil } +// ID returns the ID of the model. func (m *BaseModel) ID() (string, error) { return ID(m) } +// Config returns the configuration of the model. func (m *BaseModel) Config() (types.ModelConfig, error) { return Config(m) } +// Descriptor returns the descriptor of the model. func (m *BaseModel) Descriptor() (types.Descriptor, error) { return Descriptor(m) } diff --git a/pkg/distribution/internal/partial/partial.go b/pkg/distribution/internal/partial/partial.go index 9ba9c232f..45e1e298f 100644 --- a/pkg/distribution/internal/partial/partial.go +++ b/pkg/distribution/internal/partial/partial.go @@ -11,6 +11,7 @@ import ( ggcr "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/types" ) +// WithRawConfigFile is an interface for getting raw config file data. type WithRawConfigFile interface { // RawConfigFile returns the serialized bytes of this model's config file. RawConfigFile() ([]byte, error) @@ -70,6 +71,7 @@ type WithRawManifest interface { RawManifest() ([]byte, error) } +// ID returns the ID of the model. func ID(i WithRawManifest) (string, error) { digest, err := partial.Digest(i) if err != nil { @@ -78,15 +80,18 @@ func ID(i WithRawManifest) (string, error) { return digest.String(), nil } +// WithLayers is an interface for accessing model layers. type WithLayers interface { WithRawConfigFile Layers() ([]v1.Layer, error) } +// GGUFPaths returns the paths of GGUF layers in the model. func GGUFPaths(i WithLayers) ([]string, error) { return layerPathsByMediaType(i, types.MediaTypeGGUF) } +// MMPROJPath returns the path of the multimodal projector layer in the model. func MMPROJPath(i WithLayers) (string, error) { paths, err := layerPathsByMediaType(i, types.MediaTypeMultimodalProjector) if err != nil { @@ -102,6 +107,7 @@ func MMPROJPath(i WithLayers) (string, error) { return paths[0], err } +// ChatTemplatePath returns the path of the chat template layer in the model. func ChatTemplatePath(i WithLayers) (string, error) { paths, err := layerPathsByMediaType(i, types.MediaTypeChatTemplate) if err != nil { @@ -117,10 +123,12 @@ func ChatTemplatePath(i WithLayers) (string, error) { return paths[0], err } +// SafetensorsPaths returns the paths of safetensors layers in the model. func SafetensorsPaths(i WithLayers) ([]string, error) { return layerPathsByMediaType(i, types.MediaTypeSafetensors) } +// ConfigArchivePath returns the path of the VLLM config archive layer in the model. func ConfigArchivePath(i WithLayers) (string, error) { paths, err := layerPathsByMediaType(i, types.MediaTypeVLLMConfigArchive) if err != nil { @@ -184,6 +192,7 @@ func matchesMediaType(layerMT, targetMT ggcr.MediaType) bool { } } +// ManifestForLayers creates a manifest for the given layers. func ManifestForLayers(i WithLayers) (*v1.Manifest, error) { cfgLayer, err := partial.ConfigLayer(i) if err != nil { diff --git a/pkg/distribution/internal/progress/reporter.go b/pkg/distribution/internal/progress/reporter.go index db038aa08..7079cb5ff 100644 --- a/pkg/distribution/internal/progress/reporter.go +++ b/pkg/distribution/internal/progress/reporter.go @@ -16,6 +16,7 @@ const UpdateInterval = 100 * time.Millisecond // before sending a progress update const MinBytesForUpdate = 1024 * 1024 // 1MB +// Layer represents progress information for a single layer. type Layer struct { ID string // Layer ID Size uint64 // Layer size @@ -31,6 +32,7 @@ type Message struct { Layer Layer `json:"layer"` // Current layer information } +// Reporter tracks and reports progress for image operations. type Reporter struct { progress chan v1.Update done chan struct{} @@ -43,14 +45,17 @@ type Reporter struct { type progressF func(update v1.Update) string +// PullMsg formats a message for pull operations. func PullMsg(update v1.Update) string { return fmt.Sprintf("Downloaded: %.2f MB", float64(update.Complete)/1024/1024) } +// PushMsg formats a message for push operations. func PushMsg(update v1.Update) string { return fmt.Sprintf("Uploaded: %.2f MB", float64(update.Complete)/1024/1024) } +// NewProgressReporter creates a new progress reporter. func NewProgressReporter(w io.Writer, msgF progressF, imageSize int64, layer v1.Layer) *Reporter { return &Reporter{ out: w, diff --git a/pkg/distribution/internal/store/blobs.go b/pkg/distribution/internal/store/blobs.go index 7e85ec554..fda720a96 100644 --- a/pkg/distribution/internal/store/blobs.go +++ b/pkg/distribution/internal/store/blobs.go @@ -107,7 +107,9 @@ func (s *LocalStore) writeLayer(layer blob, updates chan<- v1.Update) (bool, v1. if err != nil { return false, v1.Hash{}, fmt.Errorf("get blob contents: %w", err) } - defer lr.Close() + defer func() { + _ = lr.Close() + }() // Wrap the reader with progress reporting, accounting for already downloaded bytes var r io.Reader @@ -155,7 +157,7 @@ func (s *LocalStore) WriteBlob(diffID v1.Hash, r io.Reader) error { } computedHash, _, err := v1.SHA256(existingFile) - existingFile.Close() + _ = existingFile.Close() if err == nil && computedHash.String() == diffID.String() { // File is already complete, just rename it @@ -190,7 +192,7 @@ func (s *LocalStore) WriteBlob(diffID v1.Hash, r io.Reader) error { return fmt.Errorf("copy blob %q to store: %w", diffID.String(), err) } - f.Close() // Rename will fail on Windows if the file is still open. + _ = f.Close() // Rename will fail on Windows if the file is still open. // For resumed downloads, verify the complete file's hash before finalizing // (For new downloads, the stream was already verified during download) @@ -199,7 +201,9 @@ func (s *LocalStore) WriteBlob(diffID v1.Hash, r io.Reader) error { if err != nil { return fmt.Errorf("open completed file for verification: %w", err) } - defer completeFile.Close() + defer func() { + _ = completeFile.Close() + }() computedHash, _, err := v1.SHA256(completeFile) if err != nil { diff --git a/pkg/distribution/internal/store/manifests.go b/pkg/distribution/internal/store/manifests.go index 178166fea..4c2274562 100644 --- a/pkg/distribution/internal/store/manifests.go +++ b/pkg/distribution/internal/store/manifests.go @@ -92,12 +92,12 @@ func writeFile(path string, data []byte) error { } if _, err := tmp.Write(data); err != nil { - tmp.Close() + _ = tmp.Close() cleanup() return fmt.Errorf("write temporary file %q: %w", tmpName, err) } if err := tmp.Sync(); err != nil { - tmp.Close() + _ = tmp.Close() cleanup() return fmt.Errorf("sync temporary file %q: %w", tmpName, err) } diff --git a/pkg/distribution/internal/utils/utils.go b/pkg/distribution/internal/utils/utils.go index 74025471e..e879c5a56 100644 --- a/pkg/distribution/internal/utils/utils.go +++ b/pkg/distribution/internal/utils/utils.go @@ -1,3 +1,4 @@ +// Package utils provides utility functions for the distribution package. package utils import ( diff --git a/pkg/distribution/packaging/dirtar.go b/pkg/distribution/packaging/dirtar.go index 82c5aa2cf..e5f89817b 100644 --- a/pkg/distribution/packaging/dirtar.go +++ b/pkg/distribution/packaging/dirtar.go @@ -34,7 +34,7 @@ func CreateDirectoryTarArchive(dirPath string) (string, error) { shouldKeepTempFile := false defer func() { if !shouldKeepTempFile { - os.Remove(tmpPath) + _ = os.Remove(tmpPath) } }() @@ -96,14 +96,14 @@ func CreateDirectoryTarArchive(dirPath string) (string, error) { }) if err != nil { - tw.Close() - tmpFile.Close() + _ = tw.Close() + _ = tmpFile.Close() return "", fmt.Errorf("walk directory: %w", err) } // Close tar writer if err := tw.Close(); err != nil { - tmpFile.Close() + _ = tmpFile.Close() return "", fmt.Errorf("close tar writer: %w", err) } @@ -147,7 +147,7 @@ func (p *DirTarProcessor) Process() ([]string, func(), error) { // Return cleanup function cleanup := func() { for _, tempFile := range p.tempFiles { - os.Remove(tempFile) + _ = os.Remove(tempFile) } } diff --git a/pkg/distribution/registry/client.go b/pkg/distribution/registry/client.go index 20b7c0a93..ca02e5fd5 100644 --- a/pkg/distribution/registry/client.go +++ b/pkg/distribution/registry/client.go @@ -168,8 +168,8 @@ func (c *Client) BlobURL(reference string, digest v1.Hash) (string, error) { } return fmt.Sprintf("%s://%s/v2/%s/blobs/%s", - ref.Context().Registry.Scheme(), - ref.Context().Registry.RegistryStr(), + ref.Context().Scheme(), + ref.Context().RegistryStr(), ref.Context().RepositoryStr(), digest.String()), nil } diff --git a/pkg/distribution/registry/client_test.go b/pkg/distribution/registry/client_test.go index b1c64ce6d..ad5de387c 100644 --- a/pkg/distribution/registry/client_test.go +++ b/pkg/distribution/registry/client_test.go @@ -35,8 +35,8 @@ func TestGetDefaultRegistryOptions_NoEnvVars(t *testing.T) { } // Verify it uses HTTPS (secure by default) - if ref.Context().Registry.Scheme() != "https" { - t.Errorf("Expected scheme to be 'https', got '%s'", ref.Context().Registry.Scheme()) + if ref.Context().Scheme() != "https" { + t.Errorf("Expected scheme to be 'https', got '%s'", ref.Context().Scheme()) } } @@ -64,8 +64,8 @@ func TestGetDefaultRegistryOptions_OnlyDefaultRegistry(t *testing.T) { } // Verify it's not insecure (should use https) - if ref.Context().Registry.Scheme() != "https" { - t.Errorf("Expected scheme to be 'https', got '%s'", ref.Context().Registry.Scheme()) + if ref.Context().Scheme() != "https" { + t.Errorf("Expected scheme to be 'https', got '%s'", ref.Context().Scheme()) } } @@ -89,8 +89,8 @@ func TestGetDefaultRegistryOptions_OnlyInsecureRegistry(t *testing.T) { } // Insecure registries should use http - if ref.Context().Registry.Scheme() != "http" { - t.Errorf("Expected scheme to be 'http', got '%s'", ref.Context().Registry.Scheme()) + if ref.Context().Scheme() != "http" { + t.Errorf("Expected scheme to be 'http', got '%s'", ref.Context().Scheme()) } } @@ -119,8 +119,8 @@ func TestGetDefaultRegistryOptions_BothEnvVars(t *testing.T) { } // Check insecure is applied (http scheme) - if ref.Context().Registry.Scheme() != "http" { - t.Errorf("Expected scheme to be 'http', got '%s'", ref.Context().Registry.Scheme()) + if ref.Context().Scheme() != "http" { + t.Errorf("Expected scheme to be 'http', got '%s'", ref.Context().Scheme()) } } diff --git a/pkg/distribution/tarball/file.go b/pkg/distribution/tarball/file.go index 4c66fc059..2c3d639ff 100644 --- a/pkg/distribution/tarball/file.go +++ b/pkg/distribution/tarball/file.go @@ -27,7 +27,9 @@ func (t *FileTarget) Write(ctx context.Context, mdl types.ModelArtifact, pw io.W if err != nil { return fmt.Errorf("create file for archive: %w", err) } - defer f.Close() + defer func() { + _ = f.Close() + }() target, err := NewTarget(f) if err != nil { return fmt.Errorf("create target: %w", err) diff --git a/pkg/distribution/tarball/reader.go b/pkg/distribution/tarball/reader.go index 91a83fb19..ee6015b0c 100644 --- a/pkg/distribution/tarball/reader.go +++ b/pkg/distribution/tarball/reader.go @@ -12,6 +12,7 @@ import ( v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1" ) +// Reader reads a tarball containing model artifacts. type Reader struct { tr *tar.Reader rawManifest []byte @@ -19,15 +20,18 @@ type Reader struct { done bool } +// Blob represents a blob in a tarball. type Blob struct { diffID v1.Hash rc io.ReadCloser } +// DiffID returns the diff ID of the blob. func (b Blob) DiffID() (v1.Hash, error) { return b.diffID, nil } +// Uncompressed returns an uncompressed reader for the blob. func (b Blob) Uncompressed() (io.ReadCloser, error) { return b.rc, nil } @@ -81,6 +85,7 @@ func (r *Reader) Read(p []byte) (n int, err error) { return r.tr.Read(p) } +// Manifest returns the manifest from the tarball. func (r *Reader) Manifest() ([]byte, v1.Hash, error) { if !r.done { return nil, v1.Hash{}, errors.New("must read all blobs first before getting manifest") @@ -91,6 +96,7 @@ func (r *Reader) Manifest() ([]byte, v1.Hash, error) { return r.rawManifest, r.digest, nil } +// NewReader creates a new tarball reader. func NewReader(r io.Reader) *Reader { return &Reader{ tr: tar.NewReader(r), diff --git a/pkg/distribution/tarball/target.go b/pkg/distribution/tarball/target.go index dcfa0c7f0..f26076ef6 100644 --- a/pkg/distribution/tarball/target.go +++ b/pkg/distribution/tarball/target.go @@ -26,7 +26,7 @@ func NewTarget(w io.Writer) (*Target, error) { }, nil } -// Write writes the artifact in archive format to the configured io.Writer +// Write writes the artifact in archive format to the configured io.Writer. func (t *Target) Write(ctx context.Context, mdl types.ModelArtifact, progressWriter io.Writer) error { tw := tar.NewWriter(t.writer) defer tw.Close() diff --git a/pkg/distribution/types/config.go b/pkg/distribution/types/config.go index ccd78847a..4b7367812 100644 --- a/pkg/distribution/types/config.go +++ b/pkg/distribution/types/config.go @@ -32,7 +32,9 @@ const ( // MediaTypeChatTemplate indicates a Jinja chat template MediaTypeChatTemplate = types.MediaType("application/vnd.docker.ai.chat.template.jinja") - FormatGGUF = Format("gguf") + // FormatGGUF represents the GGUF format. + FormatGGUF = Format("gguf") + // FormatSafetensors represents the Safetensors format. FormatSafetensors = Format("safetensors") // OCI Annotation keys for model layers @@ -50,6 +52,7 @@ const ( AnnotationMediaTypeUntested = "org.cncf.model.file.mediatype.untested" ) +// Format represents the format of a model. type Format string // ModelConfig provides a unified interface for accessing model configuration. @@ -65,6 +68,7 @@ type ModelConfig interface { GetQuantization() string } +// ConfigFile represents a model configuration file. type ConfigFile struct { Config Config `json:"config"` Descriptor Descriptor `json:"descriptor"` diff --git a/pkg/go-containerregistry/cmd/crane/cmd/serve.go b/pkg/go-containerregistry/cmd/crane/cmd/serve.go index 4c8fbaae8..0d59c4e60 100644 --- a/pkg/go-containerregistry/cmd/crane/cmd/serve.go +++ b/pkg/go-containerregistry/cmd/crane/cmd/serve.go @@ -23,9 +23,8 @@ import ( "os" "time" - "github.com/spf13/cobra" - "github.com/docker/model-runner/pkg/go-containerregistry/pkg/registry" + "github.com/spf13/cobra" ) func NewCmdRegistry() *cobra.Command { diff --git a/pkg/go-containerregistry/pkg/authn/kubernetes/keychain_test.go b/pkg/go-containerregistry/pkg/authn/kubernetes/keychain_test.go index b804bea58..35433272a 100644 --- a/pkg/go-containerregistry/pkg/authn/kubernetes/keychain_test.go +++ b/pkg/go-containerregistry/pkg/authn/kubernetes/keychain_test.go @@ -23,9 +23,9 @@ import ( "reflect" "testing" - "github.com/google/go-cmp/cmp" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/authn" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/name" + "github.com/google/go-cmp/cmp" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" diff --git a/pkg/go-containerregistry/pkg/gcrane/copy_test.go b/pkg/go-containerregistry/pkg/gcrane/copy_test.go index d01c18193..beb137737 100644 --- a/pkg/go-containerregistry/pkg/gcrane/copy_test.go +++ b/pkg/go-containerregistry/pkg/gcrane/copy_test.go @@ -29,7 +29,6 @@ import ( "testing" "time" - "github.com/google/go-cmp/cmp" "github.com/docker/model-runner/pkg/go-containerregistry/internal/retry" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/logs" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/name" @@ -40,6 +39,7 @@ import ( "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/remote" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/remote/transport" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/types" + "github.com/google/go-cmp/cmp" ) type fakeXCR struct { diff --git a/pkg/go-containerregistry/pkg/v1/daemon/image.go b/pkg/go-containerregistry/pkg/v1/daemon/image.go index efef75b45..61c2ad07c 100644 --- a/pkg/go-containerregistry/pkg/v1/daemon/image.go +++ b/pkg/go-containerregistry/pkg/v1/daemon/image.go @@ -22,7 +22,6 @@ import ( "time" api "github.com/docker/docker/api/types/image" - "github.com/docker/model-runner/pkg/go-containerregistry/pkg/name" v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/tarball" diff --git a/pkg/go-containerregistry/pkg/v1/daemon/image_test.go b/pkg/go-containerregistry/pkg/v1/daemon/image_test.go index 52e425493..6fb057ee0 100644 --- a/pkg/go-containerregistry/pkg/v1/daemon/image_test.go +++ b/pkg/go-containerregistry/pkg/v1/daemon/image_test.go @@ -26,12 +26,11 @@ import ( api "github.com/docker/docker/api/types/image" "github.com/docker/docker/api/types/storage" "github.com/docker/docker/client" - specs "github.com/moby/docker-image-spec/specs-go/v1" - "github.com/docker/model-runner/pkg/go-containerregistry/pkg/name" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/compare" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/tarball" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/validate" + specs "github.com/moby/docker-image-spec/specs-go/v1" ) var imagePath = "../tarball/testdata/test_image_1.tar" diff --git a/pkg/go-containerregistry/pkg/v1/daemon/write_test.go b/pkg/go-containerregistry/pkg/v1/daemon/write_test.go index 26fb79dd4..b91c3c254 100644 --- a/pkg/go-containerregistry/pkg/v1/daemon/write_test.go +++ b/pkg/go-containerregistry/pkg/v1/daemon/write_test.go @@ -24,7 +24,6 @@ import ( api "github.com/docker/docker/api/types/image" "github.com/docker/docker/client" - "github.com/docker/model-runner/pkg/go-containerregistry/pkg/name" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/empty" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/tarball" diff --git a/pkg/go-containerregistry/pkg/v1/google/list_test.go b/pkg/go-containerregistry/pkg/v1/google/list_test.go index 4a946f126..045483f20 100644 --- a/pkg/go-containerregistry/pkg/v1/google/list_test.go +++ b/pkg/go-containerregistry/pkg/v1/google/list_test.go @@ -26,10 +26,10 @@ import ( "testing" "time" - "github.com/google/go-cmp/cmp" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/authn" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/logs" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/name" + "github.com/google/go-cmp/cmp" ) func mustParseDuration(t *testing.T, d string) time.Duration { diff --git a/pkg/go-containerregistry/pkg/v1/layout/write_test.go b/pkg/go-containerregistry/pkg/v1/layout/write_test.go index f54e9cf58..8b51a8edb 100644 --- a/pkg/go-containerregistry/pkg/v1/layout/write_test.go +++ b/pkg/go-containerregistry/pkg/v1/layout/write_test.go @@ -23,7 +23,6 @@ import ( "strings" "testing" - "github.com/google/go-cmp/cmp" v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/empty" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/match" @@ -32,6 +31,7 @@ import ( "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/stream" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/types" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/validate" + "github.com/google/go-cmp/cmp" ) func TestWrite(t *testing.T) { diff --git a/pkg/go-containerregistry/pkg/v1/mutate/index_test.go b/pkg/go-containerregistry/pkg/v1/mutate/index_test.go index 8cbd81d05..68bc081a5 100644 --- a/pkg/go-containerregistry/pkg/v1/mutate/index_test.go +++ b/pkg/go-containerregistry/pkg/v1/mutate/index_test.go @@ -19,7 +19,6 @@ import ( "strings" "testing" - "github.com/google/go-cmp/cmp" v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/empty" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/mutate" @@ -27,6 +26,7 @@ import ( "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/random" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/types" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/validate" + "github.com/google/go-cmp/cmp" ) func TestAppendIndex(t *testing.T) { diff --git a/pkg/go-containerregistry/pkg/v1/mutate/mutate_test.go b/pkg/go-containerregistry/pkg/v1/mutate/mutate_test.go index ad033d915..168cca02a 100644 --- a/pkg/go-containerregistry/pkg/v1/mutate/mutate_test.go +++ b/pkg/go-containerregistry/pkg/v1/mutate/mutate_test.go @@ -26,8 +26,6 @@ import ( "testing" "time" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/empty" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/match" @@ -38,6 +36,8 @@ import ( "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/tarball" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/types" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/validate" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" ) func TestExtractWhiteout(t *testing.T) { diff --git a/pkg/go-containerregistry/pkg/v1/partial/compressed_test.go b/pkg/go-containerregistry/pkg/v1/partial/compressed_test.go index 074fe8b0e..005ac899d 100644 --- a/pkg/go-containerregistry/pkg/v1/partial/compressed_test.go +++ b/pkg/go-containerregistry/pkg/v1/partial/compressed_test.go @@ -21,7 +21,6 @@ import ( "path" "testing" - "github.com/google/go-cmp/cmp" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/name" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/registry" v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1" @@ -31,6 +30,7 @@ import ( "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/remote" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/types" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/validate" + "github.com/google/go-cmp/cmp" ) // Remote leverages a lot of compressed partials. diff --git a/pkg/go-containerregistry/pkg/v1/partial/with_test.go b/pkg/go-containerregistry/pkg/v1/partial/with_test.go index 779f5b31e..ccc74a47d 100644 --- a/pkg/go-containerregistry/pkg/v1/partial/with_test.go +++ b/pkg/go-containerregistry/pkg/v1/partial/with_test.go @@ -17,11 +17,11 @@ package partial_test import ( "testing" - "github.com/google/go-cmp/cmp" v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/partial" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/random" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/types" + "github.com/google/go-cmp/cmp" ) func TestRawConfigFile(t *testing.T) { diff --git a/pkg/go-containerregistry/pkg/v1/platform_test.go b/pkg/go-containerregistry/pkg/v1/platform_test.go index 34e431438..e44b5e577 100644 --- a/pkg/go-containerregistry/pkg/v1/platform_test.go +++ b/pkg/go-containerregistry/pkg/v1/platform_test.go @@ -17,8 +17,8 @@ package v1_test import ( "testing" - "github.com/google/go-cmp/cmp" v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1" + "github.com/google/go-cmp/cmp" ) func TestPlatformString(t *testing.T) { diff --git a/pkg/go-containerregistry/pkg/v1/remote/catalog_test.go b/pkg/go-containerregistry/pkg/v1/remote/catalog_test.go index 20c072a98..c6836436b 100644 --- a/pkg/go-containerregistry/pkg/v1/remote/catalog_test.go +++ b/pkg/go-containerregistry/pkg/v1/remote/catalog_test.go @@ -23,8 +23,8 @@ import ( "net/url" "testing" - "github.com/google/go-cmp/cmp" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/name" + "github.com/google/go-cmp/cmp" ) func TestCatalogPage(t *testing.T) { diff --git a/pkg/go-containerregistry/pkg/v1/remote/descriptor_test.go b/pkg/go-containerregistry/pkg/v1/remote/descriptor_test.go index 8a485794b..7276ac522 100644 --- a/pkg/go-containerregistry/pkg/v1/remote/descriptor_test.go +++ b/pkg/go-containerregistry/pkg/v1/remote/descriptor_test.go @@ -25,9 +25,9 @@ import ( "strings" "testing" - "github.com/google/go-cmp/cmp" v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/types" + "github.com/google/go-cmp/cmp" ) var fakeDigest = "sha256:0000000000000000000000000000000000000000000000000000000000000000" diff --git a/pkg/go-containerregistry/pkg/v1/remote/image_test.go b/pkg/go-containerregistry/pkg/v1/remote/image_test.go index d7c01c5c5..ed555aa34 100644 --- a/pkg/go-containerregistry/pkg/v1/remote/image_test.go +++ b/pkg/go-containerregistry/pkg/v1/remote/image_test.go @@ -27,7 +27,6 @@ import ( "strings" "testing" - "github.com/google/go-cmp/cmp" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/authn" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/logs" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/name" @@ -38,6 +37,7 @@ import ( "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/random" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/types" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/validate" + "github.com/google/go-cmp/cmp" ) const bogusDigest = "sha256:deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef" diff --git a/pkg/go-containerregistry/pkg/v1/remote/index_test.go b/pkg/go-containerregistry/pkg/v1/remote/index_test.go index b123569ed..a551e7733 100644 --- a/pkg/go-containerregistry/pkg/v1/remote/index_test.go +++ b/pkg/go-containerregistry/pkg/v1/remote/index_test.go @@ -23,10 +23,10 @@ import ( "net/url" "testing" - "github.com/google/go-cmp/cmp" v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/random" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/types" + "github.com/google/go-cmp/cmp" ) func randomIndex(t *testing.T) v1.ImageIndex { diff --git a/pkg/go-containerregistry/pkg/v1/remote/list_test.go b/pkg/go-containerregistry/pkg/v1/remote/list_test.go index 479361d77..5356cfa4e 100644 --- a/pkg/go-containerregistry/pkg/v1/remote/list_test.go +++ b/pkg/go-containerregistry/pkg/v1/remote/list_test.go @@ -23,8 +23,8 @@ import ( "strings" "testing" - "github.com/google/go-cmp/cmp" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/name" + "github.com/google/go-cmp/cmp" ) func TestList(t *testing.T) { diff --git a/pkg/go-containerregistry/pkg/v1/remote/progress_test.go b/pkg/go-containerregistry/pkg/v1/remote/progress_test.go index 8da3ef382..23fcb4e67 100644 --- a/pkg/go-containerregistry/pkg/v1/remote/progress_test.go +++ b/pkg/go-containerregistry/pkg/v1/remote/progress_test.go @@ -23,7 +23,6 @@ import ( "sync" "testing" - "github.com/google/go-cmp/cmp" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/name" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/registry" v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1" @@ -31,6 +30,7 @@ import ( "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/mutate" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/random" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/types" + "github.com/google/go-cmp/cmp" ) func TestWriteLayer_Progress(t *testing.T) { diff --git a/pkg/go-containerregistry/pkg/v1/remote/referrers_test.go b/pkg/go-containerregistry/pkg/v1/remote/referrers_test.go index 5a6a313b0..4ad632aa2 100644 --- a/pkg/go-containerregistry/pkg/v1/remote/referrers_test.go +++ b/pkg/go-containerregistry/pkg/v1/remote/referrers_test.go @@ -20,7 +20,6 @@ import ( "net/url" "testing" - "github.com/google/go-cmp/cmp" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/name" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/registry" v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1" @@ -28,6 +27,7 @@ import ( "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/random" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/remote" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/types" + "github.com/google/go-cmp/cmp" ) func TestReferrers(t *testing.T) { diff --git a/pkg/go-containerregistry/pkg/v1/remote/transport/bearer.go b/pkg/go-containerregistry/pkg/v1/remote/transport/bearer.go index 8ef1bb185..991fb5094 100644 --- a/pkg/go-containerregistry/pkg/v1/remote/transport/bearer.go +++ b/pkg/go-containerregistry/pkg/v1/remote/transport/bearer.go @@ -27,7 +27,6 @@ import ( "sync" authchallenge "github.com/docker/distribution/registry/client/auth/challenge" - "github.com/docker/model-runner/pkg/go-containerregistry/internal/redact" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/authn" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/logs" diff --git a/pkg/go-containerregistry/pkg/v1/remote/write_test.go b/pkg/go-containerregistry/pkg/v1/remote/write_test.go index cbb59c052..671283eb6 100644 --- a/pkg/go-containerregistry/pkg/v1/remote/write_test.go +++ b/pkg/go-containerregistry/pkg/v1/remote/write_test.go @@ -30,8 +30,6 @@ import ( "sync/atomic" "testing" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/name" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/registry" v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1" @@ -44,6 +42,8 @@ import ( "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/tarball" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/types" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/validate" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" ) func mustNewTag(t *testing.T, s string) name.Tag { diff --git a/pkg/go-containerregistry/pkg/v1/validate/image.go b/pkg/go-containerregistry/pkg/v1/validate/image.go index 07ad6e77d..6a137353b 100644 --- a/pkg/go-containerregistry/pkg/v1/validate/image.go +++ b/pkg/go-containerregistry/pkg/v1/validate/image.go @@ -21,9 +21,9 @@ import ( "io" "strings" - "github.com/google/go-cmp/cmp" v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/partial" + "github.com/google/go-cmp/cmp" ) // Image validates that img does not violate any invariants of the image format. diff --git a/pkg/go-containerregistry/pkg/v1/validate/index.go b/pkg/go-containerregistry/pkg/v1/validate/index.go index ba60b31f1..9a3e812ee 100644 --- a/pkg/go-containerregistry/pkg/v1/validate/index.go +++ b/pkg/go-containerregistry/pkg/v1/validate/index.go @@ -20,10 +20,10 @@ import ( "fmt" "strings" - "github.com/google/go-cmp/cmp" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/logs" v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/types" + "github.com/google/go-cmp/cmp" ) // Index validates that idx does not violate any invariants of the index format. diff --git a/pkg/inference/backend.go b/pkg/inference/backend.go index c4c9dfb88..0f7a73e15 100644 --- a/pkg/inference/backend.go +++ b/pkg/inference/backend.go @@ -18,6 +18,9 @@ const ( // mode. BackendModeEmbedding BackendModeReranking + // BackendModeImageGeneration indicates that the backend should run in + // image generation mode. + BackendModeImageGeneration ) type ErrGGUFParse struct { @@ -37,6 +40,8 @@ func (m BackendMode) String() string { return "embedding" case BackendModeReranking: return "reranking" + case BackendModeImageGeneration: + return "image_generation" default: return "unknown" } @@ -72,6 +77,8 @@ func ParseBackendMode(mode string) (BackendMode, bool) { return BackendModeEmbedding, true case "reranking": return BackendModeReranking, true + case "image_generation": + return BackendModeImageGeneration, true default: return BackendModeCompletion, false } @@ -101,6 +108,16 @@ type LlamaCppConfig struct { ReasoningBudget *int32 `json:"reasoning-budget,omitempty"` } +// DiffusersConfig contains diffusers-specific configuration options. +type DiffusersConfig struct { + // Device specifies the compute device (cpu, cuda, mps). + Device string `json:"device,omitempty"` + // Precision specifies the compute precision (fp16, bf16, fp32). + Precision string `json:"precision,omitempty"` + // EnableAttentionSlicing enables attention slicing for memory efficiency. + EnableAttentionSlicing bool `json:"enable-attention-slicing,omitempty"` +} + type BackendConfiguration struct { // Shared configuration across all backends ContextSize *int32 `json:"context-size,omitempty"` @@ -108,8 +125,9 @@ type BackendConfiguration struct { Speculative *SpeculativeDecodingConfig `json:"speculative,omitempty"` // Backend-specific configuration - VLLM *VLLMConfig `json:"vllm,omitempty"` - LlamaCpp *LlamaCppConfig `json:"llamacpp,omitempty"` + VLLM *VLLMConfig `json:"vllm,omitempty"` + LlamaCpp *LlamaCppConfig `json:"llamacpp,omitempty"` + Diffusers *DiffusersConfig `json:"diffusers,omitempty"` } type RequiredMemory struct { diff --git a/pkg/inference/backends/diffusers/diffusers.go b/pkg/inference/backends/diffusers/diffusers.go new file mode 100644 index 000000000..57dc15776 --- /dev/null +++ b/pkg/inference/backends/diffusers/diffusers.go @@ -0,0 +1,178 @@ +package diffusers + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + + "github.com/docker/model-runner/pkg/diskusage" + "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/backends" + "github.com/docker/model-runner/pkg/inference/models" + "github.com/docker/model-runner/pkg/inference/platform" + "github.com/docker/model-runner/pkg/logging" +) + +const ( + // Name is the backend name. + Name = "diffusers" + // diffusersDir is the default installation directory in Docker containers. + diffusersDir = "/opt/diffusers-env" +) + +// ErrorNotFound indicates that the diffusers Python environment was not found. +var ErrorNotFound = errors.New("diffusers Python environment not found") + +// diffusersBackend is the diffusers-based backend implementation for image generation. +type diffusersBackend struct { + // log is the associated logger. + log logging.Logger + // modelManager is the shared model manager. + modelManager *models.Manager + // serverLog is the logger to use for the diffusers server process. + serverLog logging.Logger + // config is the configuration for the diffusers backend. + config *Config + // status is the state in which the diffusers backend is in. + status string + // pythonPath is the path to the Python interpreter to use. + pythonPath string +} + +// New creates a new diffusers-based backend for image generation. +func New(log logging.Logger, modelManager *models.Manager, serverLog logging.Logger, conf *Config) (inference.Backend, error) { + if conf == nil { + conf = NewDefaultConfig() + } + + return &diffusersBackend{ + log: log, + modelManager: modelManager, + serverLog: serverLog, + config: conf, + status: "not installed", + }, nil +} + +// Name implements inference.Backend.Name. +func (d *diffusersBackend) Name() string { + return Name +} + +// UsesExternalModelManagement implements inference.Backend.UsesExternalModelManagement. +func (d *diffusersBackend) UsesExternalModelManagement() bool { + return false +} + +// UsesTCP implements inference.Backend.UsesTCP. +func (d *diffusersBackend) UsesTCP() bool { + return false +} + +// Install implements inference.Backend.Install. +func (d *diffusersBackend) Install(_ context.Context, _ *http.Client) error { + if !platform.SupportsDiffusers() { + d.status = "not supported on this platform" + return errors.New("diffusers is not supported on this platform") + } + + // Try container path first + containerPython := filepath.Join(diffusersDir, "bin", "python3") + if _, err := os.Stat(containerPython); err == nil { + d.pythonPath = containerPython + return nil + } + + // Try system Python with diffusers installed + systemPython, err := d.findSystemPython() + if err != nil { + d.status = ErrorNotFound.Error() + return ErrorNotFound + } + + d.pythonPath = systemPython + return nil +} + +// findSystemPython looks for a Python installation with diffusers available. +func (d *diffusersBackend) findSystemPython() (string, error) { + pythonCandidates := []string{"python3", "python"} + + // On macOS, also check common homebrew paths + if runtime.GOOS == "darwin" { + pythonCandidates = append([]string{ + "/opt/homebrew/bin/python3", + "/usr/local/bin/python3", + }, pythonCandidates...) + } + + for _, python := range pythonCandidates { + pythonPath, err := exec.LookPath(python) + if err != nil { + continue + } + + return pythonPath, nil + } + + return "", ErrorNotFound +} + +// Run implements inference.Backend.Run. +func (d *diffusersBackend) Run(ctx context.Context, socket, model string, modelRef string, mode inference.BackendMode, backendConfig *inference.BackendConfiguration) error { + if d.pythonPath == "" { + return ErrorNotFound + } + + if mode != inference.BackendModeImageGeneration { + return fmt.Errorf("diffusers backend only supports image generation mode, got %s", mode.String()) + } + + bundle, err := d.modelManager.GetBundle(model) + if err != nil { + return fmt.Errorf("failed to get model: %w", err) + } + + args, err := d.config.GetArgs(bundle, socket, mode, backendConfig) + if err != nil { + return fmt.Errorf("failed to get diffusers arguments: %w", err) + } + + // Add model name arguments + args = append(args, "--served-model-name", model, modelRef) + + return backends.RunBackend(ctx, backends.RunnerConfig{ + BackendName: "diffusers", + Socket: socket, + BinaryPath: d.pythonPath, + SandboxPath: filepath.Dir(d.pythonPath), + SandboxConfig: "", + Args: args, + Logger: d.log, + ServerLogWriter: d.serverLog.Writer(), + }) +} + +// Status implements inference.Backend.Status. +func (d *diffusersBackend) Status() string { + return d.status +} + +// GetDiskUsage implements inference.Backend.GetDiskUsage. +func (d *diffusersBackend) GetDiskUsage() (int64, error) { + // Check if we're using the container installation + if _, err := os.Stat(diffusersDir); err == nil { + size, err := diskusage.Size(diffusersDir) + if err != nil { + return 0, fmt.Errorf("error while getting store size: %w", err) + } + return size, nil + } + // For system Python, report 0 since it's not managed by us + return 0, nil +} diff --git a/pkg/inference/backends/diffusers/diffusers_config.go b/pkg/inference/backends/diffusers/diffusers_config.go new file mode 100644 index 000000000..084ee88d5 --- /dev/null +++ b/pkg/inference/backends/diffusers/diffusers_config.go @@ -0,0 +1,76 @@ +package diffusers + +import ( + "fmt" + + "github.com/docker/model-runner/pkg/distribution/types" + "github.com/docker/model-runner/pkg/inference" +) + +// Config is the configuration for the diffusers backend. +type Config struct { + // Args are the base arguments that are always included. + Args []string +} + +// NewDefaultConfig creates a new Config with default values. +func NewDefaultConfig() *Config { + return &Config{ + Args: []string{}, + } +} + +// GetArgs implements config.BackendConfig.GetArgs. +func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error) { + if mode != inference.BackendModeImageGeneration { + return nil, fmt.Errorf("diffusers backend only supports image generation mode") + } + + // Start with base arguments + args := append([]string{}, c.Args...) + + // Python module entry point + args = append(args, "-m", "diffusers_server") + + // Get model path + modelPath, err := getModelPath(bundle) + if err != nil { + return nil, fmt.Errorf("failed to get model path: %w", err) + } + + // Add model path argument + args = append(args, "--model-path", modelPath) + + // Add socket argument + args = append(args, "--socket", socket) + + // Add runtime flags from backend config + if config != nil { + args = append(args, config.RuntimeFlags...) + } + + // Add diffusers-specific arguments from backend config + if config != nil && config.Diffusers != nil { + if config.Diffusers.Device != "" { + args = append(args, "--device", config.Diffusers.Device) + } + if config.Diffusers.Precision != "" { + args = append(args, "--precision", config.Diffusers.Precision) + } + if config.Diffusers.EnableAttentionSlicing { + args = append(args, "--enable-attention-slicing") + } + } + + return args, nil +} + +// getModelPath extracts the model path from the bundle. +func getModelPath(bundle types.ModelBundle) (string, error) { + rootDir := bundle.RootDir() + if rootDir != "" { + return rootDir, nil + } + + return "", fmt.Errorf("no model path found in bundle") +} diff --git a/pkg/inference/backends/llamacpp/llamacpp_config.go b/pkg/inference/backends/llamacpp/llamacpp_config.go index f0ed4106f..293fe5aba 100644 --- a/pkg/inference/backends/llamacpp/llamacpp_config.go +++ b/pkg/inference/backends/llamacpp/llamacpp_config.go @@ -65,6 +65,8 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference args = append(args, "--embeddings") case inference.BackendModeReranking: args = append(args, "--embeddings", "--reranking") + case inference.BackendModeImageGeneration: + return nil, fmt.Errorf("image generation mode not supported by llama.cpp backend") default: return nil, fmt.Errorf("unsupported backend mode %q", mode) } diff --git a/pkg/inference/backends/mlx/mlx.go b/pkg/inference/backends/mlx/mlx.go index 6a9e6b2f2..831e98a95 100644 --- a/pkg/inference/backends/mlx/mlx.go +++ b/pkg/inference/backends/mlx/mlx.go @@ -20,7 +20,7 @@ const ( Name = "mlx" ) -var ErrStatusNotFound = errors.New("Python or mlx-lm not found") +var ErrStatusNotFound = errors.New("python or mlx-lm not found") // mlx is the MLX-based backend implementation. type mlx struct { diff --git a/pkg/inference/backends/mlx/mlx_config.go b/pkg/inference/backends/mlx/mlx_config.go index bc4f605c6..a80169adb 100644 --- a/pkg/inference/backends/mlx/mlx_config.go +++ b/pkg/inference/backends/mlx/mlx_config.go @@ -49,6 +49,8 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference case inference.BackendModeReranking: // MLX may not support reranking mode return nil, fmt.Errorf("reranking mode not supported by MLX backend") + case inference.BackendModeImageGeneration: + return nil, fmt.Errorf("image generation mode not supported by MLX backend") default: return nil, fmt.Errorf("unsupported backend mode %q", mode) } diff --git a/pkg/inference/backends/runner.go b/pkg/inference/backends/runner.go index 244fa1c7b..1f040d287 100644 --- a/pkg/inference/backends/runner.go +++ b/pkg/inference/backends/runner.go @@ -89,13 +89,15 @@ func RunBackend(ctx context.Context, config RunnerConfig) error { if err != nil { return fmt.Errorf("unable to start %s: %w", config.BackendName, err) } - defer backendSandbox.Close() + defer func() { + _ = backendSandbox.Close() + }() // Handle backend process errors backendErrors := make(chan error, 1) go func() { backendErr := backendSandbox.Command().Wait() - config.ServerLogWriter.Close() + _ = config.ServerLogWriter.Close() errOutput := new(strings.Builder) if _, err := io.Copy(errOutput, tailBuf); err != nil { diff --git a/pkg/inference/backends/sglang/sglang_config.go b/pkg/inference/backends/sglang/sglang_config.go index 4d220d96c..22f42f748 100644 --- a/pkg/inference/backends/sglang/sglang_config.go +++ b/pkg/inference/backends/sglang/sglang_config.go @@ -50,6 +50,8 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference case inference.BackendModeEmbedding: args = append(args, "--is-embedding") case inference.BackendModeReranking: + case inference.BackendModeImageGeneration: + return nil, fmt.Errorf("image generation mode not supported by SGLang backend") default: return nil, fmt.Errorf("unsupported backend mode %q", mode) } diff --git a/pkg/inference/backends/vllm/vllm_config.go b/pkg/inference/backends/vllm/vllm_config.go index b172637f2..a0d63957a 100644 --- a/pkg/inference/backends/vllm/vllm_config.go +++ b/pkg/inference/backends/vllm/vllm_config.go @@ -48,6 +48,8 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference // vLLM doesn't have a specific embedding flag like llama.cpp // Embedding models are detected automatically case inference.BackendModeReranking: + case inference.BackendModeImageGeneration: + return nil, fmt.Errorf("image generation mode not supported by vLLM backend") default: return nil, fmt.Errorf("unsupported backend mode %q", mode) } diff --git a/pkg/inference/platform/platform.go b/pkg/inference/platform/platform.go index 49bffb75e..53ca283ce 100644 --- a/pkg/inference/platform/platform.go +++ b/pkg/inference/platform/platform.go @@ -1,3 +1,4 @@ +// Package platform provides platform-specific checks for backend support. package platform import "runtime" @@ -17,3 +18,9 @@ func SupportsMLX() bool { func SupportsSGLang() bool { return runtime.GOOS == "linux" } + +// SupportsDiffusers returns true if diffusers is supported on the current platform. +// Diffusers is supported on Linux (Docker/GPU) and macOS (MPS/CPU). +func SupportsDiffusers() bool { + return runtime.GOOS == "linux" || runtime.GOOS == "darwin" +} diff --git a/pkg/inference/scheduling/api.go b/pkg/inference/scheduling/api.go index f9460dc21..798d71c19 100644 --- a/pkg/inference/scheduling/api.go +++ b/pkg/inference/scheduling/api.go @@ -41,6 +41,8 @@ func backendModeForRequest(path string) (inference.BackendMode, bool) { } else if strings.HasSuffix(path, "/v1/messages") || strings.HasSuffix(path, "/v1/messages/count_tokens") { // Anthropic Messages API - treated as completion mode return inference.BackendModeCompletion, true + } else if strings.HasSuffix(path, "/v1/images/generations") { + return inference.BackendModeImageGeneration, true } return inference.BackendMode(0), false } diff --git a/pkg/inference/scheduling/http_handler.go b/pkg/inference/scheduling/http_handler.go index 82e57b4d9..211ea4c34 100644 --- a/pkg/inference/scheduling/http_handler.go +++ b/pkg/inference/scheduling/http_handler.go @@ -76,8 +76,14 @@ func (h *HTTPHandler) routeHandlers() map[string]http.HandlerFunc { "POST " + inference.InferencePrefix + "/v1/messages/count_tokens", } + // OpenAI Images API routes + imageRoutes := []string{ + "POST " + inference.InferencePrefix + "/{backend}/v1/images/generations", + "POST " + inference.InferencePrefix + "/v1/images/generations", + } + m := make(map[string]http.HandlerFunc) - for _, route := range append(openAIRoutes, anthropicRoutes...) { + for _, route := range append(append(openAIRoutes, anthropicRoutes...), imageRoutes...) { m[route] = h.handleOpenAIInference } diff --git a/pkg/inference/scheduling/images_api.go b/pkg/inference/scheduling/images_api.go new file mode 100644 index 000000000..5f87f0c1a --- /dev/null +++ b/pkg/inference/scheduling/images_api.go @@ -0,0 +1,38 @@ +package scheduling + +// ImageGenerationRequest represents an OpenAI-compatible image generation request. +// See https://platform.openai.com/docs/api-reference/images/create +type ImageGenerationRequest struct { + // Model is the model to use for image generation. + Model string `json:"model"` + // Prompt is the text description of the desired image(s). + Prompt string `json:"prompt"` + // N is the number of images to generate. Defaults to 1. + N int `json:"n,omitempty"` + // Size is the size of the generated images. Defaults to "1024x1024". + Size string `json:"size,omitempty"` + // Quality is the quality of the image. "standard" or "hd". Defaults to "standard". + Quality string `json:"quality,omitempty"` + // ResponseFormat is the format of the generated images. "url" or "b64_json". Defaults to "url". + ResponseFormat string `json:"response_format,omitempty"` + // Style is the style of the generated images. "vivid" or "natural". Defaults to "vivid". + Style string `json:"style,omitempty"` +} + +// ImageGenerationResponse represents an OpenAI-compatible image generation response. +type ImageGenerationResponse struct { + // Created is the Unix timestamp of when the images were created. + Created int64 `json:"created"` + // Data is the list of generated images. + Data []ImageData `json:"data"` +} + +// ImageData represents a single generated image. +type ImageData struct { + // URL is the URL of the generated image. Present when response_format is "url". + URL string `json:"url,omitempty"` + // B64JSON is the base64-encoded JSON of the generated image. Present when response_format is "b64_json". + B64JSON string `json:"b64_json,omitempty"` + // RevisedPrompt is the prompt that was used if it was revised. + RevisedPrompt string `json:"revised_prompt,omitempty"` +} diff --git a/pkg/internal/jsonutil/jsonutil.go b/pkg/internal/jsonutil/jsonutil.go index c9a2d0fcf..770609476 100644 --- a/pkg/internal/jsonutil/jsonutil.go +++ b/pkg/internal/jsonutil/jsonutil.go @@ -12,7 +12,9 @@ func ReadFile[T any](path string, result T) error { if err != nil { return err } - defer f.Close() + defer func() { + _ = f.Close() + }() dec := json.NewDecoder(f) if err := dec.Decode(&result); err != nil { return fmt.Errorf("parsing JSON: %w", err) diff --git a/pkg/internal/utils/log.go b/pkg/internal/utils/log.go index 740baf901..062485c4a 100644 --- a/pkg/internal/utils/log.go +++ b/pkg/internal/utils/log.go @@ -1,3 +1,4 @@ +// Package utils provides utility functions for internal use. package utils import ( diff --git a/pkg/responses/streaming.go b/pkg/responses/streaming.go index bd06b8e11..d31554983 100644 --- a/pkg/responses/streaming.go +++ b/pkg/responses/streaming.go @@ -486,8 +486,8 @@ func (s *StreamingResponseWriter) sendEvent(eventType string, event *StreamEvent return } - fmt.Fprintf(s.w, "event: %s\n", eventType) - fmt.Fprintf(s.w, "data: %s\n\n", data) + _, _ = fmt.Fprintf(s.w, "event: %s\n", eventType) + _, _ = fmt.Fprintf(s.w, "data: %s\n\n", data) if s.flusher != nil { s.flusher.Flush() diff --git a/python/diffusers_server/__init__.py b/python/diffusers_server/__init__.py new file mode 100644 index 000000000..4f9c08ccb --- /dev/null +++ b/python/diffusers_server/__init__.py @@ -0,0 +1,7 @@ +"""Diffusers server for Docker Model Runner. + +This package provides an OpenAI-compatible image generation API +backed by the Hugging Face diffusers library. +""" + +__version__ = "0.1.0" diff --git a/python/diffusers_server/__main__.py b/python/diffusers_server/__main__.py new file mode 100644 index 000000000..50b8521c5 --- /dev/null +++ b/python/diffusers_server/__main__.py @@ -0,0 +1,71 @@ +"""Entry point for running the diffusers server as a module. + +Usage: + python -m diffusers_server --model-path /path/to/model --socket /path/to/socket +""" + +import argparse +import sys + +from .server import run_server + + +def main(): + parser = argparse.ArgumentParser( + description="Diffusers server for Docker Model Runner" + ) + parser.add_argument( + "--model-path", + required=True, + help="Path to the diffusers model directory", + ) + parser.add_argument( + "--socket", + required=True, + help="Unix socket path to listen on", + ) + parser.add_argument( + "--device", + default="auto", + choices=["auto", "cpu", "cuda", "mps"], + help="Device to run inference on (default: auto)", + ) + parser.add_argument( + "--precision", + default="auto", + choices=["auto", "fp16", "bf16", "fp32"], + help="Precision for inference (default: auto)", + ) + parser.add_argument( + "--enable-attention-slicing", + action="store_true", + help="Enable attention slicing for memory efficiency", + ) + parser.add_argument( + "--served-model-name", + nargs="*", + default=[], + help="Model names to serve (for OpenAI API compatibility)", + ) + + args = parser.parse_args() + + try: + run_server( + model_path=args.model_path, + socket_path=args.socket, + device=args.device, + precision=args.precision, + enable_attention_slicing=args.enable_attention_slicing, + served_model_names=args.served_model_name, + ) + except KeyboardInterrupt: + print("\nShutting down...") + sys.exit(0) + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/python/diffusers_server/pipeline.py b/python/diffusers_server/pipeline.py new file mode 100644 index 000000000..0ab543c5c --- /dev/null +++ b/python/diffusers_server/pipeline.py @@ -0,0 +1,181 @@ +"""Diffusers pipeline wrapper for image generation.""" + +import logging +import os +from typing import List, Optional + +import torch +from PIL import Image + +logger = logging.getLogger(__name__) + + +class DiffusersPipeline: + """Wrapper for diffusers pipelines supporting various image generation models.""" + + def __init__( + self, + model_path: str, + device: str = "auto", + precision: str = "auto", + enable_attention_slicing: bool = False, + ): + """Initialize the diffusers pipeline. + + Args: + model_path: Path to the model directory + device: Device to use (auto, cpu, cuda, mps) + precision: Precision to use (auto, fp16, bf16, fp32) + enable_attention_slicing: Enable attention slicing for memory efficiency + """ + self.model_path = model_path + self.device = self._resolve_device(device) + self.dtype = self._resolve_dtype(precision) + self.enable_attention_slicing = enable_attention_slicing + + logger.info(f"Using device: {self.device}, dtype: {self.dtype}") + + self.pipeline = self._load_pipeline() + + def _resolve_device(self, device: str) -> str: + """Resolve the device to use.""" + if device != "auto": + return device + + if torch.cuda.is_available(): + return "cuda" + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return "mps" + else: + return "cpu" + + def _resolve_dtype(self, precision: str) -> torch.dtype: + """Resolve the dtype to use.""" + if precision == "fp16": + return torch.float16 + elif precision == "bf16": + return torch.bfloat16 + elif precision == "fp32": + return torch.float32 + elif precision == "auto": + # Use fp16 on GPU, fp32 on CPU + if self.device in ("cuda", "mps"): + return torch.float16 + else: + return torch.float32 + else: + return torch.float32 + + def _load_pipeline(self): + """Load the appropriate diffusers pipeline based on model type.""" + from diffusers import ( + DiffusionPipeline, + StableDiffusionPipeline, + StableDiffusionXLPipeline, + ) + + # Check for model_index.json to determine pipeline type + model_index_path = os.path.join(self.model_path, "model_index.json") + + try: + if os.path.exists(model_index_path): + # Use auto-detection via DiffusionPipeline + logger.info("Loading pipeline using DiffusionPipeline.from_pretrained") + pipeline = DiffusionPipeline.from_pretrained( + self.model_path, + torch_dtype=self.dtype, + local_files_only=True, + ) + else: + # Try StableDiffusion as fallback + logger.info("No model_index.json found, trying StableDiffusionPipeline") + try: + pipeline = StableDiffusionXLPipeline.from_pretrained( + self.model_path, + torch_dtype=self.dtype, + local_files_only=True, + ) + except Exception: + pipeline = StableDiffusionPipeline.from_pretrained( + self.model_path, + torch_dtype=self.dtype, + local_files_only=True, + ) + except Exception as e: + logger.error(f"Failed to load pipeline: {e}") + raise RuntimeError(f"Failed to load diffusers model from {self.model_path}: {e}") + + # Move to device + pipeline = pipeline.to(self.device) + + # Apply optimizations + if self.enable_attention_slicing: + logger.info("Enabling attention slicing") + pipeline.enable_attention_slicing() + + # Enable memory efficient attention if available + if self.device == "cuda": + try: + pipeline.enable_xformers_memory_efficient_attention() + logger.info("Enabled xformers memory efficient attention") + except Exception: + logger.info("xformers not available, using default attention") + + return pipeline + + def generate( + self, + prompt: str, + num_images: int = 1, + width: int = 1024, + height: int = 1024, + num_inference_steps: int = 30, + guidance_scale: float = 7.5, + negative_prompt: Optional[str] = None, + seed: Optional[int] = None, + ) -> List[Image.Image]: + """Generate images from a text prompt. + + Args: + prompt: The text prompt to generate images from + num_images: Number of images to generate + width: Image width + height: Image height + num_inference_steps: Number of denoising steps + guidance_scale: Guidance scale for classifier-free guidance + negative_prompt: Negative prompt for guidance + seed: Random seed for reproducibility + + Returns: + List of PIL Images + """ + generator = None + if seed is not None: + generator = torch.Generator(device=self.device).manual_seed(seed) + + logger.info( + f"Generating {num_images} image(s): {width}x{height}, " + f"steps={num_inference_steps}, guidance={guidance_scale}" + ) + + # Build kwargs based on pipeline capabilities + kwargs = { + "prompt": prompt, + "num_images_per_prompt": num_images, + "width": width, + "height": height, + "num_inference_steps": num_inference_steps, + "guidance_scale": guidance_scale, + } + + if generator is not None: + kwargs["generator"] = generator + + if negative_prompt is not None: + kwargs["negative_prompt"] = negative_prompt + + # Run inference + with torch.inference_mode(): + result = self.pipeline(**kwargs) + + return result.images diff --git a/python/diffusers_server/server.py b/python/diffusers_server/server.py new file mode 100644 index 000000000..81cc41e0b --- /dev/null +++ b/python/diffusers_server/server.py @@ -0,0 +1,215 @@ +"""FastAPI server implementing OpenAI-compatible image generation API.""" + +import base64 +import io +import logging +import os +import socket +import time +from typing import List, Optional + +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel, Field +import uvicorn + +from .pipeline import DiffusersPipeline + +logger = logging.getLogger(__name__) + +app = FastAPI(title="Diffusers Server", version="0.1.0") + +# Global pipeline instance +_pipeline: Optional[DiffusersPipeline] = None +_served_model_names: List[str] = [] + + +class ImageGenerationRequest(BaseModel): + """OpenAI-compatible image generation request.""" + + model: str + prompt: str + n: int = Field(default=1, ge=1, le=10) + size: str = Field(default="1024x1024") + quality: str = Field(default="standard") + response_format: str = Field(default="b64_json") + style: str = Field(default="vivid") + + +class ImageData(BaseModel): + """Single generated image data.""" + + url: Optional[str] = None + b64_json: Optional[str] = None + revised_prompt: Optional[str] = None + + +class ImageGenerationResponse(BaseModel): + """OpenAI-compatible image generation response.""" + + created: int + data: List[ImageData] + + +class ErrorDetail(BaseModel): + """Error detail for OpenAI-compatible error response.""" + + message: str + type: str + param: Optional[str] = None + code: Optional[str] = None + + +class ErrorResponse(BaseModel): + """OpenAI-compatible error response.""" + + error: ErrorDetail + + +def parse_size(size: str) -> tuple[int, int]: + """Parse size string like '1024x1024' into (width, height).""" + try: + parts = size.lower().split("x") + if len(parts) != 2: + raise ValueError(f"Invalid size format: {size}") + width, height = int(parts[0]), int(parts[1]) + return width, height + except (ValueError, IndexError) as e: + raise ValueError(f"Invalid size format: {size}") from e + + +@app.post("/v1/images/generations", response_model=ImageGenerationResponse) +async def generate_images(request: ImageGenerationRequest) -> ImageGenerationResponse: + """Generate images from a text prompt.""" + global _pipeline, _served_model_names + + if _pipeline is None: + raise HTTPException(status_code=503, detail="Model not loaded") + + # Validate model name if served_model_names is configured + if _served_model_names and request.model not in _served_model_names: + raise HTTPException( + status_code=404, + detail=f"Model '{request.model}' not found. Available: {_served_model_names}", + ) + + try: + width, height = parse_size(request.size) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + # Map quality to inference steps + num_inference_steps = 50 if request.quality == "hd" else 30 + + try: + images = _pipeline.generate( + prompt=request.prompt, + num_images=request.n, + width=width, + height=height, + num_inference_steps=num_inference_steps, + ) + except Exception as e: + logger.exception("Error generating images") + raise HTTPException(status_code=500, detail=str(e)) + + # Convert images to response format + data = [] + for img in images: + if request.response_format == "b64_json": + # Convert PIL Image to base64 + buffer = io.BytesIO() + img.save(buffer, format="PNG") + b64_data = base64.b64encode(buffer.getvalue()).decode("utf-8") + data.append(ImageData(b64_json=b64_data)) + else: + # URL format not supported without file storage + raise HTTPException( + status_code=400, + detail="response_format 'url' not supported, use 'b64_json'", + ) + + return ImageGenerationResponse( + created=int(time.time()), + data=data, + ) + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return {"status": "ok", "model_loaded": _pipeline is not None} + + +@app.get("/v1/models") +async def list_models(): + """List available models (OpenAI-compatible).""" + models = [] + for name in _served_model_names: + models.append( + { + "id": name, + "object": "model", + "created": int(time.time()), + "owned_by": "diffusers", + } + ) + return {"object": "list", "data": models} + + +def run_server( + model_path: str, + socket_path: str, + device: str = "auto", + precision: str = "auto", + enable_attention_slicing: bool = False, + served_model_names: Optional[List[str]] = None, +): + """Run the diffusers server on a Unix domain socket.""" + global _pipeline, _served_model_names + + # Configure logging + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + logger.info(f"Loading model from {model_path}") + logger.info(f"Device: {device}, Precision: {precision}") + + # Load the pipeline + _pipeline = DiffusersPipeline( + model_path=model_path, + device=device, + precision=precision, + enable_attention_slicing=enable_attention_slicing, + ) + + _served_model_names = served_model_names or [] + logger.info(f"Serving model names: {_served_model_names}") + + # Remove existing socket if present + if os.path.exists(socket_path): + os.unlink(socket_path) + + logger.info(f"Starting server on unix://{socket_path}") + + # Create and bind the socket manually for Unix domain sockets + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.bind(socket_path) + os.chmod(socket_path, 0o666) + + config = uvicorn.Config( + app, + log_level="info", + ) + server = uvicorn.Server(config) + + # Override the socket + server.config.fd = sock.fileno() + + try: + server.run(sockets=[sock]) + finally: + sock.close() + if os.path.exists(socket_path): + os.unlink(socket_path)