diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 52523149..d3b3a629 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -57,7 +57,7 @@ class ChatCompletionStreamChoice(BaseModel): class ChatCompletionRequest(CommonCompletionRequest): messages: List[ChatCompletionMessage] prompt_template: Optional[str] = None - add_generation_prompt: Optional[bool] = True + add_generation_prompt: Optional[bool] = None template_vars: Optional[dict] = Field( default={}, validation_alias=AliasChoices("template_vars", "chat_template_kwargs"), diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index b559bb2b..cec4a37c 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -32,6 +32,29 @@ from endpoints.OAI.utils.tools import ToolCallProcessor, TOOL_CALL_SCHEMA +def should_add_generation_prompt(data: ChatCompletionRequest) -> bool: + """ + Determines if a generation prompt should be added based on the request. + - Explicitly follows `data.add_generation_prompt` if set. + - Defaults to `False` if the last message is from the assistant to avoid double prompts. + - Defaults to `True` otherwise. + """ + if data.add_generation_prompt is not None: + return data.add_generation_prompt + if data.messages and data.messages[-1].role == "assistant": + return False + return True + + +def preprocess_stream_chunk(data: dict, inject_thinking: bool, is_first_chunk: bool): + """Prepends '' to the first chunk of a stream if needed.""" + if inject_thinking and is_first_chunk: + updated = data.copy() + updated["text"] = "" + updated.get("text", "") + return updated + return data + + def _create_response( request_id: str, generations: List[dict], model_name: Optional[str] ): @@ -54,10 +77,10 @@ def _create_response( logprobs = unwrap(generation.get("logprobs"), []) collected_token_probs = [] - for index, token in enumerate(token_probs.keys()): + for i, token in enumerate(token_probs.keys()): top_logprobs = [ - ChatCompletionLogprob(token=token, logprob=logprob) - for token, logprob in logprobs[index].items() + ChatCompletionLogprob(token=t, logprob=lp) + for t, lp in logprobs[i].items() ] collected_token_probs.append( @@ -258,7 +281,7 @@ async def apply_chat_template(data: ChatCompletionRequest): try: data.template_vars.update( { - "add_generation_prompt": data.add_generation_prompt, + "add_generation_prompt": should_add_generation_prompt(data), "tools": tools, "functions": data.functions, } @@ -324,6 +347,8 @@ async def stream_generate_chat_completion( try: logger.info(f"Received chat completion streaming request {request.state.id}") + inject_thinking = "" in prompt[-11:] and should_add_generation_prompt(data) + for idx in range(0, data.n): task_gen_params = data.model_copy(deep=True) request_id = _parse_gen_request_id(data.n, request.state.id, idx) @@ -342,8 +367,8 @@ async def stream_generate_chat_completion( gen_tasks.append(gen_task) - # Text accumulation for tool calls - current_generation_text = "" + # Text accumulation for tool calls(?) + seen_first_chunk_indices = set() # Consumer loop while True: @@ -353,30 +378,36 @@ async def stream_generate_chat_completion( generation = await gen_queue.get() # Handle options if a tool model is present - if tool_start: - if "stop_str" in generation: - generations = await generate_tool_calls( - prompt, - embeddings, - data, - [generation], - request, - ) - - # Only one generation present in this case - generation = generations[0] - elif "text" in generation: - current_generation_text += generation["text"] + if tool_start and "stop_str" in generation: + generations = await generate_tool_calls( + prompt, + embeddings, + data, + [generation], + request, + ) + # Only one generation present in this case + generation = generations[0] # Stream collector will push an exception to the queue if it fails if isinstance(generation, Exception): raise generation + index = generation.get("index", 0) + is_first_for_this_index = index not in seen_first_chunk_indices + + processed_generation = preprocess_stream_chunk( + generation, inject_thinking, is_first_for_this_index + ) + response = _create_stream_chunk( - request.state.id, generation, model_path.name + request.state.id, processed_generation, model_path.name ) yield response.model_dump_json() + if is_first_for_this_index: + seen_first_chunk_indices.add(index) + # Check if all tasks are completed if all(task.done() for task in gen_tasks) and gen_queue.empty(): # Send a usage chunk @@ -442,6 +473,12 @@ async def generate_chat_completion( prompt, embeddings, data, generations, request ) + # Prepend "" after generation and tool calls are complete. + if "" in prompt[-11:] and should_add_generation_prompt(data): + for gen in generations: + if "text" in gen: + gen["text"] = "" + gen["text"] + response = _create_response(request.state.id, generations, model_path.name) logger.info(f"Finished chat completion request {request.state.id}")