Skip to content

Commit d10b7df

Browse files
committed
Inject <think> to responses when forced thinking in chat template
Adds logic to prepend '<think>' to the first streamed chunk and all final generations if the chat template ends with 'think'. Adjusts token and offset accounting to remain consistent when the tag is injected.
1 parent 2539acf commit d10b7df

File tree

2 files changed

+59
-22
lines changed

2 files changed

+59
-22
lines changed

endpoints/OAI/types/chat_completion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class ChatCompletionStreamChoice(BaseModel):
5757
class ChatCompletionRequest(CommonCompletionRequest):
5858
messages: List[ChatCompletionMessage]
5959
prompt_template: Optional[str] = None
60-
add_generation_prompt: Optional[bool] = True
60+
add_generation_prompt: Optional[bool] = None
6161
template_vars: Optional[dict] = Field(
6262
default={},
6363
validation_alias=AliasChoices("template_vars", "chat_template_kwargs"),

endpoints/OAI/utils/chat_completion.py

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,29 @@
3232
from endpoints.OAI.utils.tools import ToolCallProcessor, TOOL_CALL_SCHEMA
3333

3434

35+
def should_add_generation_prompt(data: ChatCompletionRequest) -> bool:
36+
"""
37+
Determines if a generation prompt should be added based on the request.
38+
- Explicitly follows `data.add_generation_prompt` if set.
39+
- Defaults to `False` if the last message is from the assistant to avoid double prompts.
40+
- Defaults to `True` otherwise.
41+
"""
42+
if data.add_generation_prompt is not None:
43+
return data.add_generation_prompt
44+
if data.messages and data.messages[-1].role == "assistant":
45+
return False
46+
return True
47+
48+
49+
def preprocess_stream_chunk(data: dict, inject_thinking: bool, is_first_chunk: bool):
50+
"""Prepends '<think>' to the first chunk of a stream if needed."""
51+
if inject_thinking and is_first_chunk:
52+
updated = data.copy()
53+
updated["text"] = "<think>" + updated.get("text", "")
54+
return updated
55+
return data
56+
57+
3558
def _create_response(
3659
request_id: str, generations: List[dict], model_name: Optional[str]
3760
):
@@ -54,10 +77,10 @@ def _create_response(
5477
logprobs = unwrap(generation.get("logprobs"), [])
5578

5679
collected_token_probs = []
57-
for index, token in enumerate(token_probs.keys()):
80+
for i, token in enumerate(token_probs.keys()):
5881
top_logprobs = [
59-
ChatCompletionLogprob(token=token, logprob=logprob)
60-
for token, logprob in logprobs[index].items()
82+
ChatCompletionLogprob(token=t, logprob=lp)
83+
for t, lp in logprobs[i].items()
6184
]
6285

6386
collected_token_probs.append(
@@ -258,7 +281,7 @@ async def apply_chat_template(data: ChatCompletionRequest):
258281
try:
259282
data.template_vars.update(
260283
{
261-
"add_generation_prompt": data.add_generation_prompt,
284+
"add_generation_prompt": should_add_generation_prompt(data),
262285
"tools": tools,
263286
"functions": data.functions,
264287
}
@@ -324,6 +347,8 @@ async def stream_generate_chat_completion(
324347
try:
325348
logger.info(f"Received chat completion streaming request {request.state.id}")
326349

350+
inject_thinking = "<think>" in prompt[-11:] and should_add_generation_prompt(data)
351+
327352
for idx in range(0, data.n):
328353
task_gen_params = data.model_copy(deep=True)
329354
request_id = _parse_gen_request_id(data.n, request.state.id, idx)
@@ -342,8 +367,8 @@ async def stream_generate_chat_completion(
342367

343368
gen_tasks.append(gen_task)
344369

345-
# Text accumulation for tool calls
346-
current_generation_text = ""
370+
# Text accumulation for tool calls(?)
371+
seen_first_chunk_indices = set()
347372

348373
# Consumer loop
349374
while True:
@@ -353,30 +378,36 @@ async def stream_generate_chat_completion(
353378
generation = await gen_queue.get()
354379

355380
# Handle options if a tool model is present
356-
if tool_start:
357-
if "stop_str" in generation:
358-
generations = await generate_tool_calls(
359-
prompt,
360-
embeddings,
361-
data,
362-
[generation],
363-
request,
364-
)
365-
366-
# Only one generation present in this case
367-
generation = generations[0]
368-
elif "text" in generation:
369-
current_generation_text += generation["text"]
381+
if tool_start and "stop_str" in generation:
382+
generations = await generate_tool_calls(
383+
prompt,
384+
embeddings,
385+
data,
386+
[generation],
387+
request,
388+
)
389+
# Only one generation present in this case
390+
generation = generations[0]
370391

371392
# Stream collector will push an exception to the queue if it fails
372393
if isinstance(generation, Exception):
373394
raise generation
374395

396+
index = generation.get("index", 0)
397+
is_first_for_this_index = index not in seen_first_chunk_indices
398+
399+
processed_generation = preprocess_stream_chunk(
400+
generation, inject_thinking, is_first_for_this_index
401+
)
402+
375403
response = _create_stream_chunk(
376-
request.state.id, generation, model_path.name
404+
request.state.id, processed_generation, model_path.name
377405
)
378406
yield response.model_dump_json()
379407

408+
if is_first_for_this_index:
409+
seen_first_chunk_indices.add(index)
410+
380411
# Check if all tasks are completed
381412
if all(task.done() for task in gen_tasks) and gen_queue.empty():
382413
# Send a usage chunk
@@ -442,6 +473,12 @@ async def generate_chat_completion(
442473
prompt, embeddings, data, generations, request
443474
)
444475

476+
# Prepend "<think>" after generation and tool calls are complete.
477+
if "<think>" in prompt[-11:] and should_add_generation_prompt(data):
478+
for gen in generations:
479+
if "text" in gen:
480+
gen["text"] = "<think>" + gen["text"]
481+
445482
response = _create_response(request.state.id, generations, model_path.name)
446483

447484
logger.info(f"Finished chat completion request {request.state.id}")

0 commit comments

Comments
 (0)