Skip to content
This repository was archived by the owner on Oct 6, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions desktop/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/docker/pinata/common/pkg/inference"
"github.com/docker/pinata/common/pkg/inference/models"
"github.com/docker/pinata/common/pkg/paths"
"github.com/pkg/errors"
"go.opentelemetry.io/otel"
"html"
"io"
"net/http"
"strconv"
"strings"

"github.com/docker/pinata/common/pkg/inference"
"github.com/docker/pinata/common/pkg/inference/models"
"github.com/docker/pinata/common/pkg/paths"
"github.com/pkg/errors"
"go.opentelemetry.io/otel"
)

var (
Expand Down Expand Up @@ -50,6 +51,14 @@ type Status struct {
Error error `json:"error"`
}

// normalizeHuggingFaceModelName converts Hugging Face model names to lowercase
func normalizeHuggingFaceModelName(model string) string {
if strings.HasPrefix(model, "hf.co/") {
return strings.ToLower(model)
}
return model
}

func (c *Client) Status() Status {
// TODO: Query "/".
resp, err := c.doRequest(http.MethodGet, inference.ModelsPrefix, nil)
Expand Down Expand Up @@ -92,6 +101,7 @@ func (c *Client) Status() Status {
}

func (c *Client) Pull(model string, progress func(string)) (string, bool, error) {
model = normalizeHuggingFaceModelName(model)
jsonData, err := json.Marshal(models.ModelCreateRequest{From: model})
if err != nil {
return "", false, fmt.Errorf("error marshaling request: %w", err)
Expand Down Expand Up @@ -147,6 +157,7 @@ func (c *Client) Pull(model string, progress func(string)) (string, bool, error)
}

func (c *Client) Push(model string, progress func(string)) (string, bool, error) {
model = normalizeHuggingFaceModelName(model)
pushPath := inference.ModelsPrefix + "/" + model + "/push"
resp, err := c.doRequest(
http.MethodPost,
Expand Down Expand Up @@ -225,6 +236,7 @@ func (c *Client) ListOpenAI() (OpenAIModelList, error) {
}

func (c *Client) Inspect(model string) (Model, error) {
model = normalizeHuggingFaceModelName(model)
if model != "" {
if !strings.Contains(strings.Trim(model, "/"), "/") {
// Do an extra API call to check if the model parameter isn't a model ID.
Expand All @@ -248,6 +260,7 @@ func (c *Client) Inspect(model string) (Model, error) {
}

func (c *Client) InspectOpenAI(model string) (OpenAIModel, error) {
model = normalizeHuggingFaceModelName(model)
modelsRoute := inference.InferencePrefix + "/v1/models"
if !strings.Contains(strings.Trim(model, "/"), "/") {
// Do an extra API call to check if the model parameter isn't a model ID.
Expand Down Expand Up @@ -310,6 +323,7 @@ func (c *Client) fullModelID(id string) (string, error) {
}

func (c *Client) Chat(model, prompt string) error {
model = normalizeHuggingFaceModelName(model)
if !strings.Contains(strings.Trim(model, "/"), "/") {
// Do an extra API call to check if the model parameter isn't a model ID.
if expanded, err := c.fullModelID(model); err == nil {
Expand Down Expand Up @@ -387,6 +401,7 @@ func (c *Client) Chat(model, prompt string) error {
func (c *Client) Remove(models []string, force bool) (string, error) {
modelRemoved := ""
for _, model := range models {
model = normalizeHuggingFaceModelName(model)
// Check if not a model ID passed as parameter.
if !strings.Contains(model, "/") {
if expanded, err := c.fullModelID(model); err == nil {
Expand Down Expand Up @@ -460,6 +475,7 @@ func (c *Client) handleQueryError(err error, path string) error {
}

func (c *Client) Tag(source, targetRepo, targetTag string) (string, error) {
source = normalizeHuggingFaceModelName(source)
// Check if the source is a model ID, and expand it if necessary
if !strings.Contains(strings.Trim(source, "/"), "/") {
// Do an extra API call to check if the model parameter might be a model ID
Expand Down
219 changes: 219 additions & 0 deletions desktop/desktop_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
package desktop

import (
"bytes"
"encoding/json"
"io"
"net/http"
"testing"

mockdesktop "github.com/docker/model-cli/mocks"
"github.com/docker/pinata/common/pkg/inference/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)

func TestPullHuggingFaceModel(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

// Test case for pulling a Hugging Face model with mixed case
modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF"
expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf"

mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
client := New(mockClient)

mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) {
var reqBody models.ModelCreateRequest
err := json.NewDecoder(req.Body).Decode(&reqBody)
require.NoError(t, err)
assert.Equal(t, expectedLowercase, reqBody.From)
}).Return(&http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)),
}, nil)

_, _, err := client.Pull(modelName, func(s string) {})
assert.NoError(t, err)
}

func TestChatHuggingFaceModel(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

// Test case for chatting with a Hugging Face model with mixed case
modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF"
expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf"
prompt := "Hello"

mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
client := New(mockClient)

mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) {
var reqBody OpenAIChatRequest
err := json.NewDecoder(req.Body).Decode(&reqBody)
require.NoError(t, err)
assert.Equal(t, expectedLowercase, reqBody.Model)
}).Return(&http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString("data: {\"choices\":[{\"delta\":{\"content\":\"Hello there!\"}}]}\n")),
}, nil)

err := client.Chat(modelName, prompt)
assert.NoError(t, err)
}

func TestInspectHuggingFaceModel(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

// Test case for inspecting a Hugging Face model with mixed case
modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF"
expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf"

mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
client := New(mockClient)

mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) {
assert.Contains(t, req.URL.Path, expectedLowercase)
}).Return(&http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(`{
"id": "sha256:123456789012",
"tags": ["` + expectedLowercase + `"],
"created": 1234567890,
"config": {
"format": "gguf",
"quantization": "Q4_K_M",
"parameters": "1B",
"architecture": "llama",
"size": "1.2GB"
}
}`)),
}, nil)

model, err := client.Inspect(modelName)
assert.NoError(t, err)
assert.Equal(t, expectedLowercase, model.Tags[0])
}

func TestNonHuggingFaceModel(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

// Test case for a non-Hugging Face model (should not be converted to lowercase)
modelName := "docker.io/library/llama2"
mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
client := New(mockClient)

mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) {
var reqBody models.ModelCreateRequest
err := json.NewDecoder(req.Body).Decode(&reqBody)
require.NoError(t, err)
assert.Equal(t, modelName, reqBody.From)
}).Return(&http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)),
}, nil)

_, _, err := client.Pull(modelName, func(s string) {})
assert.NoError(t, err)
}

func TestPushHuggingFaceModel(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

// Test case for pushing a Hugging Face model with mixed case
modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF"
expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf"

mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
client := New(mockClient)

mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) {
assert.Contains(t, req.URL.Path, expectedLowercase)
}).Return(&http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pushed successfully"}`)),
}, nil)

_, _, err := client.Push(modelName, func(s string) {})
assert.NoError(t, err)
}

func TestRemoveHuggingFaceModel(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

// Test case for removing a Hugging Face model with mixed case
modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF"
expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf"

mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
client := New(mockClient)

mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) {
assert.Contains(t, req.URL.Path, expectedLowercase)
}).Return(&http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString("Model removed successfully")),
}, nil)

_, err := client.Remove([]string{modelName}, false)
assert.NoError(t, err)
}

func TestTagHuggingFaceModel(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

// Test case for tagging a Hugging Face model with mixed case
sourceModel := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF"
expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf"
targetRepo := "myrepo"
targetTag := "latest"

mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
client := New(mockClient)

mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) {
assert.Contains(t, req.URL.Path, expectedLowercase)
}).Return(&http.Response{
StatusCode: http.StatusCreated,
Body: io.NopCloser(bytes.NewBufferString("Tag created successfully")),
}, nil)

_, err := client.Tag(sourceModel, targetRepo, targetTag)
assert.NoError(t, err)
}

func TestInspectOpenAIHuggingFaceModel(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

// Test case for inspecting a Hugging Face model with mixed case
modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF"
expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf"

mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
client := New(mockClient)

mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) {
assert.Contains(t, req.URL.Path, expectedLowercase)
}).Return(&http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(`{
"id": "` + expectedLowercase + `",
"object": "model",
"created": 1234567890,
"owned_by": "organization"
}`)),
}, nil)

model, err := client.InspectOpenAI(modelName)
assert.NoError(t, err)
assert.Equal(t, expectedLowercase, model.ID)
}