From 1a4488b8ffc802aee48715a35d5c6e207ddd68a6 Mon Sep 17 00:00:00 2001 From: ilopezluna Date: Tue, 29 Apr 2025 13:54:57 +0200 Subject: [PATCH 1/4] Add case conversion for Hugging Face model names --- desktop/desktop.go | 25 ++++- desktop/desktop_test.go | 204 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 224 insertions(+), 5 deletions(-) create mode 100644 desktop/desktop_test.go diff --git a/desktop/desktop.go b/desktop/desktop.go index 6891ec97..b43b749c 100644 --- a/desktop/desktop.go +++ b/desktop/desktop.go @@ -5,16 +5,17 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/docker/pinata/common/pkg/inference" - "github.com/docker/pinata/common/pkg/inference/models" - "github.com/docker/pinata/common/pkg/paths" - "github.com/pkg/errors" - "go.opentelemetry.io/otel" "html" "io" "net/http" "strconv" "strings" + + "github.com/docker/pinata/common/pkg/inference" + "github.com/docker/pinata/common/pkg/inference/models" + "github.com/docker/pinata/common/pkg/paths" + "github.com/pkg/errors" + "go.opentelemetry.io/otel" ) var ( @@ -50,6 +51,14 @@ type Status struct { Error error `json:"error"` } +// normalizeModelName converts Hugging Face model names to lowercase +func normalizeModelName(model string) string { + if strings.HasPrefix(model, "hf.co/") { + return strings.ToLower(model) + } + return model +} + func (c *Client) Status() Status { // TODO: Query "/". resp, err := c.doRequest(http.MethodGet, inference.ModelsPrefix, nil) @@ -92,6 +101,7 @@ func (c *Client) Status() Status { } func (c *Client) Pull(model string, progress func(string)) (string, bool, error) { + model = normalizeModelName(model) jsonData, err := json.Marshal(models.ModelCreateRequest{From: model}) if err != nil { return "", false, fmt.Errorf("error marshaling request: %w", err) @@ -147,6 +157,7 @@ func (c *Client) Pull(model string, progress func(string)) (string, bool, error) } func (c *Client) Push(model string, progress func(string)) (string, bool, error) { + model = normalizeModelName(model) pushPath := inference.ModelsPrefix + "/" + model + "/push" resp, err := c.doRequest( http.MethodPost, @@ -225,6 +236,7 @@ func (c *Client) ListOpenAI() (OpenAIModelList, error) { } func (c *Client) Inspect(model string) (Model, error) { + model = normalizeModelName(model) if model != "" { if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. @@ -310,6 +322,7 @@ func (c *Client) fullModelID(id string) (string, error) { } func (c *Client) Chat(model, prompt string) error { + model = normalizeModelName(model) if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. if expanded, err := c.fullModelID(model); err == nil { @@ -387,6 +400,7 @@ func (c *Client) Chat(model, prompt string) error { func (c *Client) Remove(models []string, force bool) (string, error) { modelRemoved := "" for _, model := range models { + model = normalizeModelName(model) // Check if not a model ID passed as parameter. if !strings.Contains(model, "/") { if expanded, err := c.fullModelID(model); err == nil { @@ -460,6 +474,7 @@ func (c *Client) handleQueryError(err error, path string) error { } func (c *Client) Tag(source, targetRepo, targetTag string) (string, error) { + source = normalizeModelName(source) // Check if the source is a model ID, and expand it if necessary if !strings.Contains(strings.Trim(source, "/"), "/") { // Do an extra API call to check if the model parameter might be a model ID diff --git a/desktop/desktop_test.go b/desktop/desktop_test.go new file mode 100644 index 00000000..8aea6c4a --- /dev/null +++ b/desktop/desktop_test.go @@ -0,0 +1,204 @@ +package desktop + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "testing" + + "github.com/docker/pinata/common/pkg/inference/models" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockHTTPClient struct { + doFunc func(req *http.Request) (*http.Response, error) +} + +func (m *mockHTTPClient) Do(req *http.Request) (*http.Response, error) { + return m.doFunc(req) +} + +func TestPullHuggingFaceModel(t *testing.T) { + // Test case for pulling a Hugging Face model with mixed case + modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" + + client := &Client{ + dockerClient: &mockHTTPClient{ + doFunc: func(req *http.Request) (*http.Response, error) { + // Verify the model name is converted to lowercase in the request + var reqBody models.ModelCreateRequest + err := json.NewDecoder(req.Body).Decode(&reqBody) + require.NoError(t, err) + assert.Equal(t, expectedLowercase, reqBody.From) + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)), + }, nil + }, + }, + } + + _, _, err := client.Pull(modelName, func(s string) {}) + assert.NoError(t, err) +} + +func TestChatHuggingFaceModel(t *testing.T) { + // Test case for chatting with a Hugging Face model with mixed case + modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" + prompt := "Hello" + + client := &Client{ + dockerClient: &mockHTTPClient{ + doFunc: func(req *http.Request) (*http.Response, error) { + // Verify the model name is converted to lowercase in the request + var reqBody OpenAIChatRequest + err := json.NewDecoder(req.Body).Decode(&reqBody) + require.NoError(t, err) + assert.Equal(t, expectedLowercase, reqBody.Model) + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString("data: {\"choices\":[{\"delta\":{\"content\":\"Hello there!\"}}]}\n")), + }, nil + }, + }, + } + + err := client.Chat(modelName, prompt) + assert.NoError(t, err) +} + +func TestInspectHuggingFaceModel(t *testing.T) { + // Test case for inspecting a Hugging Face model with mixed case + modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" + + client := &Client{ + dockerClient: &mockHTTPClient{ + doFunc: func(req *http.Request) (*http.Response, error) { + // Verify the model name is converted to lowercase in the request URL + assert.Contains(t, req.URL.Path, expectedLowercase) + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{ + "id": "sha256:123456789012", + "tags": ["` + expectedLowercase + `"], + "created": 1234567890, + "config": { + "format": "gguf", + "quantization": "Q4_K_M", + "parameters": "1B", + "architecture": "llama", + "size": "1.2GB" + } + }`)), + }, nil + }, + }, + } + + model, err := client.Inspect(modelName) + assert.NoError(t, err) + assert.Equal(t, expectedLowercase, model.Tags[0]) +} + +func TestNonHuggingFaceModel(t *testing.T) { + // Test case for a non-Hugging Face model (should not be converted to lowercase) + modelName := "docker.io/library/llama2" + client := &Client{ + dockerClient: &mockHTTPClient{ + doFunc: func(req *http.Request) (*http.Response, error) { + // Verify the model name is not converted to lowercase + var reqBody models.ModelCreateRequest + err := json.NewDecoder(req.Body).Decode(&reqBody) + require.NoError(t, err) + assert.Equal(t, modelName, reqBody.From) + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)), + }, nil + }, + }, + } + + _, _, err := client.Pull(modelName, func(s string) {}) + assert.NoError(t, err) +} + +func TestPushHuggingFaceModel(t *testing.T) { + // Test case for pushing a Hugging Face model with mixed case + modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" + + client := &Client{ + dockerClient: &mockHTTPClient{ + doFunc: func(req *http.Request) (*http.Response, error) { + // Verify the model name is converted to lowercase in the request URL + assert.Contains(t, req.URL.Path, expectedLowercase) + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pushed successfully"}`)), + }, nil + }, + }, + } + + _, _, err := client.Push(modelName, func(s string) {}) + assert.NoError(t, err) +} + +func TestRemoveHuggingFaceModel(t *testing.T) { + // Test case for removing a Hugging Face model with mixed case + modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" + + client := &Client{ + dockerClient: &mockHTTPClient{ + doFunc: func(req *http.Request) (*http.Response, error) { + // Verify the model name is converted to lowercase in the request URL + assert.Contains(t, req.URL.Path, expectedLowercase) + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString("Model removed successfully")), + }, nil + }, + }, + } + + _, err := client.Remove([]string{modelName}, false) + assert.NoError(t, err) +} + +func TestTagHuggingFaceModel(t *testing.T) { + // Test case for tagging a Hugging Face model with mixed case + sourceModel := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" + targetRepo := "myrepo" + targetTag := "latest" + + client := &Client{ + dockerClient: &mockHTTPClient{ + doFunc: func(req *http.Request) (*http.Response, error) { + // Verify the model name is converted to lowercase in the request URL + assert.Contains(t, req.URL.Path, expectedLowercase) + + return &http.Response{ + StatusCode: http.StatusCreated, + Body: io.NopCloser(bytes.NewBufferString("Tag created successfully")), + }, nil + }, + }, + } + + _, err := client.Tag(sourceModel, targetRepo, targetTag) + assert.NoError(t, err) +} From 2ecd43dbad12b7ed1df7510f78001a114029f2b6 Mon Sep 17 00:00:00 2001 From: ilopezluna Date: Tue, 29 Apr 2025 15:40:58 +0200 Subject: [PATCH 2/4] Using existing desktop mock --- desktop/desktop_test.go | 221 +++++++++++++++++++--------------------- 1 file changed, 104 insertions(+), 117 deletions(-) diff --git a/desktop/desktop_test.go b/desktop/desktop_test.go index 8aea6c4a..ac5bc355 100644 --- a/desktop/desktop_test.go +++ b/desktop/desktop_test.go @@ -7,101 +7,92 @@ import ( "net/http" "testing" + mockdesktop "github.com/docker/model-cli/mocks" "github.com/docker/pinata/common/pkg/inference/models" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" ) -type mockHTTPClient struct { - doFunc func(req *http.Request) (*http.Response, error) -} - -func (m *mockHTTPClient) Do(req *http.Request) (*http.Response, error) { - return m.doFunc(req) -} - func TestPullHuggingFaceModel(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + // Test case for pulling a Hugging Face model with mixed case modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" - client := &Client{ - dockerClient: &mockHTTPClient{ - doFunc: func(req *http.Request) (*http.Response, error) { - // Verify the model name is converted to lowercase in the request - var reqBody models.ModelCreateRequest - err := json.NewDecoder(req.Body).Decode(&reqBody) - require.NoError(t, err) - assert.Equal(t, expectedLowercase, reqBody.From) - - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)), - }, nil - }, - }, - } + mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) + client := New(mockClient) + + mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { + var reqBody models.ModelCreateRequest + err := json.NewDecoder(req.Body).Decode(&reqBody) + require.NoError(t, err) + assert.Equal(t, expectedLowercase, reqBody.From) + }).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)), + }, nil) _, _, err := client.Pull(modelName, func(s string) {}) assert.NoError(t, err) } func TestChatHuggingFaceModel(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + // Test case for chatting with a Hugging Face model with mixed case modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" prompt := "Hello" - client := &Client{ - dockerClient: &mockHTTPClient{ - doFunc: func(req *http.Request) (*http.Response, error) { - // Verify the model name is converted to lowercase in the request - var reqBody OpenAIChatRequest - err := json.NewDecoder(req.Body).Decode(&reqBody) - require.NoError(t, err) - assert.Equal(t, expectedLowercase, reqBody.Model) - - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString("data: {\"choices\":[{\"delta\":{\"content\":\"Hello there!\"}}]}\n")), - }, nil - }, - }, - } + mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) + client := New(mockClient) + + mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { + var reqBody OpenAIChatRequest + err := json.NewDecoder(req.Body).Decode(&reqBody) + require.NoError(t, err) + assert.Equal(t, expectedLowercase, reqBody.Model) + }).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString("data: {\"choices\":[{\"delta\":{\"content\":\"Hello there!\"}}]}\n")), + }, nil) err := client.Chat(modelName, prompt) assert.NoError(t, err) } func TestInspectHuggingFaceModel(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + // Test case for inspecting a Hugging Face model with mixed case modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" - client := &Client{ - dockerClient: &mockHTTPClient{ - doFunc: func(req *http.Request) (*http.Response, error) { - // Verify the model name is converted to lowercase in the request URL - assert.Contains(t, req.URL.Path, expectedLowercase) - - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString(`{ - "id": "sha256:123456789012", - "tags": ["` + expectedLowercase + `"], - "created": 1234567890, - "config": { - "format": "gguf", - "quantization": "Q4_K_M", - "parameters": "1B", - "architecture": "llama", - "size": "1.2GB" - } - }`)), - }, nil - }, - }, - } + mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) + client := New(mockClient) + + mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { + assert.Contains(t, req.URL.Path, expectedLowercase) + }).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{ + "id": "sha256:123456789012", + "tags": ["` + expectedLowercase + `"], + "created": 1234567890, + "config": { + "format": "gguf", + "quantization": "Q4_K_M", + "parameters": "1B", + "architecture": "llama", + "size": "1.2GB" + } + }`)), + }, nil) model, err := client.Inspect(modelName) assert.NoError(t, err) @@ -109,95 +100,91 @@ func TestInspectHuggingFaceModel(t *testing.T) { } func TestNonHuggingFaceModel(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + // Test case for a non-Hugging Face model (should not be converted to lowercase) modelName := "docker.io/library/llama2" - client := &Client{ - dockerClient: &mockHTTPClient{ - doFunc: func(req *http.Request) (*http.Response, error) { - // Verify the model name is not converted to lowercase - var reqBody models.ModelCreateRequest - err := json.NewDecoder(req.Body).Decode(&reqBody) - require.NoError(t, err) - assert.Equal(t, modelName, reqBody.From) - - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)), - }, nil - }, - }, - } + mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) + client := New(mockClient) + + mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { + var reqBody models.ModelCreateRequest + err := json.NewDecoder(req.Body).Decode(&reqBody) + require.NoError(t, err) + assert.Equal(t, modelName, reqBody.From) + }).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)), + }, nil) _, _, err := client.Pull(modelName, func(s string) {}) assert.NoError(t, err) } func TestPushHuggingFaceModel(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + // Test case for pushing a Hugging Face model with mixed case modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" - client := &Client{ - dockerClient: &mockHTTPClient{ - doFunc: func(req *http.Request) (*http.Response, error) { - // Verify the model name is converted to lowercase in the request URL - assert.Contains(t, req.URL.Path, expectedLowercase) + mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) + client := New(mockClient) - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pushed successfully"}`)), - }, nil - }, - }, - } + mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { + assert.Contains(t, req.URL.Path, expectedLowercase) + }).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pushed successfully"}`)), + }, nil) _, _, err := client.Push(modelName, func(s string) {}) assert.NoError(t, err) } func TestRemoveHuggingFaceModel(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + // Test case for removing a Hugging Face model with mixed case modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" - client := &Client{ - dockerClient: &mockHTTPClient{ - doFunc: func(req *http.Request) (*http.Response, error) { - // Verify the model name is converted to lowercase in the request URL - assert.Contains(t, req.URL.Path, expectedLowercase) + mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) + client := New(mockClient) - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString("Model removed successfully")), - }, nil - }, - }, - } + mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { + assert.Contains(t, req.URL.Path, expectedLowercase) + }).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString("Model removed successfully")), + }, nil) _, err := client.Remove([]string{modelName}, false) assert.NoError(t, err) } func TestTagHuggingFaceModel(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + // Test case for tagging a Hugging Face model with mixed case sourceModel := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" targetRepo := "myrepo" targetTag := "latest" - client := &Client{ - dockerClient: &mockHTTPClient{ - doFunc: func(req *http.Request) (*http.Response, error) { - // Verify the model name is converted to lowercase in the request URL - assert.Contains(t, req.URL.Path, expectedLowercase) - - return &http.Response{ - StatusCode: http.StatusCreated, - Body: io.NopCloser(bytes.NewBufferString("Tag created successfully")), - }, nil - }, - }, - } + mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) + client := New(mockClient) + + mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { + assert.Contains(t, req.URL.Path, expectedLowercase) + }).Return(&http.Response{ + StatusCode: http.StatusCreated, + Body: io.NopCloser(bytes.NewBufferString("Tag created successfully")), + }, nil) _, err := client.Tag(sourceModel, targetRepo, targetTag) assert.NoError(t, err) From cf16bc21131c9fc07ccf643d097923b5adda7e81 Mon Sep 17 00:00:00 2001 From: ilopezluna Date: Tue, 29 Apr 2025 15:44:58 +0200 Subject: [PATCH 3/4] Add name normalization to InspectOpenAI --- desktop/desktop.go | 1 + desktop/desktop_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/desktop/desktop.go b/desktop/desktop.go index b43b749c..2d692178 100644 --- a/desktop/desktop.go +++ b/desktop/desktop.go @@ -260,6 +260,7 @@ func (c *Client) Inspect(model string) (Model, error) { } func (c *Client) InspectOpenAI(model string) (OpenAIModel, error) { + model = normalizeModelName(model) modelsRoute := inference.InferencePrefix + "/v1/models" if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. diff --git a/desktop/desktop_test.go b/desktop/desktop_test.go index ac5bc355..d95fd0f3 100644 --- a/desktop/desktop_test.go +++ b/desktop/desktop_test.go @@ -189,3 +189,31 @@ func TestTagHuggingFaceModel(t *testing.T) { _, err := client.Tag(sourceModel, targetRepo, targetTag) assert.NoError(t, err) } + +func TestInspectOpenAIHuggingFaceModel(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Test case for inspecting a Hugging Face model with mixed case + modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" + + mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) + client := New(mockClient) + + mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { + assert.Contains(t, req.URL.Path, expectedLowercase) + }).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{ + "id": "` + expectedLowercase + `", + "object": "model", + "created": 1234567890, + "owned_by": "organization" + }`)), + }, nil) + + model, err := client.InspectOpenAI(modelName) + assert.NoError(t, err) + assert.Equal(t, expectedLowercase, model.ID) +} From 6785c7c5d7bb69424f35479f6eb889149d6cac59 Mon Sep 17 00:00:00 2001 From: ilopezluna Date: Tue, 29 Apr 2025 15:47:06 +0200 Subject: [PATCH 4/4] Rename to normalizeHuggingFaceModelName because normalization will only apply to HF repositories --- desktop/desktop.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/desktop/desktop.go b/desktop/desktop.go index 2d692178..84bc09f2 100644 --- a/desktop/desktop.go +++ b/desktop/desktop.go @@ -51,8 +51,8 @@ type Status struct { Error error `json:"error"` } -// normalizeModelName converts Hugging Face model names to lowercase -func normalizeModelName(model string) string { +// normalizeHuggingFaceModelName converts Hugging Face model names to lowercase +func normalizeHuggingFaceModelName(model string) string { if strings.HasPrefix(model, "hf.co/") { return strings.ToLower(model) } @@ -101,7 +101,7 @@ func (c *Client) Status() Status { } func (c *Client) Pull(model string, progress func(string)) (string, bool, error) { - model = normalizeModelName(model) + model = normalizeHuggingFaceModelName(model) jsonData, err := json.Marshal(models.ModelCreateRequest{From: model}) if err != nil { return "", false, fmt.Errorf("error marshaling request: %w", err) @@ -157,7 +157,7 @@ func (c *Client) Pull(model string, progress func(string)) (string, bool, error) } func (c *Client) Push(model string, progress func(string)) (string, bool, error) { - model = normalizeModelName(model) + model = normalizeHuggingFaceModelName(model) pushPath := inference.ModelsPrefix + "/" + model + "/push" resp, err := c.doRequest( http.MethodPost, @@ -236,7 +236,7 @@ func (c *Client) ListOpenAI() (OpenAIModelList, error) { } func (c *Client) Inspect(model string) (Model, error) { - model = normalizeModelName(model) + model = normalizeHuggingFaceModelName(model) if model != "" { if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. @@ -260,7 +260,7 @@ func (c *Client) Inspect(model string) (Model, error) { } func (c *Client) InspectOpenAI(model string) (OpenAIModel, error) { - model = normalizeModelName(model) + model = normalizeHuggingFaceModelName(model) modelsRoute := inference.InferencePrefix + "/v1/models" if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. @@ -323,7 +323,7 @@ func (c *Client) fullModelID(id string) (string, error) { } func (c *Client) Chat(model, prompt string) error { - model = normalizeModelName(model) + model = normalizeHuggingFaceModelName(model) if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. if expanded, err := c.fullModelID(model); err == nil { @@ -401,7 +401,7 @@ func (c *Client) Chat(model, prompt string) error { func (c *Client) Remove(models []string, force bool) (string, error) { modelRemoved := "" for _, model := range models { - model = normalizeModelName(model) + model = normalizeHuggingFaceModelName(model) // Check if not a model ID passed as parameter. if !strings.Contains(model, "/") { if expanded, err := c.fullModelID(model); err == nil { @@ -475,7 +475,7 @@ func (c *Client) handleQueryError(err error, path string) error { } func (c *Client) Tag(source, targetRepo, targetTag string) (string, error) { - source = normalizeModelName(source) + source = normalizeHuggingFaceModelName(source) // Check if the source is a model ID, and expand it if necessary if !strings.Contains(strings.Trim(source, "/"), "/") { // Do an extra API call to check if the model parameter might be a model ID