Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cohere vision #2342

Merged
merged 7 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
18 changes: 18 additions & 0 deletions cookbook/models/cohere/image_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from agno.agent import Agent
from agno.media import Image
from agno.models.cohere import Cohere

agent = Agent(
model=Cohere(id="c4ai-aya-vision-8b"),
markdown=True,
)

agent.print_response(
"Tell me about this image.",
images=[
Image(
url="https://upload.wikimedia.org/wikipedia/commons/0/0c/GoldenGateBridge-001.jpg"
)
],
stream=True,
)
24 changes: 24 additions & 0 deletions cookbook/models/cohere/image_agent_bytes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from pathlib import Path

from agno.agent import Agent
from agno.media import Image
from agno.models.cohere.chat import Cohere

agent = Agent(
model=Cohere(id="c4ai-aya-vision-8b"),
markdown=True,
)

image_path = Path(__file__).parent.joinpath("sample.jpg")

# Read the image file content as bytes
with open(image_path, "rb") as img_file:
image_bytes = img_file.read()

agent.print_response(
"Tell me about this image.",
images=[
Image(content=image_bytes),
],
stream=True,
)
20 changes: 20 additions & 0 deletions cookbook/models/cohere/image_agent_local_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from pathlib import Path

from agno.agent import Agent
from agno.media import Image
from agno.models.cohere.chat import Cohere

agent = Agent(
model=Cohere(id="c4ai-aya-vision-8b"),
markdown=True,
)

image_path = Path(__file__).parent.joinpath("sample.jpg")

agent.print_response(
"Tell me about this image.",
images=[
Image(filepath=image_path),
],
stream=True,
)
107 changes: 79 additions & 28 deletions libs/agno/agno/models/cohere/chat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from dataclasses import dataclass
from os import getenv
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple
from pathlib import Path
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Sequence, Tuple

from agno.exceptions import ModelProviderError
from agno.media import Image
from agno.models.base import MessageData, Model
from agno.models.message import Message
from agno.models.response import ModelResponse
Expand All @@ -17,6 +19,78 @@
raise ImportError("`cohere` not installed. Please install using `pip install cohere`")


def _format_images_for_message(message: Message, images: Sequence[Image]) -> List[Dict[str, Any]]:
"""
Format an image into the format expected by WatsonX.
"""

# Create a default message content with text
message_content_with_image: List[Dict[str, Any]] = [{"type": "text", "text": message.content}]

# Add images to the message content
for image in images:
try:
if image.content is not None:
image_content = image.content
elif image.url is not None:
image_content = image.image_url_content
elif image.filepath is not None:
if isinstance(image.filepath, Path):
image_content = image.filepath.read_bytes()
else:
with open(image.filepath, "rb") as f:
image_content = f.read()
else:
logger.warning(f"Unsupported image format: {image}")
continue

if image_content is not None:
import base64

base64_image = base64.b64encode(image_content).decode("utf-8")
image_url = f"data:image/jpeg;base64,{base64_image}"
image_payload = {"type": "image_url", "image_url": {"url": image_url}}
message_content_with_image.append(image_payload)

except Exception as e:
logger.error(f"Failed to process image: {str(e)}")

# Update the message content with the images
return message_content_with_image


def _format_messages(messages: List[Message]) -> List[Dict[str, Any]]:
"""
Format messages for the Cohere API.

Args:
messages (List[Message]): The list of messages.

Returns:
List[Dict[str, Any]]: The formatted messages.
"""
formatted_messages = []
for message in messages:
message_dict = {
"role": message.role,
"content": message.content,
"name": message.name,
"tool_call_id": message.tool_call_id,
"tool_calls": message.tool_calls,
}

if message.images is not None and len(message.images) > 0:
# Ignore non-string message content
if isinstance(message.content, str):
message_content_with_image = _format_images_for_message(message=message, images=message.images)
if len(message_content_with_image) > 1:
message_dict["content"] = message_content_with_image

message_dict = {k: v for k, v in message_dict.items() if v is not None}
formatted_messages.append(message_dict)
return formatted_messages


@dataclass
class Cohere(Model):
id: str = "command-r-plus"
Expand Down Expand Up @@ -116,29 +190,6 @@ def request_kwargs(self) -> Dict[str, Any]:
_request_params.update(self.request_params)
return _request_params

def _format_messages(self, messages: List[Message]) -> List[Dict[str, Any]]:
"""
Format messages for the Cohere API.

Args:
messages (List[Message]): The list of messages.

Returns:
List[Dict[str, Any]]: The formatted messages.
"""
formatted_messages = []
for message in messages:
message_dict = {
"role": message.role,
"content": message.content,
"name": message.name,
"tool_call_id": message.tool_call_id,
"tool_calls": message.tool_calls,
}
message_dict = {k: v for k, v in message_dict.items() if v is not None}
formatted_messages.append(message_dict)
return formatted_messages

def invoke(self, messages: List[Message]) -> ChatResponse:
"""
Invoke a non-streamed chat response from the Cohere API.
Expand All @@ -153,7 +204,7 @@ def invoke(self, messages: List[Message]) -> ChatResponse:
request_kwargs = self.request_kwargs

try:
return self.get_client().chat(model=self.id, messages=self._format_messages(messages), **request_kwargs) # type: ignore
return self.get_client().chat(model=self.id, messages=_format_messages(messages), **request_kwargs) # type: ignore
except Exception as e:
logger.error(f"Unexpected error calling Cohere API: {str(e)}")
raise ModelProviderError(message=str(e), model_name=self.name, model_id=self.id) from e
Expand All @@ -173,7 +224,7 @@ def invoke_stream(self, messages: List[Message]) -> Iterator[StreamedChatRespons
try:
return self.get_client().chat_stream(
model=self.id,
messages=self._format_messages(messages), # type: ignore
messages=_format_messages(messages), # type: ignore
**request_kwargs,
)
except Exception as e:
Expand All @@ -195,7 +246,7 @@ async def ainvoke(self, messages: List[Message]) -> ChatResponse:
try:
return await self.get_async_client().chat(
model=self.id,
messages=self._format_messages(messages), # type: ignore
messages=_format_messages(messages), # type: ignore
**request_kwargs,
)
except Exception as e:
Expand All @@ -217,7 +268,7 @@ async def ainvoke_stream(self, messages: List[Message]) -> AsyncIterator[Streame
try:
async for response in self.get_async_client().chat_stream(
model=self.id,
messages=self._format_messages(messages), # type: ignore
messages=_format_messages(messages), # type: ignore
**request_kwargs,
):
yield response
Expand Down
79 changes: 79 additions & 0 deletions libs/agno/tests/integration/models/cohere/test_multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from pathlib import Path

import pytest

from agno.agent.agent import Agent
from agno.media import Image
from agno.models.cohere.chat import Cohere
from agno.tools.duckduckgo import DuckDuckGoTools


def test_image_input():
agent = Agent(
model=Cohere(id="c4ai-aya-vision-8b"),
add_history_to_messages=True,
markdown=True,
telemetry=False,
monitoring=False,
)

response = agent.run(
"Tell me about this image.",
images=[Image(url="https://upload.wikimedia.org/wikipedia/commons/0/0c/GoldenGateBridge-001.jpg")],
)

assert "golden" in response.content.lower()

# Just check it doesn't break on subsequent messages
response = agent.run("Where can I find more information?")
assert [message.role for message in response.messages] == ["system", "user", "assistant", "user", "assistant"]


def test_image_input_bytes():
agent = Agent(model=Cohere(id="c4ai-aya-vision-8b"), telemetry=False, monitoring=False)

image_path = Path(__file__).parent.joinpath("../sample_image.jpg")

# Read the image file content as bytes
with open(image_path, "rb") as img_file:
image_bytes = img_file.read()

response = agent.run(
"Tell me about this image.",
images=[Image(content=image_bytes)],
)

assert "golden" in response.content.lower()
assert "bridge" in response.content.lower()


def test_image_input_local_file():
agent = Agent(model=Cohere(id="c4ai-aya-vision-8b"), telemetry=False, monitoring=False)

image_path = Path(__file__).parent.joinpath("../sample_image.jpg")

response = agent.run(
"Tell me about this image.",
images=[Image(filepath=image_path)],
)

assert "golden" in response.content.lower()
assert "bridge" in response.content.lower()


@pytest.mark.skip(reason="Image with tool call is not supported yet.")
def test_image_input_with_tool_call():
agent = Agent(
model=Cohere(id="c4ai-aya-vision-8b"),
tools=[DuckDuckGoTools()],
markdown=True,
telemetry=False,
monitoring=False,
)

response = agent.run(
"Tell me about this image and give me the latest news about it.",
images=[Image(url="https://upload.wikimedia.org/wikipedia/commons/0/0c/GoldenGateBridge-001.jpg")],
)

assert "golden" in response.content.lower()
Empty file.