Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions omlx/adapter/harmony.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,22 +351,23 @@ def current_recipient(self) -> str | None:
def parse_tool_calls_from_tokens(
token_ids: list[int],
prepend_start: bool = True,
) -> tuple[str, list[dict[str, str]]]:
) -> tuple[str, str, list[dict[str, str]]]:
"""
Parse tool calls from complete token sequence (non-streaming).
Parse a complete Harmony token sequence (non-streaming).

Args:
token_ids: Model output token ID list
prepend_start: Whether to prepend "<|start|>assistant" tokens.
Set to False if token_ids already includes start tokens.

Returns:
(output_text, tool_calls)
- output_text: Text from final channel
(output_text, analysis_text, tool_calls)
- output_text: Text from the final channel
- analysis_text: Chain-of-thought text from the analysis channel
- tool_calls: [{"name": "...", "arguments": "..."}]
"""
if not token_ids:
return "", []
return "", "", []

try:
encoding = load_harmony_encoding("HarmonyGptOss")
Expand Down Expand Up @@ -399,6 +400,7 @@ def parse_tool_calls_from_tokens(
)

output_text = ""
analysis_text = ""
tool_calls = []

for msg in messages:
Expand All @@ -414,6 +416,13 @@ def parse_tool_calls_from_tokens(
if isinstance(text, str):
output_text += text

elif msg.channel == "analysis":
# Extract chain-of-thought text from analysis channel
for content in msg_content:
text = getattr(content, "text", None)
if isinstance(text, str):
analysis_text += text

elif msg.recipient and msg.recipient.startswith("functions."):
# Extract tool calls from commentary channel
name = msg.recipient[10:] # Remove "functions." prefix
Expand All @@ -424,8 +433,8 @@ def parse_tool_calls_from_tokens(
arguments += text
tool_calls.append({"name": name, "arguments": arguments})

return output_text, tool_calls
return output_text, analysis_text, tool_calls

except Exception as e:
logger.warning(f"Error parsing tool calls from tokens: {e}")
return "", []
return "", "", []
10 changes: 9 additions & 1 deletion omlx/adapter/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class OutputParserFinalizeResult:

stream_text: str = ""
visible_text: str = ""
output_text_prefix: str = ""
tool_calls: list[dict[str, str]] = field(default_factory=list)
finish_reason: str | None = None

Expand Down Expand Up @@ -123,12 +124,19 @@ def finalize(self) -> OutputParserFinalizeResult:
if self._parser.current_channel == "final":
visible_text += final_text

_, tool_calls = parse_tool_calls_from_tokens(self._raw_token_ids)
_, analysis_text, tool_calls = parse_tool_calls_from_tokens(
self._raw_token_ids
)
finish_reason = "tool_calls" if tool_calls else None

output_text_prefix = (
f"<think>\n{analysis_text}\n</think>\n" if analysis_text else ""
)

return OutputParserFinalizeResult(
stream_text=stream_text,
visible_text=visible_text,
output_text_prefix=output_text_prefix,
tool_calls=tool_calls,
finish_reason=finish_reason,
)
Expand Down
4 changes: 4 additions & 0 deletions omlx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3164,6 +3164,10 @@ def _process_batch_responses(
output.new_text += final_result.stream_text
if final_result.visible_text:
request.output_text += final_result.visible_text
if final_result.output_text_prefix:
request.output_text = (
final_result.output_text_prefix + request.output_text
)
if final_result.tool_calls:
output.tool_calls = final_result.tool_calls
if final_result.finish_reason:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_harmony.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def test_extracts_tool_call(self, encoding):
allowed_special="all",
)
# prepend_start=True adds <|start|>assistant
output_text, tool_calls = parse_tool_calls_from_tokens(
output_text, analysis_text, tool_calls = parse_tool_calls_from_tokens(
tokens, prepend_start=True
)
assert isinstance(tool_calls, list)
Expand All @@ -305,7 +305,7 @@ def test_extracts_final_text(self, encoding):
"<|channel|>final<|message|>Hello world<|end|>",
allowed_special="all",
)
output_text, tool_calls = parse_tool_calls_from_tokens(
output_text, analysis_text, tool_calls = parse_tool_calls_from_tokens(
tokens, prepend_start=True
)
assert "Hello world" in output_text
36 changes: 36 additions & 0 deletions tests/test_output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,39 @@ def test_harmony_wrapper_regression(self):
assert "<think>\n" in "".join(stream)
assert "</think>\n" in "".join(stream)
assert "".join(visible) == "Answer"

def test_harmony_non_streaming_preserves_reasoning(self):
"""Non-streaming output_text retains analysis-channel reasoning."""
from omlx.api.thinking import extract_thinking

encoding = load_harmony_encoding("HarmonyGptOss")
tokenizer = HarmonyTokenizer(encoding)
factory = detect_output_parser(
"gpt-oss-20b",
tokenizer,
{"model_type": "gpt_oss"},
)
session = factory.create_session(tokenizer)

tokens = encoding.encode(
"<|channel|>analysis<|message|>Let me think about this<|end|>"
"<|start|>assistant<|channel|>final<|message|>Four<|return|>",
allowed_special="all",
)

visible_parts = []
for token in tokens:
result = session.process_token(token)
visible_parts.append(result.visible_text)

final = session.finalize()
visible_parts.append(final.visible_text)

# Mirror scheduler aggregation: prepend any parser-provided prefix
# to the accumulated visible_text before exposing as output_text.
prefix = getattr(final, "output_text_prefix", "")
output_text = prefix + "".join(visible_parts)

thinking, content = extract_thinking(output_text)
assert thinking == "Let me think about this"
assert content == "Four"