@@ -365,7 +365,7 @@ func (c *Client) fullModelID(id string) (string, error) {
365365}
366366
367367// Chat performs a chat request and streams the response content with selective markdown rendering.
368- func (c * Client ) Chat (backend , model , prompt , apiKey string , outputFunc func (string )) error {
368+ func (c * Client ) Chat (backend , model , prompt , apiKey string , outputFunc func (string ), shouldUseMarkdown bool ) error {
369369 model = normalizeHuggingFaceModelName (model )
370370 if ! strings .Contains (strings .Trim (model , "/" ), "/" ) {
371371 // Do an extra API call to check if the model parameter isn't a model ID.
@@ -422,7 +422,14 @@ func (c *Client) Chat(backend, model, prompt, apiKey string, outputFunc func(str
422422 )
423423
424424 printerState := chatPrinterNone
425- reasoningFmt := color .New (color .FgWhite ).Add (color .Italic )
425+ reasoningFmt := color .New ().Add (color .Italic )
426+
427+ var finalUsage * struct {
428+ CompletionTokens int `json:"completion_tokens"`
429+ PromptTokens int `json:"prompt_tokens"`
430+ TotalTokens int `json:"total_tokens"`
431+ }
432+
426433 scanner := bufio .NewScanner (resp .Body )
427434 for scanner .Scan () {
428435 line := scanner .Text ()
@@ -445,6 +452,10 @@ func (c *Client) Chat(backend, model, prompt, apiKey string, outputFunc func(str
445452 return fmt .Errorf ("error parsing stream response: %w" , err )
446453 }
447454
455+ if streamResp .Usage != nil {
456+ finalUsage = streamResp .Usage
457+ }
458+
448459 if len (streamResp .Choices ) > 0 {
449460 if streamResp .Choices [0 ].Delta .ReasoningContent != "" {
450461 chunk := streamResp .Choices [0 ].Delta .ReasoningContent
@@ -454,14 +465,14 @@ func (c *Client) Chat(backend, model, prompt, apiKey string, outputFunc func(str
454465 if printerState != chatPrinterReasoning {
455466 const thinkingHeader = "Thinking:\n "
456467 if reasoningFmt != nil {
457- outputFunc ( reasoningFmt .Sprint (thinkingHeader ) )
468+ reasoningFmt .Print (thinkingHeader )
458469 } else {
459470 outputFunc (thinkingHeader )
460471 }
461472 }
462473 printerState = chatPrinterReasoning
463474 if reasoningFmt != nil {
464- outputFunc ( reasoningFmt .Sprint (chunk ) )
475+ reasoningFmt .Print (chunk )
465476 } else {
466477 outputFunc (chunk )
467478 }
@@ -481,6 +492,19 @@ func (c *Client) Chat(backend, model, prompt, apiKey string, outputFunc func(str
481492 return fmt .Errorf ("error reading response stream: %w" , err )
482493 }
483494
495+ if finalUsage != nil {
496+ usageInfo := fmt .Sprintf ("\n \n Token usage: %d prompt + %d completion = %d total" ,
497+ finalUsage .PromptTokens ,
498+ finalUsage .CompletionTokens ,
499+ finalUsage .TotalTokens )
500+
501+ usageFmt := color .New (color .FgHiBlack )
502+ if ! shouldUseMarkdown {
503+ usageFmt .DisableColor ()
504+ }
505+ outputFunc (usageFmt .Sprint (usageInfo ))
506+ }
507+
484508 return nil
485509}
486510
0 commit comments