diff --git a/desktop/desktop.go b/desktop/desktop.go index 6891ec97..84bc09f2 100644 --- a/desktop/desktop.go +++ b/desktop/desktop.go @@ -5,16 +5,17 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/docker/pinata/common/pkg/inference" - "github.com/docker/pinata/common/pkg/inference/models" - "github.com/docker/pinata/common/pkg/paths" - "github.com/pkg/errors" - "go.opentelemetry.io/otel" "html" "io" "net/http" "strconv" "strings" + + "github.com/docker/pinata/common/pkg/inference" + "github.com/docker/pinata/common/pkg/inference/models" + "github.com/docker/pinata/common/pkg/paths" + "github.com/pkg/errors" + "go.opentelemetry.io/otel" ) var ( @@ -50,6 +51,14 @@ type Status struct { Error error `json:"error"` } +// normalizeHuggingFaceModelName converts Hugging Face model names to lowercase +func normalizeHuggingFaceModelName(model string) string { + if strings.HasPrefix(model, "hf.co/") { + return strings.ToLower(model) + } + return model +} + func (c *Client) Status() Status { // TODO: Query "/". resp, err := c.doRequest(http.MethodGet, inference.ModelsPrefix, nil) @@ -92,6 +101,7 @@ func (c *Client) Status() Status { } func (c *Client) Pull(model string, progress func(string)) (string, bool, error) { + model = normalizeHuggingFaceModelName(model) jsonData, err := json.Marshal(models.ModelCreateRequest{From: model}) if err != nil { return "", false, fmt.Errorf("error marshaling request: %w", err) @@ -147,6 +157,7 @@ func (c *Client) Pull(model string, progress func(string)) (string, bool, error) } func (c *Client) Push(model string, progress func(string)) (string, bool, error) { + model = normalizeHuggingFaceModelName(model) pushPath := inference.ModelsPrefix + "/" + model + "/push" resp, err := c.doRequest( http.MethodPost, @@ -225,6 +236,7 @@ func (c *Client) ListOpenAI() (OpenAIModelList, error) { } func (c *Client) Inspect(model string) (Model, error) { + model = normalizeHuggingFaceModelName(model) if model != "" { if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. @@ -248,6 +260,7 @@ func (c *Client) Inspect(model string) (Model, error) { } func (c *Client) InspectOpenAI(model string) (OpenAIModel, error) { + model = normalizeHuggingFaceModelName(model) modelsRoute := inference.InferencePrefix + "/v1/models" if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. @@ -310,6 +323,7 @@ func (c *Client) fullModelID(id string) (string, error) { } func (c *Client) Chat(model, prompt string) error { + model = normalizeHuggingFaceModelName(model) if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. if expanded, err := c.fullModelID(model); err == nil { @@ -387,6 +401,7 @@ func (c *Client) Chat(model, prompt string) error { func (c *Client) Remove(models []string, force bool) (string, error) { modelRemoved := "" for _, model := range models { + model = normalizeHuggingFaceModelName(model) // Check if not a model ID passed as parameter. if !strings.Contains(model, "/") { if expanded, err := c.fullModelID(model); err == nil { @@ -460,6 +475,7 @@ func (c *Client) handleQueryError(err error, path string) error { } func (c *Client) Tag(source, targetRepo, targetTag string) (string, error) { + source = normalizeHuggingFaceModelName(source) // Check if the source is a model ID, and expand it if necessary if !strings.Contains(strings.Trim(source, "/"), "/") { // Do an extra API call to check if the model parameter might be a model ID diff --git a/desktop/desktop_test.go b/desktop/desktop_test.go new file mode 100644 index 00000000..d95fd0f3 --- /dev/null +++ b/desktop/desktop_test.go @@ -0,0 +1,219 @@ +package desktop + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "testing" + + mockdesktop "github.com/docker/model-cli/mocks" + "github.com/docker/pinata/common/pkg/inference/models" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func TestPullHuggingFaceModel(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Test case for pulling a Hugging Face model with mixed case + modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" + + mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) + client := New(mockClient) + + mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { + var reqBody models.ModelCreateRequest + err := json.NewDecoder(req.Body).Decode(&reqBody) + require.NoError(t, err) + assert.Equal(t, expectedLowercase, reqBody.From) + }).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)), + }, nil) + + _, _, err := client.Pull(modelName, func(s string) {}) + assert.NoError(t, err) +} + +func TestChatHuggingFaceModel(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Test case for chatting with a Hugging Face model with mixed case + modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" + prompt := "Hello" + + mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) + client := New(mockClient) + + mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { + var reqBody OpenAIChatRequest + err := json.NewDecoder(req.Body).Decode(&reqBody) + require.NoError(t, err) + assert.Equal(t, expectedLowercase, reqBody.Model) + }).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString("data: {\"choices\":[{\"delta\":{\"content\":\"Hello there!\"}}]}\n")), + }, nil) + + err := client.Chat(modelName, prompt) + assert.NoError(t, err) +} + +func TestInspectHuggingFaceModel(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Test case for inspecting a Hugging Face model with mixed case + modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" + + mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) + client := New(mockClient) + + mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { + assert.Contains(t, req.URL.Path, expectedLowercase) + }).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{ + "id": "sha256:123456789012", + "tags": ["` + expectedLowercase + `"], + "created": 1234567890, + "config": { + "format": "gguf", + "quantization": "Q4_K_M", + "parameters": "1B", + "architecture": "llama", + "size": "1.2GB" + } + }`)), + }, nil) + + model, err := client.Inspect(modelName) + assert.NoError(t, err) + assert.Equal(t, expectedLowercase, model.Tags[0]) +} + +func TestNonHuggingFaceModel(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Test case for a non-Hugging Face model (should not be converted to lowercase) + modelName := "docker.io/library/llama2" + mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) + client := New(mockClient) + + mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { + var reqBody models.ModelCreateRequest + err := json.NewDecoder(req.Body).Decode(&reqBody) + require.NoError(t, err) + assert.Equal(t, modelName, reqBody.From) + }).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)), + }, nil) + + _, _, err := client.Pull(modelName, func(s string) {}) + assert.NoError(t, err) +} + +func TestPushHuggingFaceModel(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Test case for pushing a Hugging Face model with mixed case + modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" + + mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) + client := New(mockClient) + + mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { + assert.Contains(t, req.URL.Path, expectedLowercase) + }).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pushed successfully"}`)), + }, nil) + + _, _, err := client.Push(modelName, func(s string) {}) + assert.NoError(t, err) +} + +func TestRemoveHuggingFaceModel(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Test case for removing a Hugging Face model with mixed case + modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" + + mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) + client := New(mockClient) + + mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { + assert.Contains(t, req.URL.Path, expectedLowercase) + }).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString("Model removed successfully")), + }, nil) + + _, err := client.Remove([]string{modelName}, false) + assert.NoError(t, err) +} + +func TestTagHuggingFaceModel(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Test case for tagging a Hugging Face model with mixed case + sourceModel := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" + targetRepo := "myrepo" + targetTag := "latest" + + mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) + client := New(mockClient) + + mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { + assert.Contains(t, req.URL.Path, expectedLowercase) + }).Return(&http.Response{ + StatusCode: http.StatusCreated, + Body: io.NopCloser(bytes.NewBufferString("Tag created successfully")), + }, nil) + + _, err := client.Tag(sourceModel, targetRepo, targetTag) + assert.NoError(t, err) +} + +func TestInspectOpenAIHuggingFaceModel(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Test case for inspecting a Hugging Face model with mixed case + modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" + + mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) + client := New(mockClient) + + mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { + assert.Contains(t, req.URL.Path, expectedLowercase) + }).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{ + "id": "` + expectedLowercase + `", + "object": "model", + "created": 1234567890, + "owned_by": "organization" + }`)), + }, nil) + + model, err := client.InspectOpenAI(modelName) + assert.NoError(t, err) + assert.Equal(t, expectedLowercase, model.ID) +}