Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
3 changes: 0 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ WORKDIR /app
# Copy go mod/sum first for better caching
COPY --link go.mod go.sum ./

# Copy pkg/go-containerregistry for the replace directive in go.mod
COPY --link pkg/go-containerregistry ./pkg/go-containerregistry

# Download dependencies (with cache mounts)
RUN --mount=type=cache,target=/go/pkg/mod \
--mount=type=cache,target=/root/.cache/go-build \
Expand Down
40 changes: 30 additions & 10 deletions cmd/cli/commands/configure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@ func TestConfigureCmdHfOverridesFlag(t *testing.T) {
hfOverridesFlag := cmd.Flags().Lookup("hf_overrides")
if hfOverridesFlag == nil {
t.Fatal("--hf_overrides flag not found")
return // unreachable but satisfies staticcheck SA5011
}

// Get values to avoid potential nil dereference flagged by linter
defValue := hfOverridesFlag.DefValue

// Verify the default value is empty
if hfOverridesFlag.DefValue != "" {
t.Errorf("Expected default hf_overrides value to be empty, got '%s'", hfOverridesFlag.DefValue)
if defValue != "" {
t.Errorf("Expected default hf_overrides value to be empty, got '%s'", defValue)
}

// Verify the flag type
Expand All @@ -33,11 +37,15 @@ func TestConfigureCmdContextSizeFlag(t *testing.T) {
contextSizeFlag := cmd.Flags().Lookup("context-size")
if contextSizeFlag == nil {
t.Fatal("--context-size flag not found")
return // unreachable but satisfies staticcheck SA5011
}

// Get values to avoid potential nil dereference flagged by linter
defValue := contextSizeFlag.DefValue

// Verify the default value is empty (nil pointer)
if contextSizeFlag.DefValue != "" {
t.Errorf("Expected default context-size value to be '' (nil), got '%s'", contextSizeFlag.DefValue)
if defValue != "" {
t.Errorf("Expected default context-size value to be '' (nil), got '%s'", defValue)
}

// Test setting the flag value
Expand Down Expand Up @@ -83,11 +91,15 @@ func TestConfigureCmdModeFlag(t *testing.T) {
modeFlag := cmd.Flags().Lookup("mode")
if modeFlag == nil {
t.Fatal("--mode flag not found")
return // unreachable but satisfies staticcheck SA5011
}

// Get values to avoid potential nil dereference flagged by linter
defValue := modeFlag.DefValue

// Verify the default value is empty
if modeFlag.DefValue != "" {
t.Errorf("Expected default mode value to be empty, got '%s'", modeFlag.DefValue)
if defValue != "" {
t.Errorf("Expected default mode value to be empty, got '%s'", defValue)
}

// Verify the flag type
Expand All @@ -104,11 +116,15 @@ func TestConfigureCmdThinkFlag(t *testing.T) {
thinkFlag := cmd.Flags().Lookup("think")
if thinkFlag == nil {
t.Fatal("--think flag not found")
return // unreachable but satisfies staticcheck SA5011
}

// Get values to avoid potential nil dereference flagged by linter
defValue := thinkFlag.DefValue

// Verify the default value is empty
if thinkFlag.DefValue != "" {
t.Errorf("Expected default think value to be empty (nil), got '%s'", thinkFlag.DefValue)
if defValue != "" {
t.Errorf("Expected default think value to be empty (nil), got '%s'", defValue)
}

// Verify the flag type
Expand Down Expand Up @@ -136,11 +152,15 @@ func TestConfigureCmdGPUMemoryUtilizationFlag(t *testing.T) {
gpuMemFlag := cmd.Flags().Lookup("gpu-memory-utilization")
if gpuMemFlag == nil {
t.Fatal("--gpu-memory-utilization flag not found")
return // unreachable but satisfies staticcheck SA5011
}

// Get values to avoid potential nil dereference flagged by linter
defValue := gpuMemFlag.DefValue

// Verify the default value is empty (nil pointer)
if gpuMemFlag.DefValue != "" {
t.Errorf("Expected default gpu-memory-utilization value to be '' (nil), got '%s'", gpuMemFlag.DefValue)
if defValue != "" {
t.Errorf("Expected default gpu-memory-utilization value to be '' (nil), got '%s'", defValue)
}

// Verify the flag type
Expand Down
16 changes: 12 additions & 4 deletions cmd/cli/commands/install-runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@ func TestInstallRunnerHostFlag(t *testing.T) {
hostFlag := cmd.Flags().Lookup("host")
if hostFlag == nil {
t.Fatal("--host flag not found")
return // unreachable but satisfies staticcheck SA5011
}

// Get values to avoid potential nil dereference flagged by linter
defValue := hostFlag.DefValue

// Verify the default value
if hostFlag.DefValue != "127.0.0.1" {
t.Errorf("Expected default host value to be '127.0.0.1', got '%s'", hostFlag.DefValue)
if defValue != "127.0.0.1" {
t.Errorf("Expected default host value to be '127.0.0.1', got '%s'", defValue)
}

// Verify the flag type
Expand Down Expand Up @@ -77,11 +81,15 @@ func TestInstallRunnerBackendFlag(t *testing.T) {
backendFlag := cmd.Flags().Lookup("backend")
if backendFlag == nil {
t.Fatal("--backend flag not found")
return // unreachable but satisfies staticcheck SA5011
}

// Get values to avoid potential nil dereference flagged by linter
defValue := backendFlag.DefValue

// Verify the default value
if backendFlag.DefValue != "" {
t.Errorf("Expected default backend value to be empty, got '%s'", backendFlag.DefValue)
if defValue != "" {
t.Errorf("Expected default backend value to be empty, got '%s'", defValue)
}

// Verify the flag type
Expand Down
141 changes: 132 additions & 9 deletions cmd/cli/commands/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/docker/model-runner/cmd/cli/desktop"
"github.com/docker/model-runner/cmd/cli/pkg/types"
"github.com/docker/model-runner/pkg/distribution/builder"
"github.com/docker/model-runner/pkg/distribution/oci/reference"
"github.com/docker/model-runner/pkg/distribution/registry"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
Expand Down Expand Up @@ -110,6 +111,11 @@ func generateReferenceTestCases(info modelInfo) []referenceTestCase {
func setupTestEnv(t *testing.T) *testEnv {
ctx := context.Background()

// Set environment variables for the test process to match the DMR container.
// This ensures CLI functions use the same default registry when parsing references.
t.Setenv("DEFAULT_REGISTRY", "registry.local:5000")
t.Setenv("INSECURE_REGISTRY", "true")

// Create a custom network for container communication
net, err := network.New(ctx)
require.NoError(t, err)
Expand Down Expand Up @@ -149,16 +155,26 @@ func ociRegistry(t *testing.T, ctx context.Context, net *testcontainers.DockerNe
return registryURL
}

func dockerModelRunner(t *testing.T, ctx context.Context, net *testcontainers.DockerNetwork) string {
// dmrConfig holds configuration options for Docker Model Runner container.
type dmrConfig struct {
envVars map[string]string // Optional environment variables to set
logMsg string // Custom log message (defaults to "Starting DMR container...")
}

// startDockerModelRunner starts a DMR container with the given configuration.
// If config.envVars is nil or empty, no extra environment variables are set.
func startDockerModelRunner(t *testing.T, ctx context.Context, net *testcontainers.DockerNetwork, config dmrConfig) string {
containerCustomizerOpts := []testcontainers.ContainerCustomizer{
testcontainers.WithExposedPorts("12434/tcp"),
testcontainers.WithWaitStrategy(wait.ForHTTP("/engines/status").WithPort("12434/tcp").WithStartupTimeout(10 * time.Second)),
testcontainers.WithEnv(map[string]string{
"DEFAULT_REGISTRY": "registry.local:5000",
"INSECURE_REGISTRY": "true",
}),
network.WithNetwork([]string{"dmr"}, net),
}

// Add environment variables if provided
if len(config.envVars) > 0 {
containerCustomizerOpts = append(containerCustomizerOpts, testcontainers.WithEnv(config.envVars))
}

if os.Getenv("BUILD_DMR") == "1" {
t.Log("Building DMR container...")
out, err := exec.CommandContext(ctx, "make", "-C", "../../..", "docker-build").CombinedOutput()
Expand All @@ -169,7 +185,13 @@ func dockerModelRunner(t *testing.T, ctx context.Context, net *testcontainers.Do
// Always pull the image if it's not build locally.
containerCustomizerOpts = append(containerCustomizerOpts, testcontainers.WithAlwaysPull())
}
t.Log("Starting DMR container...")

logMsg := config.logMsg
if logMsg == "" {
logMsg = "Starting DMR container..."
}
t.Log(logMsg)

ctr, err := testcontainers.Run(
ctx, "docker/model-runner:latest",
containerCustomizerOpts...,
Expand All @@ -185,6 +207,17 @@ func dockerModelRunner(t *testing.T, ctx context.Context, net *testcontainers.Do
return dmrURL
}

// dockerModelRunner starts a DMR container configured for local registry tests.
// Sets DEFAULT_REGISTRY and INSECURE_REGISTRY environment variables.
func dockerModelRunner(t *testing.T, ctx context.Context, net *testcontainers.DockerNetwork) string {
return startDockerModelRunner(t, ctx, net, dmrConfig{
envVars: map[string]string{
"DEFAULT_REGISTRY": "registry.local:5000",
"INSECURE_REGISTRY": "true",
},
})
}

// removeModel removes a model from the local store
func removeModel(client *desktop.Client, modelID string, force bool) error {
_, err := client.Remove([]string{modelID}, force)
Expand Down Expand Up @@ -1037,7 +1070,7 @@ func TestIntegration_PackageModel(t *testing.T) {
model, err := env.client.Inspect(targetTag, false)
require.NoError(t, err, "Failed to inspect packaged model by tag: %s", targetTag)
require.NotEmpty(t, model.ID, "Model ID should not be empty")
require.Contains(t, model.Tags, targetTag, "Model should have the expected tag")
require.Contains(t, model.Tags, normalizeRef(t, targetTag), "Model should have the expected tag")

t.Logf("✓ Successfully packaged and tagged model: %s (ID: %s)", targetTag, model.ID[7:19])

Expand Down Expand Up @@ -1070,7 +1103,7 @@ func TestIntegration_PackageModel(t *testing.T) {
// Verify the model was loaded and tagged
model, err := env.client.Inspect(targetTag, false)
require.NoError(t, err, "Failed to inspect packaged model")
require.Contains(t, model.Tags, targetTag, "Model should have the expected tag")
require.Contains(t, model.Tags, normalizeRef(t, targetTag), "Model should have the expected tag")

t.Logf("✓ Successfully packaged model with context size: %s", targetTag)

Expand Down Expand Up @@ -1100,7 +1133,7 @@ func TestIntegration_PackageModel(t *testing.T) {
// Verify the model was loaded and tagged
model, err := env.client.Inspect(targetTag, false)
require.NoError(t, err, "Failed to inspect packaged model")
require.Contains(t, model.Tags, targetTag, "Model should have the expected tag")
require.Contains(t, model.Tags, normalizeRef(t, targetTag), "Model should have the expected tag")

t.Logf("✓ Successfully packaged model with custom org: %s", targetTag)

Expand All @@ -1118,3 +1151,93 @@ func TestIntegration_PackageModel(t *testing.T) {
func int32ptr(n int32) *int32 {
return &n
}

// setupDockerHubTestEnv creates a test environment for Docker Hub tests.
// Unlike setupTestEnv, this does NOT set DEFAULT_REGISTRY, so it uses
// the real Docker Hub (index.docker.io) as the default registry.
// This is used to test that pulling from Docker Hub works correctly.
func setupDockerHubTestEnv(t *testing.T) *testEnv {
ctx := context.Background()

// Create a custom network for container communication
net, err := network.New(ctx)
require.NoError(t, err)
testcontainers.CleanupNetwork(t, net)

// dockerModelRunnerForDockerHub starts a DMR container configured for Docker Hub tests.
// it uses the real Docker Hub as the default registry.
dmrURL := startDockerModelRunner(t, ctx, net, dmrConfig{
logMsg: "Starting DMR container for Docker Hub tests (no DEFAULT_REGISTRY)...",
})

modelRunnerCtx, err := desktop.NewContextForTest(dmrURL, nil, types.ModelRunnerEngineKindMoby)
require.NoError(t, err, "Failed to create model runner context")

client := desktop.New(modelRunnerCtx)
if !client.Status().Running {
t.Fatal("DMR is not running")
}

return &testEnv{
ctx: ctx,
client: client,
net: net,
}
}

// TestIntegration_PullFromDockerHub is a smoke test that pulls a real model
// from Docker Hub to verify that the OCI registry code works correctly
// with the real Docker Hub registry (index.docker.io -> registry-1.docker.io).
//
// This test catches regressions where the code doesn't properly handle
// Docker Hub's hostname remapping requirements.
func TestIntegration_PullFromDockerHub(t *testing.T) {
env := setupDockerHubTestEnv(t)

// Ensure no models exist initially
models, err := listModels(false, env.client, true, false, "")
require.NoError(t, err)
if len(models) != 0 {
t.Fatal("Expected no initial models, but found some")
}

// Pull a small model from Docker Hub
// ai/smollm2:135M-Q4_0 is a small model that's quick to download
modelRef := "ai/smollm2:135M-Q4_0"
t.Logf("Pulling model from Docker Hub: %s", modelRef)

err = pullModel(newPullCmd(), env.client, modelRef)
require.NoError(t, err, "Failed to pull model from Docker Hub: %s", modelRef)

// Verify the model was pulled
t.Log("Verifying model was pulled successfully")
models, err = listModels(false, env.client, true, false, "")
require.NoError(t, err)
require.NotEmpty(t, strings.TrimSpace(models), "Model should exist after pull from Docker Hub")

// Verify we can inspect the model
model, err := env.client.Inspect(modelRef, false)
require.NoError(t, err, "Failed to inspect model pulled from Docker Hub")
require.NotEmpty(t, model.ID, "Model ID should not be empty")

t.Logf("✓ Successfully pulled model from Docker Hub: %s (ID: %s)", modelRef, model.ID[7:19])

// Cleanup: remove the model
t.Logf("Cleaning up: removing model %s", model.ID[7:19])
err = removeModel(env.client, model.ID, true)
require.NoError(t, err, "Failed to remove model")

// Verify model was removed
models, err = listModels(false, env.client, true, false, "")
require.NoError(t, err)
require.Empty(t, strings.TrimSpace(models), "Model should be removed after cleanup")
}

// normalizeRef normalizes a reference to its fully qualified form.
// This is used in tests to compare against the stored tags which are always normalized.
func normalizeRef(t *testing.T, ref string) string {
t.Helper()
parsed, err := reference.ParseReference(ref, registry.GetDefaultRegistryOptions()...)
require.NoError(t, err, "Failed to parse reference: %s", ref)
return parsed.String()
}
8 changes: 4 additions & 4 deletions cmd/cli/commands/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ import (
"github.com/docker/model-runner/cmd/cli/desktop"
"github.com/docker/model-runner/pkg/distribution/builder"
"github.com/docker/model-runner/pkg/distribution/distribution"
"github.com/docker/model-runner/pkg/distribution/oci/reference"
"github.com/docker/model-runner/pkg/distribution/packaging"
"github.com/docker/model-runner/pkg/distribution/registry"
"github.com/docker/model-runner/pkg/distribution/tarball"
"github.com/docker/model-runner/pkg/distribution/types"
"github.com/docker/model-runner/pkg/go-containerregistry/pkg/name"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -434,7 +434,7 @@ func packageModel(ctx context.Context, cmd *cobra.Command, client *desktop.Clien
// modelRunnerTarget loads model to Docker Model Runner via models/load endpoint
type modelRunnerTarget struct {
client *desktop.Client
tag name.Tag
tag *reference.Tag
}

func newModelRunnerTarget(client *desktop.Client, tag string) (*modelRunnerTarget, error) {
Expand All @@ -443,7 +443,7 @@ func newModelRunnerTarget(client *desktop.Client, tag string) (*modelRunnerTarge
}
if tag != "" {
var err error
target.tag, err = name.NewTag(tag)
target.tag, err = reference.NewTag(tag, registry.GetDefaultRegistryOptions()...)
if err != nil {
return nil, fmt.Errorf("invalid tag: %w", err)
}
Expand Down Expand Up @@ -477,7 +477,7 @@ func (t *modelRunnerTarget) Write(ctx context.Context, mdl types.ModelArtifact,
if err != nil {
return fmt.Errorf("get model ID: %w", err)
}
if t.tag.String() != "" {
if t.tag != nil {
if err := t.client.Tag(id, parseRepo(t.tag), t.tag.TagStr()); err != nil {
return fmt.Errorf("tag model: %w", err)
}
Expand Down
Loading
Loading