diff --git a/omlx/adapter/harmony.py b/omlx/adapter/harmony.py index 8588243e..ca1c796b 100644 --- a/omlx/adapter/harmony.py +++ b/omlx/adapter/harmony.py @@ -351,9 +351,9 @@ def current_recipient(self) -> str | None: def parse_tool_calls_from_tokens( token_ids: list[int], prepend_start: bool = True, -) -> tuple[str, list[dict[str, str]]]: +) -> tuple[str, str, list[dict[str, str]]]: """ - Parse tool calls from complete token sequence (non-streaming). + Parse a complete Harmony token sequence (non-streaming). Args: token_ids: Model output token ID list @@ -361,12 +361,13 @@ def parse_tool_calls_from_tokens( Set to False if token_ids already includes start tokens. Returns: - (output_text, tool_calls) - - output_text: Text from final channel + (output_text, analysis_text, tool_calls) + - output_text: Text from the final channel + - analysis_text: Chain-of-thought text from the analysis channel - tool_calls: [{"name": "...", "arguments": "..."}] """ if not token_ids: - return "", [] + return "", "", [] try: encoding = load_harmony_encoding("HarmonyGptOss") @@ -399,6 +400,7 @@ def parse_tool_calls_from_tokens( ) output_text = "" + analysis_text = "" tool_calls = [] for msg in messages: @@ -414,6 +416,13 @@ def parse_tool_calls_from_tokens( if isinstance(text, str): output_text += text + elif msg.channel == "analysis": + # Extract chain-of-thought text from analysis channel + for content in msg_content: + text = getattr(content, "text", None) + if isinstance(text, str): + analysis_text += text + elif msg.recipient and msg.recipient.startswith("functions."): # Extract tool calls from commentary channel name = msg.recipient[10:] # Remove "functions." prefix @@ -424,8 +433,8 @@ def parse_tool_calls_from_tokens( arguments += text tool_calls.append({"name": name, "arguments": arguments}) - return output_text, tool_calls + return output_text, analysis_text, tool_calls except Exception as e: logger.warning(f"Error parsing tool calls from tokens: {e}") - return "", [] + return "", "", [] diff --git a/omlx/adapter/output_parser.py b/omlx/adapter/output_parser.py index b03bc83c..a9f8acd4 100644 --- a/omlx/adapter/output_parser.py +++ b/omlx/adapter/output_parser.py @@ -37,6 +37,7 @@ class OutputParserFinalizeResult: stream_text: str = "" visible_text: str = "" + output_text_prefix: str = "" tool_calls: list[dict[str, str]] = field(default_factory=list) finish_reason: str | None = None @@ -123,12 +124,19 @@ def finalize(self) -> OutputParserFinalizeResult: if self._parser.current_channel == "final": visible_text += final_text - _, tool_calls = parse_tool_calls_from_tokens(self._raw_token_ids) + _, analysis_text, tool_calls = parse_tool_calls_from_tokens( + self._raw_token_ids + ) finish_reason = "tool_calls" if tool_calls else None + output_text_prefix = ( + f"\n{analysis_text}\n\n" if analysis_text else "" + ) + return OutputParserFinalizeResult( stream_text=stream_text, visible_text=visible_text, + output_text_prefix=output_text_prefix, tool_calls=tool_calls, finish_reason=finish_reason, ) diff --git a/omlx/scheduler.py b/omlx/scheduler.py index 30e02875..8ccfcf7a 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -3164,6 +3164,10 @@ def _process_batch_responses( output.new_text += final_result.stream_text if final_result.visible_text: request.output_text += final_result.visible_text + if final_result.output_text_prefix: + request.output_text = ( + final_result.output_text_prefix + request.output_text + ) if final_result.tool_calls: output.tool_calls = final_result.tool_calls if final_result.finish_reason: diff --git a/tests/test_harmony.py b/tests/test_harmony.py index 134de9c1..cd99dc5c 100644 --- a/tests/test_harmony.py +++ b/tests/test_harmony.py @@ -294,7 +294,7 @@ def test_extracts_tool_call(self, encoding): allowed_special="all", ) # prepend_start=True adds <|start|>assistant - output_text, tool_calls = parse_tool_calls_from_tokens( + output_text, analysis_text, tool_calls = parse_tool_calls_from_tokens( tokens, prepend_start=True ) assert isinstance(tool_calls, list) @@ -305,7 +305,7 @@ def test_extracts_final_text(self, encoding): "<|channel|>final<|message|>Hello world<|end|>", allowed_special="all", ) - output_text, tool_calls = parse_tool_calls_from_tokens( + output_text, analysis_text, tool_calls = parse_tool_calls_from_tokens( tokens, prepend_start=True ) assert "Hello world" in output_text diff --git a/tests/test_output_parser.py b/tests/test_output_parser.py index a089d291..8a00dced 100644 --- a/tests/test_output_parser.py +++ b/tests/test_output_parser.py @@ -190,3 +190,39 @@ def test_harmony_wrapper_regression(self): assert "\n" in "".join(stream) assert "\n" in "".join(stream) assert "".join(visible) == "Answer" + + def test_harmony_non_streaming_preserves_reasoning(self): + """Non-streaming output_text retains analysis-channel reasoning.""" + from omlx.api.thinking import extract_thinking + + encoding = load_harmony_encoding("HarmonyGptOss") + tokenizer = HarmonyTokenizer(encoding) + factory = detect_output_parser( + "gpt-oss-20b", + tokenizer, + {"model_type": "gpt_oss"}, + ) + session = factory.create_session(tokenizer) + + tokens = encoding.encode( + "<|channel|>analysis<|message|>Let me think about this<|end|>" + "<|start|>assistant<|channel|>final<|message|>Four<|return|>", + allowed_special="all", + ) + + visible_parts = [] + for token in tokens: + result = session.process_token(token) + visible_parts.append(result.visible_text) + + final = session.finalize() + visible_parts.append(final.visible_text) + + # Mirror scheduler aggregation: prepend any parser-provided prefix + # to the accumulated visible_text before exposing as output_text. + prefix = getattr(final, "output_text_prefix", "") + output_text = prefix + "".join(visible_parts) + + thinking, content = extract_thinking(output_text) + assert thinking == "Let me think about this" + assert content == "Four"