Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import re
from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union, get_args

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message
from haystack.dataclasses import ChatMessage, ComponentInfo, ImageContent, TextContent, ToolCall
from haystack.dataclasses import ChatMessage, ComponentInfo, ImageContent, ReasoningContent, TextContent, ToolCall
from haystack.dataclasses.streaming_chunk import (
AsyncStreamingCallbackT,
FinishReason,
Expand Down Expand Up @@ -202,11 +203,20 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage:
)
)

# Extract reasoning from content if present, even with tool calls
reasoning_content = None
if chat_response.message.content and hasattr(chat_response.message.content[0], "text"):
raw_content = chat_response.message.content[0].text
reasoning_content, _ = _extract_reasoning_from_response(raw_content)

# Create message with tool plan as text and tool calls in the format Haystack expects
tool_plan = chat_response.message.tool_plan or ""
message = ChatMessage.from_assistant(text=tool_plan, tool_calls=tool_calls)
message = ChatMessage.from_assistant(text=tool_plan, tool_calls=tool_calls, reasoning=reasoning_content)
elif chat_response.message.content and hasattr(chat_response.message.content[0], "text"):
message = ChatMessage.from_assistant(chat_response.message.content[0].text)
raw_content = chat_response.message.content[0].text
# Extract reasoning content if present
reasoning_content, cleaned_content = _extract_reasoning_from_response(raw_content)
message = ChatMessage.from_assistant(cleaned_content, reasoning=reasoning_content)
else:
# Handle the case where neither tool_calls nor content exists
logger.warning(f"Received empty response from Cohere API: {chat_response.message}")
Expand Down Expand Up @@ -350,6 +360,125 @@ def _convert_cohere_chunk_to_streaming_chunk(
)


def _extract_reasoning_from_response(response_text: str) -> tuple[Optional[ReasoningContent], str]:
"""
Extract reasoning content from Cohere's response if present.

Cohere's reasoning-capable models (like Command A Reasoning) may include reasoning content
in various formats. This function attempts to identify and extract such content.

:param response_text: The raw response text from Cohere
:returns: A tuple of (ReasoningContent or None, cleaned_response_text)
"""
if not response_text or not isinstance(response_text, str):
return None, response_text

# Check for reasoning markers that Cohere might use

# Pattern 1: Look for thinking/reasoning tags
thinking_patterns = [
r"<thinking>(.*?)</thinking>",
r"<reasoning>(.*?)</reasoning>",
r"## Reasoning\s*\n(.*?)(?=\n## |$)",
r"## Thinking\s*\n(.*?)(?=\n## |$)",
]

for pattern in thinking_patterns:
match = re.search(pattern, response_text, re.DOTALL | re.IGNORECASE)
if match:
reasoning_text = match.group(1).strip()
cleaned_content = re.sub(pattern, "", response_text, flags=re.DOTALL | re.IGNORECASE).strip()
# Apply minimum length threshold for tag-based reasoning
min_reasoning_length = 30
if len(reasoning_text) > min_reasoning_length:
return ReasoningContent(reasoning_text=reasoning_text), cleaned_content
else:
# Content too short, but still clean the tags
return None, cleaned_content

# Pattern 2: Look for step-by-step reasoning at start
lines = response_text.split("\n")
reasoning_lines = []
content_lines = []
found_separator = False

for i, line in enumerate(lines):
stripped_line = line.strip()
# Look for reasoning indicators at the beginning of lines (more precise)
if (
stripped_line.startswith(("Step ", "First,", "Let me think", "I need to solve", "To solve"))
or stripped_line.startswith(("## Reasoning", "## Thinking", "## My reasoning"))
or (
len(stripped_line) > 0
and stripped_line.endswith(":")
and ("reasoning" in stripped_line.lower() or "thinking" in stripped_line.lower())
)
):
# Look for a clear separator to determine where reasoning ends
reasoning_end = len(lines) # Default to end of text
for j in range(i + 1, len(lines)):
next_line = lines[j].strip()
if next_line.startswith(
("Based on", "Therefore", "In conclusion", "So,", "Thus,", "## Solution", "## Answer")
):
reasoning_end = j
break

reasoning_lines = lines[:reasoning_end]
content_lines = lines[reasoning_end:]
found_separator = True
break
# Stop looking after first few lines
max_lines_to_check = 10
if i > max_lines_to_check:
break

if found_separator and reasoning_lines:
reasoning_text = "\n".join(reasoning_lines).strip()
cleaned_content = "\n".join(content_lines).strip()
min_reasoning_length = 30
if len(reasoning_text) > min_reasoning_length: # Minimum threshold
return ReasoningContent(reasoning_text=reasoning_text), cleaned_content

# No reasoning detected
return None, response_text


def _convert_streaming_chunks_to_chat_message_with_reasoning(chunks: List[StreamingChunk]) -> ChatMessage:
"""
Convert streaming chunks to ChatMessage with reasoning extraction support.

This is a custom version of the core utility function that adds reasoning content
extraction for Cohere responses.
"""
# Use the core utility to get the base ChatMessage
base_message = _convert_streaming_chunks_to_chat_message(chunks=chunks)

# Extract text content to check for reasoning
if not base_message.text:
return base_message

# Use the text property for reasoning extraction
combined_text = base_message.text

# Extract reasoning if present
reasoning_content, cleaned_text = _extract_reasoning_from_response(combined_text)

if reasoning_content is None:
# No reasoning found, return original message
return base_message

# Create new message with reasoning support
new_message = ChatMessage.from_assistant(
text=cleaned_text,
reasoning=reasoning_content,
tool_calls=base_message.tool_calls,
meta=base_message.meta,
)

return new_message


def _parse_streaming_response(
response: Iterator[StreamedChatResponseV2],
model: str,
Expand Down Expand Up @@ -381,7 +510,7 @@ def _parse_streaming_response(
chunks.append(streaming_chunk)
streaming_callback(streaming_chunk)

return _convert_streaming_chunks_to_chat_message(chunks=chunks)
return _convert_streaming_chunks_to_chat_message_with_reasoning(chunks=chunks)


async def _parse_async_streaming_response(
Expand Down Expand Up @@ -409,7 +538,7 @@ async def _parse_async_streaming_response(
chunks.append(streaming_chunk)
await streaming_callback(streaming_chunk)

return _convert_streaming_chunks_to_chat_message(chunks=chunks)
return _convert_streaming_chunks_to_chat_message_with_reasoning(chunks=chunks)


@component
Expand Down
Loading
Loading