diff --git a/desktop/context.go b/desktop/context.go index e302185d..aa70d837 100644 --- a/desktop/context.go +++ b/desktop/context.go @@ -3,7 +3,6 @@ package desktop import ( "context" "fmt" - "github.com/docker/model-cli/pkg/types" "net/http" "net/url" "os" @@ -16,6 +15,7 @@ import ( "github.com/docker/cli/cli/context/docker" clientpkg "github.com/docker/docker/client" "github.com/docker/model-cli/pkg/standalone" + "github.com/docker/model-cli/pkg/types" "github.com/docker/model-runner/pkg/inference" ) @@ -155,6 +155,10 @@ func DetectContext(ctx context.Context, cli *command.DockerCli) (*ModelRunnerCon client = http.DefaultClient } + if userAgent := os.Getenv("USER_AGENT"); userAgent != "" { + setUserAgent(client, userAgent) + } + // Success. return &ModelRunnerContext{ kind: kind, @@ -183,3 +187,39 @@ func (c *ModelRunnerContext) URL(path string) string { func (c *ModelRunnerContext) Client() DockerHttpClient { return c.client } + +func setUserAgent(client DockerHttpClient, userAgent string) { + if httpClient, ok := client.(*http.Client); ok { + transport := httpClient.Transport + if transport == nil { + transport = http.DefaultTransport + } + + httpClient.Transport = &userAgentTransport{ + userAgent: userAgent, + transport: transport, + } + } +} + +type userAgentTransport struct { + userAgent string + transport http.RoundTripper +} + +func (u *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) { + reqClone := req.Clone(req.Context()) + + existingUA := reqClone.UserAgent() + + var newUA string + if existingUA != "" { + newUA = existingUA + " " + u.userAgent + } else { + newUA = u.userAgent + } + + reqClone.Header.Set("User-Agent", newUA) + + return u.transport.RoundTrip(reqClone) +}