Skip to content
This repository was archived by the owner on Oct 6, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions desktop/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/docker/pinata/common/pkg/inference"
"github.com/docker/pinata/common/pkg/inference/models"
"github.com/docker/pinata/common/pkg/paths"
"github.com/pkg/errors"
"go.opentelemetry.io/otel"
"html"
"io"
"net/http"
"strconv"
"strings"

"github.com/docker/pinata/common/pkg/inference"
"github.com/docker/pinata/common/pkg/inference/models"
"github.com/docker/pinata/common/pkg/paths"
"github.com/pkg/errors"
"go.opentelemetry.io/otel"
)

var (
Expand Down Expand Up @@ -50,6 +51,14 @@ type Status struct {
Error error `json:"error"`
}

// normalizeModelName converts Hugging Face model names to lowercase
func normalizeModelName(model string) string {
if strings.HasPrefix(model, "hf.co/") {
return strings.ToLower(model)
}
return model
}

func (c *Client) Status() Status {
// TODO: Query "/".
resp, err := c.doRequest(http.MethodGet, inference.ModelsPrefix, nil)
Expand Down Expand Up @@ -92,6 +101,7 @@ func (c *Client) Status() Status {
}

func (c *Client) Pull(model string, progress func(string)) (string, bool, error) {
model = normalizeModelName(model)
jsonData, err := json.Marshal(models.ModelCreateRequest{From: model})
if err != nil {
return "", false, fmt.Errorf("error marshaling request: %w", err)
Expand Down Expand Up @@ -147,6 +157,7 @@ func (c *Client) Pull(model string, progress func(string)) (string, bool, error)
}

func (c *Client) Push(model string, progress func(string)) (string, bool, error) {
model = normalizeModelName(model)
pushPath := inference.ModelsPrefix + "/" + model + "/push"
resp, err := c.doRequest(
http.MethodPost,
Expand Down Expand Up @@ -225,6 +236,7 @@ func (c *Client) ListOpenAI() (OpenAIModelList, error) {
}

func (c *Client) Inspect(model string) (Model, error) {
model = normalizeModelName(model)
if model != "" {
if !strings.Contains(strings.Trim(model, "/"), "/") {
// Do an extra API call to check if the model parameter isn't a model ID.
Expand Down Expand Up @@ -310,6 +322,7 @@ func (c *Client) fullModelID(id string) (string, error) {
}

func (c *Client) Chat(model, prompt string) error {
model = normalizeModelName(model)
if !strings.Contains(strings.Trim(model, "/"), "/") {
// Do an extra API call to check if the model parameter isn't a model ID.
if expanded, err := c.fullModelID(model); err == nil {
Expand Down Expand Up @@ -387,6 +400,7 @@ func (c *Client) Chat(model, prompt string) error {
func (c *Client) Remove(models []string, force bool) (string, error) {
modelRemoved := ""
for _, model := range models {
model = normalizeModelName(model)
// Check if not a model ID passed as parameter.
if !strings.Contains(model, "/") {
if expanded, err := c.fullModelID(model); err == nil {
Expand Down Expand Up @@ -460,6 +474,7 @@ func (c *Client) handleQueryError(err error, path string) error {
}

func (c *Client) Tag(source, targetRepo, targetTag string) (string, error) {
source = normalizeModelName(source)
// Check if the source is a model ID, and expand it if necessary
if !strings.Contains(strings.Trim(source, "/"), "/") {
// Do an extra API call to check if the model parameter might be a model ID
Expand Down
204 changes: 204 additions & 0 deletions desktop/desktop_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
package desktop

import (
"bytes"
"encoding/json"
"io"
"net/http"
"testing"

"github.com/docker/pinata/common/pkg/inference/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type mockHTTPClient struct {
doFunc func(req *http.Request) (*http.Response, error)
}

func (m *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
return m.doFunc(req)
}

func TestPullHuggingFaceModel(t *testing.T) {
// Test case for pulling a Hugging Face model with mixed case
modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF"
expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf"

client := &Client{
dockerClient: &mockHTTPClient{
doFunc: func(req *http.Request) (*http.Response, error) {
// Verify the model name is converted to lowercase in the request
var reqBody models.ModelCreateRequest
err := json.NewDecoder(req.Body).Decode(&reqBody)
require.NoError(t, err)
assert.Equal(t, expectedLowercase, reqBody.From)

return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)),
}, nil
},
},
}

_, _, err := client.Pull(modelName, func(s string) {})
assert.NoError(t, err)
}

func TestChatHuggingFaceModel(t *testing.T) {
// Test case for chatting with a Hugging Face model with mixed case
modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF"
expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf"
prompt := "Hello"

client := &Client{
dockerClient: &mockHTTPClient{
doFunc: func(req *http.Request) (*http.Response, error) {
// Verify the model name is converted to lowercase in the request
var reqBody OpenAIChatRequest
err := json.NewDecoder(req.Body).Decode(&reqBody)
require.NoError(t, err)
assert.Equal(t, expectedLowercase, reqBody.Model)

return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString("data: {\"choices\":[{\"delta\":{\"content\":\"Hello there!\"}}]}\n")),
}, nil
},
},
}

err := client.Chat(modelName, prompt)
assert.NoError(t, err)
}

func TestInspectHuggingFaceModel(t *testing.T) {
// Test case for inspecting a Hugging Face model with mixed case
modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF"
expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf"

client := &Client{
dockerClient: &mockHTTPClient{
doFunc: func(req *http.Request) (*http.Response, error) {
// Verify the model name is converted to lowercase in the request URL
assert.Contains(t, req.URL.Path, expectedLowercase)

return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(`{
"id": "sha256:123456789012",
"tags": ["` + expectedLowercase + `"],
"created": 1234567890,
"config": {
"format": "gguf",
"quantization": "Q4_K_M",
"parameters": "1B",
"architecture": "llama",
"size": "1.2GB"
}
}`)),
}, nil
},
},
}

model, err := client.Inspect(modelName)
assert.NoError(t, err)
assert.Equal(t, expectedLowercase, model.Tags[0])
}

func TestNonHuggingFaceModel(t *testing.T) {
// Test case for a non-Hugging Face model (should not be converted to lowercase)
modelName := "docker.io/library/llama2"
client := &Client{
dockerClient: &mockHTTPClient{
doFunc: func(req *http.Request) (*http.Response, error) {
// Verify the model name is not converted to lowercase
var reqBody models.ModelCreateRequest
err := json.NewDecoder(req.Body).Decode(&reqBody)
require.NoError(t, err)
assert.Equal(t, modelName, reqBody.From)

return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)),
}, nil
},
},
}

_, _, err := client.Pull(modelName, func(s string) {})
assert.NoError(t, err)
}

func TestPushHuggingFaceModel(t *testing.T) {
// Test case for pushing a Hugging Face model with mixed case
modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF"
expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf"

client := &Client{
dockerClient: &mockHTTPClient{
doFunc: func(req *http.Request) (*http.Response, error) {
// Verify the model name is converted to lowercase in the request URL
assert.Contains(t, req.URL.Path, expectedLowercase)

return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pushed successfully"}`)),
}, nil
},
},
}

_, _, err := client.Push(modelName, func(s string) {})
assert.NoError(t, err)
}

func TestRemoveHuggingFaceModel(t *testing.T) {
// Test case for removing a Hugging Face model with mixed case
modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF"
expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf"

client := &Client{
dockerClient: &mockHTTPClient{
doFunc: func(req *http.Request) (*http.Response, error) {
// Verify the model name is converted to lowercase in the request URL
assert.Contains(t, req.URL.Path, expectedLowercase)

return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString("Model removed successfully")),
}, nil
},
},
}

_, err := client.Remove([]string{modelName}, false)
assert.NoError(t, err)
}

func TestTagHuggingFaceModel(t *testing.T) {
// Test case for tagging a Hugging Face model with mixed case
sourceModel := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF"
expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf"
targetRepo := "myrepo"
targetTag := "latest"

client := &Client{
dockerClient: &mockHTTPClient{
doFunc: func(req *http.Request) (*http.Response, error) {
// Verify the model name is converted to lowercase in the request URL
assert.Contains(t, req.URL.Path, expectedLowercase)

return &http.Response{
StatusCode: http.StatusCreated,
Body: io.NopCloser(bytes.NewBufferString("Tag created successfully")),
}, nil
},
},
}

_, err := client.Tag(sourceModel, targetRepo, targetTag)
assert.NoError(t, err)
}