3232from endpoints .OAI .utils .tools import ToolCallProcessor , TOOL_CALL_SCHEMA
3333
3434
35+ def should_add_generation_prompt (data : ChatCompletionRequest ) -> bool :
36+ """
37+ Determines if a generation prompt should be added based on the request.
38+ - Explicitly follows `data.add_generation_prompt` if set.
39+ - Defaults to `False` if the last message is from the assistant to avoid double prompts.
40+ - Defaults to `True` otherwise.
41+ """
42+ if data .add_generation_prompt is not None :
43+ return data .add_generation_prompt
44+ if data .messages and data .messages [- 1 ].role == "assistant" :
45+ return False
46+ return True
47+
48+
49+ def preprocess_stream_chunk (data : dict , inject_thinking : bool , is_first_chunk : bool ):
50+ """Prepends '<think>' to the first chunk of a stream if needed."""
51+ if inject_thinking and is_first_chunk :
52+ updated = data .copy ()
53+ updated ["text" ] = "<think>" + updated .get ("text" , "" )
54+ return updated
55+ return data
56+
57+
3558def _create_response (
3659 request_id : str , generations : List [dict ], model_name : Optional [str ]
3760):
@@ -54,10 +77,10 @@ def _create_response(
5477 logprobs = unwrap (generation .get ("logprobs" ), [])
5578
5679 collected_token_probs = []
57- for index , token in enumerate (token_probs .keys ()):
80+ for i , token in enumerate (token_probs .keys ()):
5881 top_logprobs = [
59- ChatCompletionLogprob (token = token , logprob = logprob )
60- for token , logprob in logprobs [index ].items ()
82+ ChatCompletionLogprob (token = t , logprob = lp )
83+ for t , lp in logprobs [i ].items ()
6184 ]
6285
6386 collected_token_probs .append (
@@ -258,7 +281,7 @@ async def apply_chat_template(data: ChatCompletionRequest):
258281 try :
259282 data .template_vars .update (
260283 {
261- "add_generation_prompt" : data . add_generation_prompt ,
284+ "add_generation_prompt" : should_add_generation_prompt ( data ) ,
262285 "tools" : tools ,
263286 "functions" : data .functions ,
264287 }
@@ -324,6 +347,8 @@ async def stream_generate_chat_completion(
324347 try :
325348 logger .info (f"Received chat completion streaming request { request .state .id } " )
326349
350+ inject_thinking = "<think>" in prompt [- 11 :] and should_add_generation_prompt (data )
351+
327352 for idx in range (0 , data .n ):
328353 task_gen_params = data .model_copy (deep = True )
329354 request_id = _parse_gen_request_id (data .n , request .state .id , idx )
@@ -342,8 +367,8 @@ async def stream_generate_chat_completion(
342367
343368 gen_tasks .append (gen_task )
344369
345- # Text accumulation for tool calls
346- current_generation_text = ""
370+ # Text accumulation for tool calls(?)
371+ seen_first_chunk_indices = set ()
347372
348373 # Consumer loop
349374 while True :
@@ -353,30 +378,36 @@ async def stream_generate_chat_completion(
353378 generation = await gen_queue .get ()
354379
355380 # Handle options if a tool model is present
356- if tool_start :
357- if "stop_str" in generation :
358- generations = await generate_tool_calls (
359- prompt ,
360- embeddings ,
361- data ,
362- [generation ],
363- request ,
364- )
365-
366- # Only one generation present in this case
367- generation = generations [0 ]
368- elif "text" in generation :
369- current_generation_text += generation ["text" ]
381+ if tool_start and "stop_str" in generation :
382+ generations = await generate_tool_calls (
383+ prompt ,
384+ embeddings ,
385+ data ,
386+ [generation ],
387+ request ,
388+ )
389+ # Only one generation present in this case
390+ generation = generations [0 ]
370391
371392 # Stream collector will push an exception to the queue if it fails
372393 if isinstance (generation , Exception ):
373394 raise generation
374395
396+ index = generation .get ("index" , 0 )
397+ is_first_for_this_index = index not in seen_first_chunk_indices
398+
399+ processed_generation = preprocess_stream_chunk (
400+ generation , inject_thinking , is_first_for_this_index
401+ )
402+
375403 response = _create_stream_chunk (
376- request .state .id , generation , model_path .name
404+ request .state .id , processed_generation , model_path .name
377405 )
378406 yield response .model_dump_json ()
379407
408+ if is_first_for_this_index :
409+ seen_first_chunk_indices .add (index )
410+
380411 # Check if all tasks are completed
381412 if all (task .done () for task in gen_tasks ) and gen_queue .empty ():
382413 # Send a usage chunk
@@ -442,6 +473,12 @@ async def generate_chat_completion(
442473 prompt , embeddings , data , generations , request
443474 )
444475
476+ # Prepend "<think>" after generation and tool calls are complete.
477+ if "<think>" in prompt [- 11 :] and should_add_generation_prompt (data ):
478+ for gen in generations :
479+ if "text" in gen :
480+ gen ["text" ] = "<think>" + gen ["text" ]
481+
445482 response = _create_response (request .state .id , generations , model_path .name )
446483
447484 logger .info (f"Finished chat completion request { request .state .id } " )
0 commit comments