diff --git a/commands/root.go b/commands/root.go index 76ffe1cc..0c45f3d0 100644 --- a/commands/root.go +++ b/commands/root.go @@ -24,7 +24,7 @@ func NewRootCmd() *cobra.Command { fmt.Println("Failed to create Docker client:", err) os.Exit(1) } - desktopClient := desktop.New(dockerClient.HTTPClient()) + desktopClient := desktop.New(dockerClient.HTTPClient(), os.Getenv("DMR_HOST")) rootCmd.AddCommand( newVersionCmd(), newStatusCmd(desktopClient), diff --git a/commands/status_test.go b/commands/status_test.go index f92f3fd6..8714f19d 100644 --- a/commands/status_test.go +++ b/commands/status_test.go @@ -79,12 +79,12 @@ func TestStatus(t *testing.T) { t.Run(test.name, func(t *testing.T) { client := mockdesktop.NewMockDockerHttpClient(ctrl) - req, err := http.NewRequest(http.MethodGet, desktop.URL(inference.ModelsPrefix), nil) + req, err := http.NewRequest(http.MethodGet, desktop.URL(inference.ModelsPrefix, ""), nil) require.NoError(t, err) client.EXPECT().Do(req).Return(test.doResponse, test.doErr) if test.doResponse != nil && test.doResponse.StatusCode == http.StatusOK { - req, err = http.NewRequest(http.MethodGet, desktop.URL(inference.InferencePrefix+"/status"), nil) + req, err = http.NewRequest(http.MethodGet, desktop.URL(inference.InferencePrefix+"/status", ""), nil) require.NoError(t, err) client.EXPECT().Do(req).Return(&http.Response{Body: mockBody}, test.doErr) } @@ -97,7 +97,7 @@ func TestStatus(t *testing.T) { } defer func() { osExit = originalOsExit }() - cmd := newStatusCmd(desktop.New(client)) + cmd := newStatusCmd(desktop.New(client, "")) buf := new(bytes.Buffer) cmd.SetOut(buf) cmd.SetErr(buf) diff --git a/desktop/desktop.go b/desktop/desktop.go index 84bc09f2..d7677764 100644 --- a/desktop/desktop.go +++ b/desktop/desktop.go @@ -34,6 +34,7 @@ func init() { type Client struct { dockerClient DockerHttpClient + dmrHost string } //go:generate mockgen -source=desktop.go -destination=../mocks/mock_desktop.go -package=mocks DockerHttpClient @@ -41,8 +42,11 @@ type DockerHttpClient interface { Do(req *http.Request) (*http.Response, error) } -func New(dockerClient DockerHttpClient) *Client { - return &Client{dockerClient} +func New(dockerClient DockerHttpClient, dmrHost string) *Client { + if dmrHost != "" { + dockerClient = http.DefaultClient + } + return &Client{dockerClient, dmrHost} } type Status struct { @@ -440,13 +444,16 @@ func (c *Client) Remove(models []string, force bool) (string, error) { return modelRemoved, nil } -func URL(path string) string { +func URL(path string, dmrHost string) string { + if dmrHost != "" { + return fmt.Sprintf("%s%s", dmrHost, path) + } return fmt.Sprintf("http://localhost" + inference.ExperimentalEndpointsPrefix + path) } // doRequest is a helper function that performs HTTP requests and handles 503 responses func (c *Client) doRequest(method, path string, body io.Reader) (*http.Response, error) { - req, err := http.NewRequest(method, URL(path), body) + req, err := http.NewRequest(method, URL(path, c.dmrHost), body) if err != nil { return nil, fmt.Errorf("error creating request: %w", err) } diff --git a/desktop/desktop_test.go b/desktop/desktop_test.go index d95fd0f3..db48cca8 100644 --- a/desktop/desktop_test.go +++ b/desktop/desktop_test.go @@ -23,7 +23,7 @@ func TestPullHuggingFaceModel(t *testing.T) { expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) - client := New(mockClient) + client := New(mockClient, "") mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { var reqBody models.ModelCreateRequest @@ -49,7 +49,7 @@ func TestChatHuggingFaceModel(t *testing.T) { prompt := "Hello" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) - client := New(mockClient) + client := New(mockClient, "") mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { var reqBody OpenAIChatRequest @@ -74,7 +74,7 @@ func TestInspectHuggingFaceModel(t *testing.T) { expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) - client := New(mockClient) + client := New(mockClient, "") mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { assert.Contains(t, req.URL.Path, expectedLowercase) @@ -106,7 +106,7 @@ func TestNonHuggingFaceModel(t *testing.T) { // Test case for a non-Hugging Face model (should not be converted to lowercase) modelName := "docker.io/library/llama2" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) - client := New(mockClient) + client := New(mockClient, "") mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { var reqBody models.ModelCreateRequest @@ -131,7 +131,7 @@ func TestPushHuggingFaceModel(t *testing.T) { expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) - client := New(mockClient) + client := New(mockClient, "") mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { assert.Contains(t, req.URL.Path, expectedLowercase) @@ -153,7 +153,7 @@ func TestRemoveHuggingFaceModel(t *testing.T) { expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) - client := New(mockClient) + client := New(mockClient, "") mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { assert.Contains(t, req.URL.Path, expectedLowercase) @@ -177,7 +177,7 @@ func TestTagHuggingFaceModel(t *testing.T) { targetTag := "latest" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) - client := New(mockClient) + client := New(mockClient, "") mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { assert.Contains(t, req.URL.Path, expectedLowercase) @@ -199,7 +199,7 @@ func TestInspectOpenAIHuggingFaceModel(t *testing.T) { expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) - client := New(mockClient) + client := New(mockClient, "") mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { assert.Contains(t, req.URL.Path, expectedLowercase)