diff --git a/omlx/adapter/harmony.py b/omlx/adapter/harmony.py
index 8588243e..ca1c796b 100644
--- a/omlx/adapter/harmony.py
+++ b/omlx/adapter/harmony.py
@@ -351,9 +351,9 @@ def current_recipient(self) -> str | None:
def parse_tool_calls_from_tokens(
token_ids: list[int],
prepend_start: bool = True,
-) -> tuple[str, list[dict[str, str]]]:
+) -> tuple[str, str, list[dict[str, str]]]:
"""
- Parse tool calls from complete token sequence (non-streaming).
+ Parse a complete Harmony token sequence (non-streaming).
Args:
token_ids: Model output token ID list
@@ -361,12 +361,13 @@ def parse_tool_calls_from_tokens(
Set to False if token_ids already includes start tokens.
Returns:
- (output_text, tool_calls)
- - output_text: Text from final channel
+ (output_text, analysis_text, tool_calls)
+ - output_text: Text from the final channel
+ - analysis_text: Chain-of-thought text from the analysis channel
- tool_calls: [{"name": "...", "arguments": "..."}]
"""
if not token_ids:
- return "", []
+ return "", "", []
try:
encoding = load_harmony_encoding("HarmonyGptOss")
@@ -399,6 +400,7 @@ def parse_tool_calls_from_tokens(
)
output_text = ""
+ analysis_text = ""
tool_calls = []
for msg in messages:
@@ -414,6 +416,13 @@ def parse_tool_calls_from_tokens(
if isinstance(text, str):
output_text += text
+ elif msg.channel == "analysis":
+ # Extract chain-of-thought text from analysis channel
+ for content in msg_content:
+ text = getattr(content, "text", None)
+ if isinstance(text, str):
+ analysis_text += text
+
elif msg.recipient and msg.recipient.startswith("functions."):
# Extract tool calls from commentary channel
name = msg.recipient[10:] # Remove "functions." prefix
@@ -424,8 +433,8 @@ def parse_tool_calls_from_tokens(
arguments += text
tool_calls.append({"name": name, "arguments": arguments})
- return output_text, tool_calls
+ return output_text, analysis_text, tool_calls
except Exception as e:
logger.warning(f"Error parsing tool calls from tokens: {e}")
- return "", []
+ return "", "", []
diff --git a/omlx/adapter/output_parser.py b/omlx/adapter/output_parser.py
index b03bc83c..a9f8acd4 100644
--- a/omlx/adapter/output_parser.py
+++ b/omlx/adapter/output_parser.py
@@ -37,6 +37,7 @@ class OutputParserFinalizeResult:
stream_text: str = ""
visible_text: str = ""
+ output_text_prefix: str = ""
tool_calls: list[dict[str, str]] = field(default_factory=list)
finish_reason: str | None = None
@@ -123,12 +124,19 @@ def finalize(self) -> OutputParserFinalizeResult:
if self._parser.current_channel == "final":
visible_text += final_text
- _, tool_calls = parse_tool_calls_from_tokens(self._raw_token_ids)
+ _, analysis_text, tool_calls = parse_tool_calls_from_tokens(
+ self._raw_token_ids
+ )
finish_reason = "tool_calls" if tool_calls else None
+ output_text_prefix = (
+ f"\n{analysis_text}\n\n" if analysis_text else ""
+ )
+
return OutputParserFinalizeResult(
stream_text=stream_text,
visible_text=visible_text,
+ output_text_prefix=output_text_prefix,
tool_calls=tool_calls,
finish_reason=finish_reason,
)
diff --git a/omlx/scheduler.py b/omlx/scheduler.py
index 30e02875..8ccfcf7a 100644
--- a/omlx/scheduler.py
+++ b/omlx/scheduler.py
@@ -3164,6 +3164,10 @@ def _process_batch_responses(
output.new_text += final_result.stream_text
if final_result.visible_text:
request.output_text += final_result.visible_text
+ if final_result.output_text_prefix:
+ request.output_text = (
+ final_result.output_text_prefix + request.output_text
+ )
if final_result.tool_calls:
output.tool_calls = final_result.tool_calls
if final_result.finish_reason:
diff --git a/tests/test_harmony.py b/tests/test_harmony.py
index 134de9c1..cd99dc5c 100644
--- a/tests/test_harmony.py
+++ b/tests/test_harmony.py
@@ -294,7 +294,7 @@ def test_extracts_tool_call(self, encoding):
allowed_special="all",
)
# prepend_start=True adds <|start|>assistant
- output_text, tool_calls = parse_tool_calls_from_tokens(
+ output_text, analysis_text, tool_calls = parse_tool_calls_from_tokens(
tokens, prepend_start=True
)
assert isinstance(tool_calls, list)
@@ -305,7 +305,7 @@ def test_extracts_final_text(self, encoding):
"<|channel|>final<|message|>Hello world<|end|>",
allowed_special="all",
)
- output_text, tool_calls = parse_tool_calls_from_tokens(
+ output_text, analysis_text, tool_calls = parse_tool_calls_from_tokens(
tokens, prepend_start=True
)
assert "Hello world" in output_text
diff --git a/tests/test_output_parser.py b/tests/test_output_parser.py
index a089d291..8a00dced 100644
--- a/tests/test_output_parser.py
+++ b/tests/test_output_parser.py
@@ -190,3 +190,39 @@ def test_harmony_wrapper_regression(self):
assert "\n" in "".join(stream)
assert "\n" in "".join(stream)
assert "".join(visible) == "Answer"
+
+ def test_harmony_non_streaming_preserves_reasoning(self):
+ """Non-streaming output_text retains analysis-channel reasoning."""
+ from omlx.api.thinking import extract_thinking
+
+ encoding = load_harmony_encoding("HarmonyGptOss")
+ tokenizer = HarmonyTokenizer(encoding)
+ factory = detect_output_parser(
+ "gpt-oss-20b",
+ tokenizer,
+ {"model_type": "gpt_oss"},
+ )
+ session = factory.create_session(tokenizer)
+
+ tokens = encoding.encode(
+ "<|channel|>analysis<|message|>Let me think about this<|end|>"
+ "<|start|>assistant<|channel|>final<|message|>Four<|return|>",
+ allowed_special="all",
+ )
+
+ visible_parts = []
+ for token in tokens:
+ result = session.process_token(token)
+ visible_parts.append(result.visible_text)
+
+ final = session.finalize()
+ visible_parts.append(final.visible_text)
+
+ # Mirror scheduler aggregation: prepend any parser-provided prefix
+ # to the accumulated visible_text before exposing as output_text.
+ prefix = getattr(final, "output_text_prefix", "")
+ output_text = prefix + "".join(visible_parts)
+
+ thinking, content = extract_thinking(output_text)
+ assert thinking == "Let me think about this"
+ assert content == "Four"