diff --git a/cmd/cli/commands/compose.go b/cmd/cli/commands/compose.go index 213862964..7aa3ffacd 100644 --- a/cmd/cli/commands/compose.go +++ b/cmd/cli/commands/compose.go @@ -177,7 +177,7 @@ func downloadModelsOnlyIfNotFound(desktopClient *desktop.Client, models []string printer := desktop.NewSimplePrinter(func(s string) { _ = sendInfo(s) }) - _, _, err = desktopClient.Pull(model, false, printer) + _, _, err = desktopClient.Pull(model, printer) if err != nil { _ = sendErrorf("Failed to pull model: %v", err) return fmt.Errorf("Failed to pull model: %w\n", err) diff --git a/cmd/cli/commands/integration_test.go b/cmd/cli/commands/integration_test.go index 4b74d014f..787321aca 100644 --- a/cmd/cli/commands/integration_test.go +++ b/cmd/cli/commands/integration_test.go @@ -371,7 +371,7 @@ func TestIntegration_PullModel(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Pull the model using the test case reference t.Logf("Pulling model with reference: %s", tc.ref) - err := pullModel(newPullCmd(), env.client, tc.ref, true) + err := pullModel(newPullCmd(), env.client, tc.ref) require.NoError(t, err, "Failed to pull model with reference: %s", tc.ref) // List models and verify the expected model is present @@ -426,7 +426,7 @@ func TestIntegration_InspectModel(t *testing.T) { // Pull the model using a short reference pullRef := "inspect-test" t.Logf("Pulling model with reference: %s", pullRef) - err = pullModel(newPullCmd(), env.client, pullRef, true) + err = pullModel(newPullCmd(), env.client, pullRef) require.NoError(t, err, "Failed to pull model") // Verify the model was pulled @@ -485,7 +485,7 @@ func TestIntegration_TagModel(t *testing.T) { // Pull the model using a simple reference pullRef := "tag-test" t.Logf("Pulling model with reference: %s", pullRef) - err = pullModel(newPullCmd(), env.client, pullRef, true) + err = pullModel(newPullCmd(), env.client, pullRef) require.NoError(t, err, "Failed to pull model") // Verify the model was pulled @@ -663,7 +663,7 @@ func TestIntegration_PushModel(t *testing.T) { // Pull the model using a simple reference pullRef := "tag-test" t.Logf("Pulling model with reference: %s", pullRef) - err = pullModel(newPullCmd(), env.client, pullRef, true) + err = pullModel(newPullCmd(), env.client, pullRef) require.NoError(t, err, "Failed to pull model") // Verify the model was pulled @@ -814,7 +814,7 @@ func TestIntegration_RemoveModel(t *testing.T) { // Pull the model pullRef := "rm-test" t.Logf("Pulling model with reference: %s", pullRef) - err := pullModel(newPullCmd(), env.client, pullRef, true) + err := pullModel(newPullCmd(), env.client, pullRef) require.NoError(t, err, "Failed to pull model") // Verify model exists @@ -848,11 +848,11 @@ func TestIntegration_RemoveModel(t *testing.T) { // Pull both models t.Logf("Pulling first model: rm-multi-1") - err := pullModel(newPullCmd(), env.client, "rm-multi-1", true) + err := pullModel(newPullCmd(), env.client, "rm-multi-1") require.NoError(t, err, "Failed to pull first model") t.Logf("Pulling second model: rm-multi-2") - err = pullModel(newPullCmd(), env.client, "rm-multi-2", true) + err = pullModel(newPullCmd(), env.client, "rm-multi-2") require.NoError(t, err, "Failed to pull second model") // Verify both models exist @@ -878,7 +878,7 @@ func TestIntegration_RemoveModel(t *testing.T) { t.Run("remove specific tag keeps other tags", func(t *testing.T) { // Pull the model t.Logf("Pulling model: rm-test") - err := pullModel(newPullCmd(), env.client, "rm-test", true) + err := pullModel(newPullCmd(), env.client, "rm-test") require.NoError(t, err, "Failed to pull model") // Add multiple tags to the same model @@ -940,7 +940,7 @@ func TestIntegration_RemoveModel(t *testing.T) { t.Run("remove by model ID removes all tags", func(t *testing.T) { // Pull the model t.Logf("Pulling model: rm-test") - err := pullModel(newPullCmd(), env.client, "rm-test", true) + err := pullModel(newPullCmd(), env.client, "rm-test") require.NoError(t, err, "Failed to pull model") // Add multiple tags @@ -971,7 +971,7 @@ func TestIntegration_RemoveModel(t *testing.T) { t.Run("force flag", func(t *testing.T) { // Pull the model t.Logf("Pulling model: rm-test") - err := pullModel(newPullCmd(), env.client, "rm-test", true) + err := pullModel(newPullCmd(), env.client, "rm-test") require.NoError(t, err, "Failed to pull model") // Test removal with force flag diff --git a/cmd/cli/commands/pull.go b/cmd/cli/commands/pull.go index 8fb44f4b7..09d8474b2 100644 --- a/cmd/cli/commands/pull.go +++ b/cmd/cli/commands/pull.go @@ -10,8 +10,6 @@ import ( ) func newPullCmd() *cobra.Command { - var ignoreRuntimeMemoryCheck bool - c := &cobra.Command{ Use: "pull MODEL", Short: "Pull a model from Docker Hub or HuggingFace to your local environment", @@ -20,19 +18,17 @@ func newPullCmd() *cobra.Command { if _, err := ensureStandaloneRunnerAvailable(cmd.Context(), asPrinter(cmd), false); err != nil { return fmt.Errorf("unable to initialize standalone model runner: %w", err) } - return pullModel(cmd, desktopClient, args[0], ignoreRuntimeMemoryCheck) + return pullModel(cmd, desktopClient, args[0]) }, ValidArgsFunction: completion.NoComplete, } - c.Flags().BoolVar(&ignoreRuntimeMemoryCheck, "ignore-runtime-memory-check", false, "Do not block pull if estimated runtime memory for model exceeds system resources.") - return c } -func pullModel(cmd *cobra.Command, desktopClient *desktop.Client, model string, ignoreRuntimeMemoryCheck bool) error { +func pullModel(cmd *cobra.Command, desktopClient *desktop.Client, model string) error { printer := asPrinter(cmd) - response, _, err := desktopClient.Pull(model, ignoreRuntimeMemoryCheck, printer) + response, _, err := desktopClient.Pull(model, printer) if err != nil { return handleClientError(err, "Failed to pull model") diff --git a/cmd/cli/commands/run.go b/cmd/cli/commands/run.go index 8be379174..125d982a6 100644 --- a/cmd/cli/commands/run.go +++ b/cmd/cli/commands/run.go @@ -570,7 +570,6 @@ func chatWithMarkdownContext(ctx context.Context, cmd *cobra.Command, client *de func newRunCmd() *cobra.Command { var debug bool - var ignoreRuntimeMemoryCheck bool var colorMode string var detach bool @@ -686,7 +685,7 @@ func newRunCmd() *cobra.Command { return handleClientError(err, "Failed to inspect model") } cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.") - if err := pullModel(cmd, desktopClient, model, ignoreRuntimeMemoryCheck); err != nil { + if err := pullModel(cmd, desktopClient, model); err != nil { return err } } @@ -733,7 +732,6 @@ func newRunCmd() *cobra.Command { c.Args = requireMinArgs(1, "run", cmdArgs) c.Flags().BoolVar(&debug, "debug", false, "Enable debug logging") - c.Flags().BoolVar(&ignoreRuntimeMemoryCheck, "ignore-runtime-memory-check", false, "Do not block pull if estimated runtime memory for model exceeds system resources.") c.Flags().StringVar(&colorMode, "color", "no", "Use colored output (auto|yes|no)") c.Flags().BoolVarP(&detach, "detach", "d", false, "Load the model in the background without interaction") diff --git a/cmd/cli/desktop/desktop.go b/cmd/cli/desktop/desktop.go index 76d4b8d65..f47150dd3 100644 --- a/cmd/cli/desktop/desktop.go +++ b/cmd/cli/desktop/desktop.go @@ -105,7 +105,7 @@ func (c *Client) Status() Status { } } -func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, printer standalone.StatusPrinter) (string, bool, error) { +func (c *Client) Pull(model string, printer standalone.StatusPrinter) (string, bool, error) { model = normalizeHuggingFaceModelName(model) // Check if this is a Hugging Face model and if HF_TOKEN is set @@ -116,9 +116,8 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, printer stand return c.withRetries("download", 3, printer, func(attempt int) (string, bool, error, bool) { jsonData, err := json.Marshal(dmrm.ModelCreateRequest{ - From: model, - IgnoreRuntimeMemoryCheck: ignoreRuntimeMemoryCheck, - BearerToken: hfToken, + From: model, + BearerToken: hfToken, }) if err != nil { // Marshaling errors are not retryable diff --git a/cmd/cli/desktop/desktop_test.go b/cmd/cli/desktop/desktop_test.go index c04dd7e5c..d09eee4bd 100644 --- a/cmd/cli/desktop/desktop_test.go +++ b/cmd/cli/desktop/desktop_test.go @@ -39,7 +39,7 @@ func TestPullHuggingFaceModel(t *testing.T) { }, nil) printer := NewSimplePrinter(func(s string) {}) - _, _, err := client.Pull(modelName, false, printer) + _, _, err := client.Pull(modelName, printer) assert.NoError(t, err) } @@ -126,7 +126,7 @@ func TestNonHuggingFaceModel(t *testing.T) { }, nil) printer := NewSimplePrinter(func(s string) {}) - _, _, err := client.Pull(modelName, false, printer) + _, _, err := client.Pull(modelName, printer) assert.NoError(t, err) } @@ -250,7 +250,7 @@ func TestPullRetryOnNetworkError(t *testing.T) { ) printer := NewSimplePrinter(func(s string) {}) - _, _, err := client.Pull(modelName, false, printer) + _, _, err := client.Pull(modelName, printer) assert.NoError(t, err) } @@ -270,7 +270,7 @@ func TestPullNoRetryOn4xxError(t *testing.T) { }, nil).Times(1) printer := NewSimplePrinter(func(s string) {}) - _, _, err := client.Pull(modelName, false, printer) + _, _, err := client.Pull(modelName, printer) assert.Error(t, err) assert.Contains(t, err.Error(), "Model not found") } @@ -297,7 +297,7 @@ func TestPullRetryOn5xxError(t *testing.T) { ) printer := NewSimplePrinter(func(s string) {}) - _, _, err := client.Pull(modelName, false, printer) + _, _, err := client.Pull(modelName, printer) assert.NoError(t, err) } @@ -324,7 +324,7 @@ func TestPullRetryOnServiceUnavailable(t *testing.T) { ) printer := NewSimplePrinter(func(s string) {}) - _, _, err := client.Pull(modelName, false, printer) + _, _, err := client.Pull(modelName, printer) assert.NoError(t, err) } @@ -341,7 +341,7 @@ func TestPullMaxRetriesExhausted(t *testing.T) { mockClient.EXPECT().Do(gomock.Any()).Return(nil, io.EOF).Times(4) printer := NewSimplePrinter(func(s string) {}) - _, _, err := client.Pull(modelName, false, printer) + _, _, err := client.Pull(modelName, printer) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to download after 3 retries") } diff --git a/cmd/cli/docs/reference/docker_model_pull.yaml b/cmd/cli/docs/reference/docker_model_pull.yaml index ec05709b2..c93930458 100644 --- a/cmd/cli/docs/reference/docker_model_pull.yaml +++ b/cmd/cli/docs/reference/docker_model_pull.yaml @@ -5,18 +5,6 @@ long: | usage: docker model pull MODEL pname: docker model plink: docker_model.yaml -options: - - option: ignore-runtime-memory-check - value_type: bool - default_value: "false" - description: | - Do not block pull if estimated runtime memory for model exceeds system resources. - deprecated: false - hidden: false - experimental: false - experimentalcli: false - kubernetes: false - swarm: false examples: |- ### Pulling a model from Docker Hub diff --git a/cmd/cli/docs/reference/docker_model_run.yaml b/cmd/cli/docs/reference/docker_model_run.yaml index e279f09c9..781c2bae3 100644 --- a/cmd/cli/docs/reference/docker_model_run.yaml +++ b/cmd/cli/docs/reference/docker_model_run.yaml @@ -41,17 +41,6 @@ options: experimentalcli: false kubernetes: false swarm: false - - option: ignore-runtime-memory-check - value_type: bool - default_value: "false" - description: | - Do not block pull if estimated runtime memory for model exceeds system resources. - deprecated: false - hidden: false - experimental: false - experimentalcli: false - kubernetes: false - swarm: false examples: |- ### One-time prompt diff --git a/cmd/cli/docs/reference/model_pull.md b/cmd/cli/docs/reference/model_pull.md index a8f6a9291..246cc59d7 100644 --- a/cmd/cli/docs/reference/model_pull.md +++ b/cmd/cli/docs/reference/model_pull.md @@ -3,12 +3,6 @@ Pull a model from Docker Hub or HuggingFace to your local environment -### Options - -| Name | Type | Default | Description | -|:--------------------------------|:-------|:--------|:----------------------------------------------------------------------------------| -| `--ignore-runtime-memory-check` | `bool` | | Do not block pull if estimated runtime memory for model exceeds system resources. | - diff --git a/cmd/cli/docs/reference/model_run.md b/cmd/cli/docs/reference/model_run.md index b1712298d..0d271dc9f 100644 --- a/cmd/cli/docs/reference/model_run.md +++ b/cmd/cli/docs/reference/model_run.md @@ -5,12 +5,11 @@ Run a model and interact with it using a submitted prompt or chat mode ### Options -| Name | Type | Default | Description | -|:--------------------------------|:---------|:--------|:----------------------------------------------------------------------------------| -| `--color` | `string` | `no` | Use colored output (auto\|yes\|no) | -| `--debug` | `bool` | | Enable debug logging | -| `-d`, `--detach` | `bool` | | Load the model in the background without interaction | -| `--ignore-runtime-memory-check` | `bool` | | Do not block pull if estimated runtime memory for model exceeds system resources. | +| Name | Type | Default | Description | +|:-----------------|:---------|:--------|:-----------------------------------------------------| +| `--color` | `string` | `no` | Use colored output (auto\|yes\|no) | +| `--debug` | `bool` | | Enable debug logging | +| `-d`, `--detach` | `bool` | | Load the model in the background without interaction | diff --git a/go.mod b/go.mod index f3ef7e560..bc41d4583 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ require ( github.com/containerd/platforms v1.0.0-rc.1 github.com/docker/go-units v0.5.0 github.com/docker/model-runner/pkg/go-containerregistry v0.0.0-20251121150728-6951a2a36575 - github.com/elastic/go-sysinfo v1.15.4 github.com/gpustack/gguf-parser-go v0.22.1 github.com/jaypipes/ghw v0.19.1 github.com/kolesnikovae/go-winjob v1.0.0 @@ -30,7 +29,6 @@ require ( github.com/docker/cli v28.3.0+incompatible // indirect github.com/docker/distribution v2.8.3+incompatible // indirect github.com/docker/docker-credential-helpers v0.9.3 // indirect - github.com/elastic/go-windows v1.0.2 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -47,7 +45,6 @@ require ( github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - github.com/prometheus/procfs v0.15.1 // indirect github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d // indirect github.com/vbatts/tar-split v0.12.1 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect diff --git a/go.sum b/go.sum index 9d6ade104..0de375876 100644 --- a/go.sum +++ b/go.sum @@ -42,10 +42,6 @@ github.com/docker/go-winjob v0.0.0-20250829235554-57b487ebcbc5 h1:dxSFEb0EEmvceI github.com/docker/go-winjob v0.0.0-20250829235554-57b487ebcbc5/go.mod h1:ICOGmIXdwhfid7rQP+tLvDJqVg0lHdEk3pI5nsapTtg= github.com/docker/model-runner/pkg/go-containerregistry v0.0.0-20251121150728-6951a2a36575 h1:N2yLWYSZFTVLkLTh8ux1Z0Nug/F78pXsl2KDtbWhe+Y= github.com/docker/model-runner/pkg/go-containerregistry v0.0.0-20251121150728-6951a2a36575/go.mod h1:gbdiY0X8gr0J88OfUuRD29JXCWT9jgHzPmrqTlO15BM= -github.com/elastic/go-sysinfo v1.15.4 h1:A3zQcunCxik14MgXu39cXFXcIw2sFXZ0zL886eyiv1Q= -github.com/elastic/go-sysinfo v1.15.4/go.mod h1:ZBVXmqS368dOn/jvijV/zHLfakWTYHBZPk3G244lHrU= -github.com/elastic/go-windows v1.0.2 h1:yoLLsAsV5cfg9FLhZ9EXZ2n2sQFKeDYrHenkcivY4vI= -github.com/elastic/go-windows v1.0.2/go.mod h1:bGcDpBzXgYSqM0Gx3DM4+UxFj300SZLixie9u9ixLM8= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -112,8 +108,6 @@ github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNw github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= github.com/prometheus/common v0.67.4 h1:yR3NqWO1/UyO1w2PhUvXlGQs/PtFmoveVO0KZ4+Lvsc= github.com/prometheus/common v0.67.4/go.mod h1:gP0fq6YjjNCLssJCQp0yk4M8W6ikLURwkdd/YKtTbyI= -github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= -github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= @@ -121,8 +115,6 @@ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVs github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d h1:3VwvTjiRPA7cqtgOWddEL+JrcijMlXUmj99c/6YyZoY= github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d/go.mod h1:tAG61zBM1DYRaGIPloumExGvScf08oHuo0kFoOqdbT0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= -github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= diff --git a/main.go b/main.go index 369ff3835..90aacbb0b 100644 --- a/main.go +++ b/main.go @@ -11,13 +11,11 @@ import ( "syscall" "time" - "github.com/docker/model-runner/pkg/gpuinfo" "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/backends/llamacpp" "github.com/docker/model-runner/pkg/inference/backends/mlx" "github.com/docker/model-runner/pkg/inference/backends/vllm" "github.com/docker/model-runner/pkg/inference/config" - "github.com/docker/model-runner/pkg/inference/memory" "github.com/docker/model-runner/pkg/inference/models" "github.com/docker/model-runner/pkg/inference/scheduling" "github.com/docker/model-runner/pkg/metrics" @@ -65,15 +63,6 @@ func main() { llamaServerPath = "/Applications/Docker.app/Contents/Resources/model-runner/bin" } - gpuInfo := gpuinfo.New(llamaServerPath) - - sysMemInfo, err := memory.NewSystemMemoryInfo(log, gpuInfo) - if err != nil { - log.Fatalf("unable to initialize system memory info: %v", err) - } - - memEstimator := memory.NewEstimator(sysMemInfo) - // Create a proxy-aware HTTP transport // Use a safe type assertion with fallback, and explicitly set Proxy to http.ProxyFromEnvironment var baseTransport *http.Transport @@ -94,7 +83,6 @@ func main() { log, modelManager, nil, - memEstimator, ) log.Infof("LLAMA_SERVER_PATH: %s", llamaServerPath) @@ -118,12 +106,6 @@ func main() { log.Fatalf("unable to initialize %s backend: %v", llamacpp.Name, err) } - if os.Getenv("MODEL_RUNNER_RUNTIME_MEMORY_CHECK") == "1" { - memory.SetRuntimeMemoryCheck(true) - } - - memEstimator.SetDefaultBackend(llamaCppBackend) - vllmBackend, err := vllm.New( log, modelManager, @@ -160,7 +142,6 @@ func main() { "", false, ), - sysMemInfo, ) // Create the HTTP handler for the scheduler diff --git a/pkg/inference/memory/estimator.go b/pkg/inference/memory/estimator.go deleted file mode 100644 index 5043a8f1c..000000000 --- a/pkg/inference/memory/estimator.go +++ /dev/null @@ -1,53 +0,0 @@ -package memory - -import ( - "context" - "errors" - "fmt" - - "github.com/docker/model-runner/pkg/inference" -) - -type MemoryEstimator interface { - SetDefaultBackend(MemoryEstimatorBackend) - GetRequiredMemoryForModel(context.Context, string, *inference.BackendConfiguration) (inference.RequiredMemory, error) - HaveSufficientMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (bool, inference.RequiredMemory, inference.RequiredMemory, error) -} - -type MemoryEstimatorBackend interface { - GetRequiredMemoryForModel(context.Context, string, *inference.BackendConfiguration) (inference.RequiredMemory, error) -} - -type memoryEstimator struct { - systemMemoryInfo SystemMemoryInfo - defaultBackend MemoryEstimatorBackend -} - -func NewEstimator(systemMemoryInfo SystemMemoryInfo) MemoryEstimator { - return &memoryEstimator{systemMemoryInfo: systemMemoryInfo} -} - -func (m *memoryEstimator) SetDefaultBackend(backend MemoryEstimatorBackend) { - m.defaultBackend = backend -} - -func (m *memoryEstimator) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (inference.RequiredMemory, error) { - if m.defaultBackend == nil { - return inference.RequiredMemory{}, errors.New("default backend not configured") - } - - return m.defaultBackend.GetRequiredMemoryForModel(ctx, model, config) -} - -func (m *memoryEstimator) HaveSufficientMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (bool, inference.RequiredMemory, inference.RequiredMemory, error) { - req, err := m.GetRequiredMemoryForModel(ctx, model, config) - if err != nil { - return false, inference.RequiredMemory{}, inference.RequiredMemory{}, fmt.Errorf("estimating required memory for model: %w", err) - } - - ok, err := m.systemMemoryInfo.HaveSufficientMemory(req) - if err != nil { - return false, req, inference.RequiredMemory{}, fmt.Errorf("checking if system has sufficient memory: %w", err) - } - return ok, req, m.systemMemoryInfo.GetTotalMemory(), nil -} diff --git a/pkg/inference/memory/settings.go b/pkg/inference/memory/settings.go deleted file mode 100644 index 5da85c469..000000000 --- a/pkg/inference/memory/settings.go +++ /dev/null @@ -1,18 +0,0 @@ -package memory - -import "sync" - -var runtimeMemoryCheck bool -var runtimeMemoryCheckLock sync.Mutex - -func SetRuntimeMemoryCheck(enabled bool) { - runtimeMemoryCheckLock.Lock() - defer runtimeMemoryCheckLock.Unlock() - runtimeMemoryCheck = enabled -} - -func RuntimeMemoryCheckEnabled() bool { - runtimeMemoryCheckLock.Lock() - defer runtimeMemoryCheckLock.Unlock() - return runtimeMemoryCheck -} diff --git a/pkg/inference/memory/system.go b/pkg/inference/memory/system.go deleted file mode 100644 index 4618159c2..000000000 --- a/pkg/inference/memory/system.go +++ /dev/null @@ -1,64 +0,0 @@ -package memory - -import ( - "errors" - - "github.com/docker/model-runner/pkg/gpuinfo" - "github.com/docker/model-runner/pkg/inference" - "github.com/docker/model-runner/pkg/logging" - "github.com/elastic/go-sysinfo" -) - -type SystemMemoryInfo interface { - HaveSufficientMemory(inference.RequiredMemory) (bool, error) - GetTotalMemory() inference.RequiredMemory -} - -type systemMemoryInfo struct { - log logging.Logger - totalMemory inference.RequiredMemory -} - -func NewSystemMemoryInfo(log logging.Logger, gpuInfo *gpuinfo.GPUInfo) (SystemMemoryInfo, error) { - // Compute the amount of available memory. - // TODO(p1-0tr): improve error handling - vramSize, err := gpuInfo.GetVRAMSize() - if err != nil { - vramSize = 1 - log.Warnf("Could not read VRAM size: %s", err) - } else { - log.Infof("Running on system with %d MB VRAM", vramSize/1024/1024) - } - ramSize := uint64(1) - hostInfo, err := sysinfo.Host() - if err != nil { - log.Warnf("Could not read host info: %s", err) - } else { - ram, err := hostInfo.Memory() - if err != nil { - log.Warnf("Could not read host RAM size: %s", err) - } else { - ramSize = ram.Total - log.Infof("Running on system with %d MB RAM", ramSize/1024/1024) - } - } - return &systemMemoryInfo{ - log: log, - totalMemory: inference.RequiredMemory{RAM: ramSize, VRAM: vramSize}, - }, nil -} - -func (s *systemMemoryInfo) HaveSufficientMemory(req inference.RequiredMemory) (bool, error) { - // Sentinel value of 1 indicates unknown RAM/VRAM - if req.RAM > 1 && s.totalMemory.RAM == 1 { - return false, errors.New("system RAM unknown") - } - if req.VRAM > 1 && s.totalMemory.VRAM == 1 { - return false, errors.New("system VRAM unknown") - } - return req.RAM <= s.totalMemory.RAM && req.VRAM <= s.totalMemory.VRAM, nil -} - -func (s *systemMemoryInfo) GetTotalMemory() inference.RequiredMemory { - return s.totalMemory -} diff --git a/pkg/inference/models/api.go b/pkg/inference/models/api.go index 32e8c4a30..8c11dcf3d 100644 --- a/pkg/inference/models/api.go +++ b/pkg/inference/models/api.go @@ -14,9 +14,6 @@ import ( type ModelCreateRequest struct { // From is the name of the model to pull. From string `json:"from"` - // IgnoreRuntimeMemoryCheck indicates whether the server should check if it has sufficient - // memory to run the given model (assuming default configuration). - IgnoreRuntimeMemoryCheck bool `json:"ignore-runtime-memory-check,omitempty"` // BearerToken is an optional bearer token for authentication. BearerToken string `json:"bearer-token,omitempty"` } diff --git a/pkg/inference/models/handler_test.go b/pkg/inference/models/handler_test.go index dd03958a9..03332fcaa 100644 --- a/pkg/inference/models/handler_test.go +++ b/pkg/inference/models/handler_test.go @@ -17,23 +17,10 @@ import ( "github.com/docker/model-runner/pkg/distribution/builder" reg "github.com/docker/model-runner/pkg/distribution/registry" "github.com/docker/model-runner/pkg/inference" - "github.com/docker/model-runner/pkg/inference/memory" "github.com/sirupsen/logrus" ) -type mockMemoryEstimator struct{} - -func (me *mockMemoryEstimator) SetDefaultBackend(_ memory.MemoryEstimatorBackend) {} - -func (me *mockMemoryEstimator) GetRequiredMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (inference.RequiredMemory, error) { - return inference.RequiredMemory{RAM: 0, VRAM: 0}, nil -} - -func (me *mockMemoryEstimator) HaveSufficientMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (bool, inference.RequiredMemory, inference.RequiredMemory, error) { - return true, inference.RequiredMemory{}, inference.RequiredMemory{}, nil -} - // getProjectRoot returns the absolute path to the project root directory func getProjectRoot(t *testing.T) string { // Start from the current test file's directory @@ -123,12 +110,11 @@ func TestPullModel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { log := logrus.NewEntry(logrus.StandardLogger()) - memEstimator := &mockMemoryEstimator{} manager := NewManager(log.WithFields(logrus.Fields{"component": "model-manager"}), ClientConfig{ StoreRootPath: tempDir, Logger: log.WithFields(logrus.Fields{"component": "model-manager"}), }) - handler := NewHTTPHandler(log, manager, nil, memEstimator) + handler := NewHTTPHandler(log, manager, nil) r := httptest.NewRequest(http.MethodPost, "/models/create", strings.NewReader(`{"from": "`+tag+`"}`)) if tt.acceptHeader != "" { @@ -235,14 +221,13 @@ func TestHandleGetModel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { log := logrus.NewEntry(logrus.StandardLogger()) - memEstimator := &mockMemoryEstimator{} manager := NewManager(log.WithFields(logrus.Fields{"component": "model-manager"}), ClientConfig{ StoreRootPath: tempDir, Logger: log.WithFields(logrus.Fields{"component": "model-manager"}), Transport: http.DefaultTransport, UserAgent: "test-agent", }) - handler := NewHTTPHandler(log, manager, nil, memEstimator) + handler := NewHTTPHandler(log, manager, nil) // First pull the model if we're testing local access if !tt.remote && !strings.Contains(tt.modelName, "nonexistent") { @@ -324,7 +309,6 @@ func TestCors(t *testing.T) { for _, tt := range tests { t.Run(tt.path, func(t *testing.T) { t.Parallel() - memEstimator := &mockMemoryEstimator{} discard := logrus.New() discard.SetOutput(io.Discard) log := logrus.NewEntry(discard) @@ -332,7 +316,7 @@ func TestCors(t *testing.T) { StoreRootPath: tempDir, Logger: log.WithFields(logrus.Fields{"component": "model-manager"}), }) - m := NewHTTPHandler(log, manager, []string{"*"}, memEstimator) + m := NewHTTPHandler(log, manager, []string{"*"}) req := httptest.NewRequest(http.MethodOptions, "http://model-runner.docker.internal"+tt.path, http.NoBody) req.Header.Set("Origin", "docker.com") w := httptest.NewRecorder() diff --git a/pkg/inference/models/http_handler.go b/pkg/inference/models/http_handler.go index f849299b3..1a65f25f8 100644 --- a/pkg/inference/models/http_handler.go +++ b/pkg/inference/models/http_handler.go @@ -15,7 +15,6 @@ import ( "github.com/docker/model-runner/pkg/distribution/distribution" "github.com/docker/model-runner/pkg/distribution/registry" "github.com/docker/model-runner/pkg/inference" - "github.com/docker/model-runner/pkg/inference/memory" "github.com/docker/model-runner/pkg/internal/utils" "github.com/docker/model-runner/pkg/logging" "github.com/docker/model-runner/pkg/middleware" @@ -38,8 +37,6 @@ type HTTPHandler struct { httpHandler http.Handler // lock is used to synchronize access to the models manager's router. lock sync.RWMutex - // memoryEstimator is used to calculate runtime memory requirements for models. - memoryEstimator memory.MemoryEstimator // manager handles business logic for model operations. manager *Manager } @@ -56,12 +53,11 @@ type ClientConfig struct { } // NewHTTPHandler creates a new model's handler. -func NewHTTPHandler(log logging.Logger, manager *Manager, allowedOrigins []string, memoryEstimator memory.MemoryEstimator) *HTTPHandler { +func NewHTTPHandler(log logging.Logger, manager *Manager, allowedOrigins []string) *HTTPHandler { m := &HTTPHandler{ - log: log, - router: http.NewServeMux(), - memoryEstimator: memoryEstimator, - manager: manager, + log: log, + router: http.NewServeMux(), + manager: manager, } // Register routes. @@ -162,23 +158,7 @@ func (h *HTTPHandler) handleCreateModel(w http.ResponseWriter, r *http.Request) // Normalize the model name to add defaults request.From = NormalizeModelName(request.From) - // Pull the model. In the future, we may support additional operations here - // besides pulling (such as model building). - if memory.RuntimeMemoryCheckEnabled() && !request.IgnoreRuntimeMemoryCheck { - h.log.Infof("Will estimate memory required for %q", request.From) - proceed, req, totalMem, err := h.memoryEstimator.HaveSufficientMemoryForModel(r.Context(), request.From, nil) - if err != nil { - h.log.Warnf("Failed to validate sufficient system memory for model %q: %s", request.From, err) - // Prefer staying functional in case of unexpected estimation errors. - proceed = true - } - if !proceed { - errstr := fmt.Sprintf("Runtime memory requirement for model %q exceeds total system memory: required %d RAM %d VRAM, system %d RAM %d VRAM", request.From, req.RAM, req.VRAM, totalMem.RAM, totalMem.VRAM) - h.log.Warnf(errstr) - http.Error(w, errstr, http.StatusInsufficientStorage) - return - } - } + // Pull the model if err := h.manager.Pull(request.From, request.BearerToken, r, w); err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { h.log.Infof("Request canceled/timed out while pulling model %q", request.From) diff --git a/pkg/inference/scheduling/loader.go b/pkg/inference/scheduling/loader.go index 9827a2ad2..3d6c4d3bd 100644 --- a/pkg/inference/scheduling/loader.go +++ b/pkg/inference/scheduling/loader.go @@ -11,7 +11,6 @@ import ( "github.com/docker/model-runner/pkg/environment" "github.com/docker/model-runner/pkg/inference" - "github.com/docker/model-runner/pkg/inference/memory" "github.com/docker/model-runner/pkg/inference/models" "github.com/docker/model-runner/pkg/logging" "github.com/docker/model-runner/pkg/metrics" @@ -31,9 +30,6 @@ const ( var ( // errLoadsDisabled indicates that backend loads are currently disabled. errLoadsDisabled = errors.New("backend loading disabled") - // errModelTooBig indicates that the model is too big to ever load into the - // available system memory. - errModelTooBig = errors.New("model too big") // errRunnerAlreadyActive indicates that a given runner is already active // and therefore can't be reconfigured for example errRunnerAlreadyActive = errors.New("runner already active") @@ -97,8 +93,6 @@ type loader struct { modelManager *models.Manager // runnerIdleTimeout is the loader-specific default runner idle timeout. runnerIdleTimeout time.Duration - // totalMemory is the total system memory allocated to the loader. - totalMemory inference.RequiredMemory // idleCheck is used to signal the run loop when timestamps have updated. idleCheck chan struct{} // guard is a sempahore controlling access to all subsequent fields. It is @@ -108,8 +102,6 @@ type loader struct { guard chan struct{} // loadsEnabled signals that loads are currently enabled. loadsEnabled bool - // availableMemory is the available portion of the loader's total memory. - availableMemory inference.RequiredMemory // waiters is the set of signal channels associated with waiting loaders. We // use a set of signaling channels (instead of a sync.Cond) to enable // polling. Each signaling channel should be buffered (with size 1). @@ -121,8 +113,6 @@ type loader struct { slots []*runner // references maps slot indices to reference counts. references []uint - // allocations maps slot indices to memory allocation sizes. - allocations []inference.RequiredMemory // timestamps maps slot indices to last usage times. Values in this slice // are only valid if the corresponding reference count is zero. timestamps []time.Time @@ -138,7 +128,6 @@ func newLoader( backends map[string]inference.Backend, modelManager *models.Manager, openAIRecorder *metrics.OpenAIRecorder, - sysMemInfo memory.SystemMemoryInfo, ) *loader { // Compute the number of runner slots to allocate. Because of RAM and VRAM // limitations, it's unlikely that we'll ever be able to fully populate @@ -159,24 +148,18 @@ func newLoader( runnerIdleTimeout = 8 * time.Hour } - // Compute the amount of available memory. - totalMemory := sysMemInfo.GetTotalMemory() - // Create the loader. l := &loader{ log: log, backends: backends, modelManager: modelManager, runnerIdleTimeout: runnerIdleTimeout, - totalMemory: totalMemory, idleCheck: make(chan struct{}, 1), guard: make(chan struct{}, 1), - availableMemory: totalMemory, waiters: make(map[chan<- struct{}]bool), runners: make(map[runnerKey]runnerInfo, nSlots), slots: make([]*runner, nSlots), references: make([]uint, nSlots), - allocations: make([]inference.RequiredMemory, nSlots), timestamps: make([]time.Time, nSlots), runnerConfigs: make(map[runnerKey]inference.BackendConfiguration), openAIRecorder: openAIRecorder, @@ -211,23 +194,11 @@ func (l *loader) broadcast() { } } -// formatMemorySize formats a memory size in bytes as a string. -// Values of 0 or 1 are treated as sentinel values for "unknown" memory size. -func formatMemorySize(bytes uint64) string { - if bytes <= 1 { - return "unknown" - } - return fmt.Sprintf("%d MB", bytes/1024/1024) -} - -// freeRunnerSlot frees a runner slot and reclaims its memory. +// freeRunnerSlot frees a runner slot. // The caller must hold the loader lock. func (l *loader) freeRunnerSlot(slot int, key runnerKey) { l.slots[slot].terminate() l.slots[slot] = nil - l.availableMemory.RAM += l.allocations[slot].RAM - l.availableMemory.VRAM += l.allocations[slot].VRAM - l.allocations[slot] = inference.RequiredMemory{RAM: 0, VRAM: 0} l.timestamps[slot] = time.Time{} delete(l.runners, key) } @@ -264,10 +235,7 @@ func (l *loader) evict(idleOnly bool) int { } } if evictedCount > 0 { - l.log.Infof("Evicted %d runner(s). Available memory: %s RAM, %s VRAM", - evictedCount, - formatMemorySize(l.availableMemory.RAM), - formatMemorySize(l.availableMemory.VRAM)) + l.log.Infof("Evicted %d runner(s)", evictedCount) } return len(l.runners) } @@ -447,8 +415,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string return nil, ErrBackendNotFound } - // Estimate the amount of memory that will be used by the model and check - // that we're even capable of loading it. + // Get runner configuration if available var runnerConfig *inference.BackendConfiguration draftModelID := "" if rc, ok := l.runnerConfigs[makeConfigKey(backendName, modelID, mode)]; ok { @@ -465,44 +432,8 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string } } } - memory, err := backend.GetRequiredMemoryForModel(ctx, modelID, runnerConfig) - var parseErr *inference.ErrGGUFParse - if errors.As(err, &parseErr) { - // TODO(p1-0tr): For now override memory checks in case model can't be parsed - // e.g. model is too new for gguf-parser-go to know. We should provide a cleaner - // way to bypass these checks. - l.log.Warnf("Could not parse model(%s), memory checks will be ignored for it. Error: %s", modelID, parseErr) - memory = inference.RequiredMemory{ - RAM: 0, - VRAM: 0, - } - } else if err != nil { - return nil, err - } - l.log.Infof("Loading %s, which will require %s RAM and %s VRAM on a system with %s RAM and %s VRAM", - modelID, - formatMemorySize(memory.RAM), formatMemorySize(memory.VRAM), - formatMemorySize(l.totalMemory.RAM), formatMemorySize(l.totalMemory.VRAM)) - - if l.totalMemory.RAM == 1 { - l.log.Warnf("RAM size unknown. Assume model will fit, but only one.") - memory.RAM = 1 - } - if l.totalMemory.VRAM == 1 { - l.log.Warnf("VRAM size unknown. Assume model will fit, but only one.") - memory.VRAM = 1 - } - // Validate if model could fit. - // On Windows, llamacpp can use up to half of system RAM as shared GPU memory - // if it runs out of dedicated VRAM. - totalVRAM := l.totalMemory.VRAM - if runtime.GOOS == "windows" { - totalVRAM += l.totalMemory.RAM / 2 - } - if memory.RAM > l.totalMemory.RAM || memory.VRAM > totalVRAM { - return nil, errModelTooBig - } + l.log.Infof("Loading %s backend runner with model %s in %s mode", backendName, modelID, mode) // Acquire the loader lock and defer its release. if !l.lock(ctx) { @@ -521,14 +452,6 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string // Loop until we can satisfy the request or an error occurs. for { slot := -1 - availableVRAM := l.availableMemory.VRAM - if runtime.GOOS == "windows" { - sharedRAM := l.totalMemory.RAM / 2 - if l.availableMemory.RAM < sharedRAM { - sharedRAM = l.availableMemory.RAM - } - availableVRAM += sharedRAM - } // If loads are disabled, then there's nothing we can do. if !l.loadsEnabled { @@ -555,25 +478,20 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string } } - // If there's not sufficient memory or all slots are full, then try - // evicting unused runners. - if memory.RAM > l.availableMemory.RAM || memory.VRAM > availableVRAM || len(l.runners) == len(l.slots) { - l.log.Infof("Evicting to make room: need %s RAM, %s VRAM; have %s RAM, %s VRAM available; %d/%d slots used", - formatMemorySize(memory.RAM), formatMemorySize(memory.VRAM), - formatMemorySize(l.availableMemory.RAM), - formatMemorySize(availableVRAM), + // If all slots are full, try evicting unused runners. + if len(l.runners) == len(l.slots) { + l.log.Infof("Evicting to make room: %d/%d slots used", len(l.runners), len(l.slots)) runnerCountAtLoopStart := len(l.runners) remainingRunners := l.evict(false) - // Restart the loop if eviction happened to recompute availableVRAM - // and re-evaluate all conditions with the updated state. + // Restart the loop if eviction happened if remainingRunners < runnerCountAtLoopStart { continue } } - // If there's sufficient memory and a free slot, then find the slot. - if memory.RAM <= l.availableMemory.RAM && memory.VRAM <= availableVRAM && len(l.runners) < len(l.slots) { + // If there's a free slot, then find the slot. + if len(l.runners) < len(l.slots) { for s, runner := range l.slots { if runner == nil { slot = s @@ -583,18 +501,13 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string } if slot < 0 { - l.log.Debugf("Cannot load model yet: need %s RAM, %s VRAM; have %s RAM, %s VRAM available; %d/%d slots used", - formatMemorySize(memory.RAM), formatMemorySize(memory.VRAM), - formatMemorySize(l.availableMemory.RAM), - formatMemorySize(availableVRAM), + l.log.Debugf("Cannot load model yet: %d/%d slots used", len(l.runners), len(l.slots)) } // If we've identified a slot, then we're ready to start a runner. if slot >= 0 { - // runnerConfig was already retrieved earlier (lines 401-405), no need to look it up again // Create the runner. - l.log.Infof("Loading %s backend runner with model %s in %s mode", backendName, modelID, mode) runner, err := run(l.log, backend, modelID, modelRef, mode, slot, runnerConfig, l.openAIRecorder) if err != nil { l.log.Warnf("Unable to start %s backend runner with model %s in %s mode: %v", @@ -618,13 +531,9 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string } // Perform registration and return the runner. - l.availableMemory.RAM -= memory.RAM - l.availableMemory.VRAM -= memory.VRAM l.runners[makeRunnerKey(backendName, modelID, draftModelID, mode)] = runnerInfo{slot, modelRef} l.slots[slot] = runner l.references[slot] = 1 - l.allocations[slot].RAM = memory.RAM - l.allocations[slot].VRAM = memory.VRAM return runner, nil } diff --git a/pkg/inference/scheduling/loader_test.go b/pkg/inference/scheduling/loader_test.go index 9c60234ac..7ac5841b2 100644 --- a/pkg/inference/scheduling/loader_test.go +++ b/pkg/inference/scheduling/loader_test.go @@ -54,19 +54,6 @@ func (b *fastFailBackend) Run(ctx context.Context, socket, model string, modelRe return errors.New("boom") } -// mockSystemMemoryInfo implements memory.SystemMemoryInfo for testing -type mockSystemMemoryInfo struct { - totalMemory inference.RequiredMemory -} - -func (m *mockSystemMemoryInfo) HaveSufficientMemory(req inference.RequiredMemory) (bool, error) { - return req.RAM <= m.totalMemory.RAM && req.VRAM <= m.totalMemory.VRAM, nil -} - -func (m *mockSystemMemoryInfo) GetTotalMemory() inference.RequiredMemory { - return m.totalMemory -} - // createTestLogger creates a logger for testing func createTestLogger() *logrus.Entry { log := logrus.New() @@ -138,136 +125,6 @@ func createAliveTerminableMockRunner(log *logrus.Entry, backend inference.Backen } } -// TestFormatMemorySize tests the formatMemorySize helper function -func TestFormatMemorySize(t *testing.T) { - tests := []struct { - name string - bytes uint64 - expected string - }{ - { - name: "sentinel value 0 is unknown", - bytes: 0, - expected: "unknown", - }, - { - name: "sentinel value 1 is unknown", - bytes: 1, - expected: "unknown", - }, - { - name: "2 bytes is still unknown (edge case)", - bytes: 2, - expected: "0 MB", - }, - { - name: "1 MB", - bytes: 1024 * 1024, - expected: "1 MB", - }, - { - name: "512 MB", - bytes: 512 * 1024 * 1024, - expected: "512 MB", - }, - { - name: "1 GB", - bytes: 1024 * 1024 * 1024, - expected: "1024 MB", - }, - { - name: "8 GB", - bytes: 8 * 1024 * 1024 * 1024, - expected: "8192 MB", - }, - { - name: "fractional MB rounds down", - bytes: 1024*1024 + 512*1024, // 1.5 MB - expected: "1 MB", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := formatMemorySize(tt.bytes) - if result != tt.expected { - t.Errorf("formatMemorySize(%d) = %q, want %q", tt.bytes, result, tt.expected) - } - }) - } -} - -// TestTotalMemoryWithUnknownVRAM tests that unknown VRAM (sentinel value 1) is handled correctly -func TestTotalMemoryWithUnknownVRAM(t *testing.T) { - sysMemInfo := &mockSystemMemoryInfo{ - totalMemory: inference.RequiredMemory{ - RAM: 16 * 1024 * 1024 * 1024, // 16 GB - VRAM: 1, // unknown (sentinel) - }, - } - - totalMem := sysMemInfo.GetTotalMemory() - if totalMem.VRAM != 1 { - t.Errorf("Expected VRAM to be 1 (unknown sentinel), got %d", totalMem.VRAM) - } - - vramStr := formatMemorySize(totalMem.VRAM) - if vramStr != "unknown" { - t.Errorf("Expected VRAM to format as 'unknown', got %q", vramStr) - } - - ramStr := formatMemorySize(totalMem.RAM) - if ramStr == "unknown" { - t.Errorf("Expected RAM to format as numeric value, got %q", ramStr) - } -} - -// TestMemoryCalculation tests memory requirement calculations -func TestMemoryCalculation(t *testing.T) { - sysMemInfo := &mockSystemMemoryInfo{ - totalMemory: inference.RequiredMemory{ - RAM: 2 * 1024 * 1024 * 1024, // 2 GB - VRAM: 4 * 1024 * 1024 * 1024, // 4 GB - }, - } - - totalMem := sysMemInfo.GetTotalMemory() - if totalMem.RAM != 2*1024*1024*1024 { - t.Errorf("Expected RAM to be 2 GB, got %d", totalMem.RAM) - } - if totalMem.VRAM != 4*1024*1024*1024 { - t.Errorf("Expected VRAM to be 4 GB, got %d", totalMem.VRAM) - } - - // Test sufficient memory check - required := inference.RequiredMemory{ - RAM: 1 * 1024 * 1024 * 1024, // 1 GB - VRAM: 2 * 1024 * 1024 * 1024, // 2 GB - } - - sufficient, err := sysMemInfo.HaveSufficientMemory(required) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - if !sufficient { - t.Error("Expected sufficient memory for 1GB RAM / 2GB VRAM on 2GB RAM / 4GB VRAM system") - } - - // Test insufficient memory - tooMuch := inference.RequiredMemory{ - RAM: 3 * 1024 * 1024 * 1024, // 3 GB (more than available) - VRAM: 2 * 1024 * 1024 * 1024, // 2 GB - } - - sufficient, err = sysMemInfo.HaveSufficientMemory(tooMuch) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - if sufficient { - t.Error("Expected insufficient memory for 3GB RAM on 2GB RAM system") - } -} - // TestMakeRunnerKey tests that runner keys are created correctly func TestMakeRunnerKey(t *testing.T) { tests := []struct { @@ -371,17 +228,9 @@ func TestDefunctRunnerEvictionTriggersRetry(t *testing.T) { }, }} - // Create system memory info with exactly 1GB RAM and 1GB VRAM (only enough for one model) - sysMemInfo := &mockSystemMemoryInfo{ - totalMemory: inference.RequiredMemory{ - RAM: 1 * GB, - VRAM: 1 * GB, - }, - } - // Create the loader with minimal dependencies (nil model manager is fine for this test) backends := map[string]inference.Backend{"test-backend": backend} - loader := newLoader(log, backends, nil, nil, sysMemInfo) + loader := newLoader(log, backends, nil, nil) // Enable loads directly under the lock (no background run loop needed) if !loader.lock(context.Background()) { @@ -397,7 +246,7 @@ func TestDefunctRunnerEvictionTriggersRetry(t *testing.T) { defunctRunner := createDefunctMockRunner(log, backend) - // Register the defunct runner in slot 0, consuming all available memory + // Register the defunct runner in slot 0 slot := 0 loader.slots[slot] = defunctRunner loader.runners[makeRunnerKey("test-backend", "model1", "", inference.BackendModeCompletion)] = runnerInfo{ @@ -405,9 +254,6 @@ func TestDefunctRunnerEvictionTriggersRetry(t *testing.T) { modelRef: "model1:latest", } loader.references[slot] = 0 // Mark as unused (so it can be evicted) - loader.allocations[slot] = inference.RequiredMemory{RAM: 1 * GB, VRAM: 1 * GB} - loader.availableMemory.RAM = 0 // All RAM consumed by defunct runner - loader.availableMemory.VRAM = 0 // All VRAM consumed by defunct runner loader.timestamps[slot] = time.Now() loader.unlock() @@ -439,16 +285,8 @@ func TestUnusedRunnerEvictionTriggersRetry(t *testing.T) { }, }} - // System has exactly enough memory for one runner - sysMemInfo := &mockSystemMemoryInfo{ - totalMemory: inference.RequiredMemory{ - RAM: 1 * GB, - VRAM: 1 * GB, - }, - } - backends := map[string]inference.Backend{"test-backend": backend} - loader := newLoader(log, backends, nil, nil, sysMemInfo) + loader := newLoader(log, backends, nil, nil) // Enable loads directly if !loader.lock(context.Background()) { @@ -470,9 +308,6 @@ func TestUnusedRunnerEvictionTriggersRetry(t *testing.T) { modelRef: "modelX:latest", } loader.references[slot] = 0 // unused - loader.allocations[slot] = inference.RequiredMemory{RAM: 1 * GB, VRAM: 1 * GB} - loader.availableMemory.RAM = 0 - loader.availableMemory.VRAM = 0 loader.timestamps[slot] = time.Now() loader.unlock() diff --git a/pkg/inference/scheduling/scheduler.go b/pkg/inference/scheduling/scheduler.go index d75abca5f..1c29defb8 100644 --- a/pkg/inference/scheduling/scheduler.go +++ b/pkg/inference/scheduling/scheduler.go @@ -10,7 +10,6 @@ import ( "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/backends/llamacpp" "github.com/docker/model-runner/pkg/inference/backends/vllm" - "github.com/docker/model-runner/pkg/inference/memory" "github.com/docker/model-runner/pkg/inference/models" "github.com/docker/model-runner/pkg/internal/utils" "github.com/docker/model-runner/pkg/logging" @@ -47,7 +46,6 @@ func NewScheduler( modelManager *models.Manager, httpClient *http.Client, tracker *metrics.Tracker, - sysMemInfo memory.SystemMemoryInfo, ) *Scheduler { openAIRecorder := metrics.NewOpenAIRecorder(log.WithField("component", "openai-recorder"), modelManager) @@ -58,7 +56,7 @@ func NewScheduler( defaultBackend: defaultBackend, modelManager: modelManager, installer: newInstaller(log, backends, httpClient), - loader: newLoader(log, backends, modelManager, openAIRecorder, sysMemInfo), + loader: newLoader(log, backends, modelManager, openAIRecorder), tracker: tracker, openAIRecorder: openAIRecorder, } diff --git a/pkg/inference/scheduling/scheduler_test.go b/pkg/inference/scheduling/scheduler_test.go index be6da393b..7e8e3fd3f 100644 --- a/pkg/inference/scheduling/scheduler_test.go +++ b/pkg/inference/scheduling/scheduler_test.go @@ -6,20 +6,9 @@ import ( "net/http/httptest" "testing" - "github.com/docker/model-runner/pkg/inference" "github.com/sirupsen/logrus" ) -type systemMemoryInfo struct{} - -func (i systemMemoryInfo) HaveSufficientMemory(req inference.RequiredMemory) (bool, error) { - return true, nil -} - -func (i systemMemoryInfo) GetTotalMemory() inference.RequiredMemory { - return inference.RequiredMemory{} -} - func TestCors(t *testing.T) { // Verify that preflight requests work against non-existing handlers or // method-specific handlers that do not support OPTIONS @@ -44,7 +33,7 @@ func TestCors(t *testing.T) { discard := logrus.New() discard.SetOutput(io.Discard) log := logrus.NewEntry(discard) - s := NewScheduler(log, nil, nil, nil, nil, nil, systemMemoryInfo{}) + s := NewScheduler(log, nil, nil, nil, nil, nil) httpHandler := NewHTTPHandler(s, nil, []string{"*"}) req := httptest.NewRequest(http.MethodOptions, "http://model-runner.docker.internal"+tt.path, http.NoBody) req.Header.Set("Origin", "docker.com")