diff --git a/config.example.yaml b/config.example.yaml index 348aabd846..108d56c0cc 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -178,6 +178,7 @@ nonstream-keepalive-interval: 0 # - name: "openrouter" # The name of the provider; it will be used in the user agent and other places. # prefix: "test" # optional: require calls like "test/kimi-k2" to target this provider's credentials # base-url: "https://openrouter.ai/api/v1" # The base URL of the provider. +# # force-upstream-stream: true # optional: always call upstream with stream=true and aggregate SSE for downstream non-stream requests # headers: # X-Custom-Header: "custom-value" # api-key-entries: diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index 503179c11c..0b471e4b85 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -397,6 +397,7 @@ func (h *Handler) PatchOpenAICompat(c *gin.Context) { Name *string `json:"name"` Prefix *string `json:"prefix"` BaseURL *string `json:"base-url"` + ForceUpstreamStream *bool `json:"force-upstream-stream"` APIKeyEntries *[]config.OpenAICompatibilityAPIKey `json:"api-key-entries"` Models *[]config.OpenAICompatibilityModel `json:"models"` Headers *map[string]string `json:"headers"` @@ -445,6 +446,9 @@ func (h *Handler) PatchOpenAICompat(c *gin.Context) { } entry.BaseURL = trimmed } + if body.Value.ForceUpstreamStream != nil { + entry.ForceUpstreamStream = *body.Value.ForceUpstreamStream + } if body.Value.APIKeyEntries != nil { entry.APIKeyEntries = append([]config.OpenAICompatibilityAPIKey(nil), (*body.Value.APIKeyEntries)...) } diff --git a/internal/config/config.go b/internal/config/config.go index 5a6595f778..67f3ad9e52 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -466,6 +466,9 @@ type OpenAICompatibility struct { // BaseURL is the base URL for the external OpenAI-compatible API endpoint. BaseURL string `yaml:"base-url" json:"base-url"` + // ForceUpstreamStream forces upstream stream=true for non-stream downstream requests. + ForceUpstreamStream bool `yaml:"force-upstream-stream,omitempty" json:"force-upstream-stream,omitempty"` + // APIKeyEntries defines API keys with optional per-key proxy configuration. APIKeyEntries []OpenAICompatibilityAPIKey `yaml:"api-key-entries,omitempty" json:"api-key-entries,omitempty"` diff --git a/internal/runtime/executor/openai_compat_executor.go b/internal/runtime/executor/openai_compat_executor.go index d28b36251a..6f36284891 100644 --- a/internal/runtime/executor/openai_compat_executor.go +++ b/internal/runtime/executor/openai_compat_executor.go @@ -93,14 +93,20 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A originalPayloadSource = opts.OriginalRequest } originalPayload := originalPayloadSource - originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream) - translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream) + forceUpstreamStream := e.shouldForceUpstreamStream(auth) && opts.Alt != "responses/compact" + upstreamStream := forceUpstreamStream + originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, upstreamStream) + translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, upstreamStream) requestedModel := payloadRequestedModel(opts, req.Model) translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel) if opts.Alt == "responses/compact" { if updated, errDelete := sjson.DeleteBytes(translated, "stream"); errDelete == nil { translated = updated } + } else if forceUpstreamStream { + if updated, errSet := sjson.SetBytes(translated, "stream", true); errSet == nil { + translated = updated + } } translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier()) @@ -118,6 +124,10 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A httpReq.Header.Set("Authorization", "Bearer "+apiKey) } httpReq.Header.Set("User-Agent", "cli-proxy-openai-compat") + if forceUpstreamStream { + httpReq.Header.Set("Accept", "text/event-stream") + httpReq.Header.Set("Cache-Control", "no-cache") + } var attrs map[string]string if auth != nil { attrs = auth.Attributes @@ -166,12 +176,21 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A return resp, err } appendAPIResponseChunk(ctx, e.cfg, body) - reporter.publish(ctx, parseOpenAIUsage(body)) + bodyForTranslation := body + usageDetail := parseOpenAIUsage(body) + if forceUpstreamStream { + bodyForTranslation, usageDetail, err = aggregateOpenAIChatCompletionSSE(body) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + } + reporter.publish(ctx, usageDetail) // Ensure we at least record the request even if upstream doesn't return usage reporter.ensurePublished(ctx) // Translate response back to source format when needed var param any - out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m) + out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyForTranslation, ¶m) resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } @@ -374,6 +393,11 @@ func (e *OpenAICompatExecutor) resolveCompatConfig(auth *cliproxyauth.Auth) *con return nil } +func (e *OpenAICompatExecutor) shouldForceUpstreamStream(auth *cliproxyauth.Auth) bool { + compat := e.resolveCompatConfig(auth) + return compat != nil && compat.ForceUpstreamStream +} + func (e *OpenAICompatExecutor) overrideModel(payload []byte, model string) []byte { if len(payload) == 0 || model == "" { return payload diff --git a/internal/runtime/executor/openai_compat_executor_stream_aggregate_test.go b/internal/runtime/executor/openai_compat_executor_stream_aggregate_test.go new file mode 100644 index 0000000000..4248ab40bf --- /dev/null +++ b/internal/runtime/executor/openai_compat_executor_stream_aggregate_test.go @@ -0,0 +1,185 @@ +package executor + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + "github.com/tidwall/gjson" +) + +func TestOpenAICompatExecutor_ForceUpstreamStreamAggregatesReasoningAndContent(t *testing.T) { + var gotAccept string + var gotBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAccept = r.Header.Get("Accept") + body, _ := io.ReadAll(r.Body) + gotBody = body + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-1\",\"object\":\"chat.completion.chunk\",\"created\":1710000000,\"model\":\"glm-5\",\"choices\":[{\"index\":0,\"delta\":{\"reasoning_content\":\"r1\"}}]}\n\n")) + _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-1\",\"object\":\"chat.completion.chunk\",\"created\":1710000000,\"model\":\"glm-5\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi \"}}]}\n\n")) + _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-1\",\"object\":\"chat.completion.chunk\",\"created\":1710000000,\"model\":\"glm-5\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"there\"}}],\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":2,\"total_tokens\":3}}\n\n")) + _, _ = w.Write([]byte("data: [DONE]\n\n")) + })) + defer server.Close() + + cfg := &config.Config{OpenAICompatibility: []config.OpenAICompatibility{{ + Name: "iamai", + BaseURL: server.URL + "/v1", + ForceUpstreamStream: true, + Models: []config.OpenAICompatibilityModel{{ + Name: "glm-5", + Alias: "glm-5", + }}, + }}} + + executor := NewOpenAICompatExecutor("openai-compatibility", cfg) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "base_url": server.URL + "/v1", + "api_key": "test", + "compat_name": "iamai", + }} + payload := []byte(`{"model":"glm-5","messages":[{"role":"user","content":"hi"}]}`) + + resp, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "glm-5", + Payload: payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("openai"), + Stream: false, + }) + if err != nil { + t.Fatalf("Execute error: %v", err) + } + if gotAccept != "text/event-stream" { + t.Fatalf("expected Accept text/event-stream, got %q", gotAccept) + } + if !gjson.GetBytes(gotBody, "stream").Bool() { + t.Fatalf("expected upstream payload to include stream=true") + } + if !gjson.ValidBytes(resp.Payload) { + t.Fatalf("expected valid JSON response, got: %s", string(resp.Payload)) + } + if gjson.GetBytes(resp.Payload, "choices.0.message.content").String() != "hi there" { + t.Fatalf("content mismatch: %s", gjson.GetBytes(resp.Payload, "choices.0.message.content").String()) + } + if gjson.GetBytes(resp.Payload, "choices.0.message.reasoning_content").String() != "r1" { + t.Fatalf("reasoning mismatch: %s", gjson.GetBytes(resp.Payload, "choices.0.message.reasoning_content").String()) + } + if gjson.GetBytes(resp.Payload, "choices.0.finish_reason").String() != "stop" { + t.Fatalf("expected finish_reason stop, got %s", gjson.GetBytes(resp.Payload, "choices.0.finish_reason").String()) + } + if gjson.GetBytes(resp.Payload, "usage.prompt_tokens").Int() != 1 { + t.Fatalf("expected usage prompt_tokens") + } +} + +func TestOpenAICompatExecutor_ForceUpstreamStream_ToolCallsOnlyFinishReasonIsToolCalls(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-2\",\"object\":\"chat.completion.chunk\",\"created\":1710000001,\"model\":\"glm-5\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"read\",\"arguments\":\"{\\\"path\\\": \\\"\"}}]}}]}\n\n")) + _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-2\",\"object\":\"chat.completion.chunk\",\"created\":1710000001,\"model\":\"glm-5\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"/tmp/test\\\"}\"}}]}}]}\n\n")) + _, _ = w.Write([]byte("data: [DONE]\n\n")) + })) + defer server.Close() + + cfg := &config.Config{OpenAICompatibility: []config.OpenAICompatibility{{ + Name: "iamai", + BaseURL: server.URL + "/v1", + ForceUpstreamStream: true, + Models: []config.OpenAICompatibilityModel{{ + Name: "glm-5", + Alias: "glm-5", + }}, + }}} + + executor := NewOpenAICompatExecutor("openai-compatibility", cfg) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "base_url": server.URL + "/v1", + "api_key": "test", + "compat_name": "iamai", + }} + payload := []byte(`{"model":"glm-5","messages":[{"role":"user","content":"hi"}]}`) + + resp, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "glm-5", + Payload: payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("openai"), + Stream: false, + }) + if err != nil { + t.Fatalf("Execute error: %v", err) + } + if !gjson.ValidBytes(resp.Payload) { + t.Fatalf("expected valid JSON response") + } + calls := gjson.GetBytes(resp.Payload, "choices.0.message.tool_calls") + if !calls.Exists() || len(calls.Array()) != 1 { + t.Fatalf("expected one tool_call, got: %s", calls.String()) + } + if gjson.GetBytes(resp.Payload, "choices.0.message.tool_calls.0.function.name").String() != "read" { + t.Fatalf("tool_call name mismatch") + } + if gjson.GetBytes(resp.Payload, "choices.0.message.tool_calls.0.function.arguments").String() != `{"path": "/tmp/test"}` { + t.Fatalf("tool_call arguments mismatch: %s", gjson.GetBytes(resp.Payload, "choices.0.message.tool_calls.0.function.arguments").String()) + } + if gjson.GetBytes(resp.Payload, "choices.0.finish_reason").String() != "tool_calls" { + t.Fatalf("expected finish_reason tool_calls, got: %s", gjson.GetBytes(resp.Payload, "choices.0.finish_reason").String()) + } +} + +func TestOpenAICompatExecutor_DefaultBehaviorUnchanged(t *testing.T) { + var gotBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + gotBody = body + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"chatcmpl-3","object":"chat.completion","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}]}`)) + })) + defer server.Close() + + cfg := &config.Config{OpenAICompatibility: []config.OpenAICompatibility{{ + Name: "iamai", + BaseURL: server.URL + "/v1", + Models: []config.OpenAICompatibilityModel{{ + Name: "glm-5", + Alias: "glm-5", + }}, + }}} + + executor := NewOpenAICompatExecutor("openai-compatibility", cfg) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "base_url": server.URL + "/v1", + "api_key": "test", + "compat_name": "iamai", + }} + payload := []byte(`{"model":"glm-5","messages":[{"role":"user","content":"hi"}]}`) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + resp, err := executor.Execute(ctx, auth, cliproxyexecutor.Request{ + Model: "glm-5", + Payload: payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("openai"), + Stream: false, + }) + if err != nil { + t.Fatalf("Execute error: %v", err) + } + if gjson.GetBytes(gotBody, "stream").Exists() { + t.Fatalf("did not expect stream=true in payload") + } + if !gjson.ValidBytes(resp.Payload) { + t.Fatalf("expected valid JSON response") + } +} diff --git a/internal/runtime/executor/openai_compat_sse_aggregate.go b/internal/runtime/executor/openai_compat_sse_aggregate.go new file mode 100644 index 0000000000..ce6915e6cf --- /dev/null +++ b/internal/runtime/executor/openai_compat_sse_aggregate.go @@ -0,0 +1,176 @@ +package executor + +import ( + "bytes" + "fmt" + "sort" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +type toolCallAggregate struct { + ID string + Type string + FuncName string + Arguments string +} + +// aggregateOpenAIChatCompletionSSE converts OpenAI-style chat.completion.chunk SSE into +// a final chat.completion JSON payload and returns parsed usage details. +func aggregateOpenAIChatCompletionSSE(body []byte) ([]byte, usage.Detail, error) { + lines := bytes.Split(body, []byte("\n")) + var ( + id string + model string + created int64 + content strings.Builder + reasoning strings.Builder + finishReason string + nativeFinish string + usageRaw string + hasAny bool + toolCallsByIndex = map[int]*toolCallAggregate{} + orderedToolIdx []int + ) + + for _, line := range lines { + payload := jsonPayload(line) + if len(payload) == 0 || !gjson.ValidBytes(payload) { + continue + } + hasAny = true + + if id == "" { + if v := gjson.GetBytes(payload, "id"); v.Exists() { + id = v.String() + } + } + if model == "" { + if v := gjson.GetBytes(payload, "model"); v.Exists() { + model = v.String() + } + } + if created == 0 { + if v := gjson.GetBytes(payload, "created"); v.Exists() { + created = v.Int() + } + } + + if v := gjson.GetBytes(payload, "choices.0.delta.content"); v.Exists() { + content.WriteString(v.String()) + } + if v := gjson.GetBytes(payload, "choices.0.delta.reasoning_content"); v.Exists() { + reasoning.WriteString(v.String()) + } + + if v := gjson.GetBytes(payload, "choices.0.finish_reason"); v.Exists() { + trimmed := strings.TrimSpace(v.String()) + if trimmed != "" { + finishReason = trimmed + } + } + if v := gjson.GetBytes(payload, "choices.0.native_finish_reason"); v.Exists() { + trimmed := strings.TrimSpace(v.String()) + if trimmed != "" { + nativeFinish = trimmed + } + } + + if v := gjson.GetBytes(payload, "usage"); v.Exists() { + usageRaw = v.Raw + } + + if v := gjson.GetBytes(payload, "choices.0.delta.tool_calls"); v.Exists() { + v.ForEach(func(_, item gjson.Result) bool { + idx := int(item.Get("index").Int()) + agg, ok := toolCallsByIndex[idx] + if !ok { + agg = &toolCallAggregate{} + toolCallsByIndex[idx] = agg + orderedToolIdx = append(orderedToolIdx, idx) + } + if idv := item.Get("id"); idv.Exists() { + agg.ID = idv.String() + } + if tv := item.Get("type"); tv.Exists() { + agg.Type = tv.String() + } + if nv := item.Get("function.name"); nv.Exists() { + agg.FuncName = nv.String() + } + if av := item.Get("function.arguments"); av.Exists() { + agg.Arguments += av.String() + } + return true + }) + } + } + + if !hasAny { + return nil, usage.Detail{}, fmt.Errorf("openai compat: no SSE payloads to aggregate") + } + + if finishReason == "" { + if len(toolCallsByIndex) > 0 { + finishReason = "tool_calls" + } else { + finishReason = "stop" + } + } + + result := []byte(`{"id":"","object":"chat.completion","created":0,"model":"","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`) + if id != "" { + result, _ = sjson.SetBytes(result, "id", id) + } + if created != 0 { + result, _ = sjson.SetBytes(result, "created", created) + } + if model != "" { + result, _ = sjson.SetBytes(result, "model", model) + } + if content.Len() > 0 { + result, _ = sjson.SetBytes(result, "choices.0.message.content", content.String()) + } + if reasoning.Len() > 0 { + result, _ = sjson.SetBytes(result, "choices.0.message.reasoning_content", reasoning.String()) + } + if finishReason != "" { + result, _ = sjson.SetBytes(result, "choices.0.finish_reason", finishReason) + } + if nativeFinish != "" { + result, _ = sjson.SetBytes(result, "choices.0.native_finish_reason", nativeFinish) + } + + if len(toolCallsByIndex) > 0 { + sort.Ints(orderedToolIdx) + toolCalls := make([]map[string]any, 0, len(orderedToolIdx)) + for _, idx := range orderedToolIdx { + agg := toolCallsByIndex[idx] + if agg == nil { + continue + } + entry := map[string]any{ + "id": agg.ID, + "type": agg.Type, + "function": map[string]any{ + "name": agg.FuncName, + "arguments": agg.Arguments, + }, + } + toolCalls = append(toolCalls, entry) + } + if len(toolCalls) > 0 { + result, _ = sjson.SetBytes(result, "choices.0.message.tool_calls", toolCalls) + } + } + + if strings.TrimSpace(usageRaw) != "" && gjson.Valid(usageRaw) { + result, _ = sjson.SetRawBytes(result, "usage", []byte(usageRaw)) + } + + usageDetail := parseOpenAIUsage(result) + return result, usageDetail, nil +} diff --git a/internal/watcher/diff/openai_compat.go b/internal/watcher/diff/openai_compat.go index 6b01aed296..677ed18b02 100644 --- a/internal/watcher/diff/openai_compat.go +++ b/internal/watcher/diff/openai_compat.go @@ -75,6 +75,9 @@ func describeOpenAICompatibilityUpdate(oldEntry, newEntry config.OpenAICompatibi if !equalStringMap(oldEntry.Headers, newEntry.Headers) { details = append(details, "headers updated") } + if oldEntry.ForceUpstreamStream != newEntry.ForceUpstreamStream { + details = append(details, fmt.Sprintf("force-upstream-stream %t -> %t", oldEntry.ForceUpstreamStream, newEntry.ForceUpstreamStream)) + } if len(details) == 0 { return "" } @@ -175,6 +178,10 @@ func openAICompatSignature(entry config.OpenAICompatibility) string { parts = append(parts, fmt.Sprintf("api_keys=%d", count)) } + if entry.ForceUpstreamStream || len(parts) > 0 { + parts = append(parts, fmt.Sprintf("force_stream=%t", entry.ForceUpstreamStream)) + } + if len(parts) == 0 { return "" } diff --git a/internal/watcher/diff/openai_compat_test.go b/internal/watcher/diff/openai_compat_test.go index db33db1487..47a5f41b2a 100644 --- a/internal/watcher/diff/openai_compat_test.go +++ b/internal/watcher/diff/openai_compat_test.go @@ -30,7 +30,8 @@ func TestDiffOpenAICompatibility(t *testing.T) { {Name: "m1"}, {Name: "m2"}, }, - Headers: map[string]string{"X-Test": "1"}, + Headers: map[string]string{"X-Test": "1"}, + ForceUpstreamStream: true, }, { Name: "provider-b", @@ -40,7 +41,7 @@ func TestDiffOpenAICompatibility(t *testing.T) { changes := DiffOpenAICompatibility(oldList, newList) expectContains(t, changes, "provider added: provider-b (api-keys=1, models=0)") - expectContains(t, changes, "provider updated: provider-a (api-keys 1 -> 2, models 1 -> 2, headers updated)") + expectContains(t, changes, "provider updated: provider-a (api-keys 1 -> 2, models 1 -> 2, headers updated, force-upstream-stream false -> true)") } func TestDiffOpenAICompatibility_RemovedAndUnchanged(t *testing.T) { @@ -161,6 +162,12 @@ func TestOpenAICompatSignature_StableAndNormalized(t *testing.T) { if sigC := openAICompatSignature(c); sigC == sigB { t.Fatalf("expected signature to change when models change, got %s", sigC) } + + d := b + d.ForceUpstreamStream = true + if sigD := openAICompatSignature(d); sigD == sigB { + t.Fatalf("expected signature to change when force-upstream-stream changes, got %s", sigD) + } } func TestCountOpenAIModelsSkipsBlanks(t *testing.T) {