Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cohere vision #2342

Merged
merged 7 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
6 changes: 4 additions & 2 deletions cookbook/models/aws/bedrock/image_agent_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from agno.media import Image
from agno.models.aws import AwsBedrock
from agno.tools.duckduckgo import DuckDuckGoTools
from agno.utils.media import download_image

agent = Agent(
model=AwsBedrock(id="amazon.nova-pro-v1:0"),
Expand All @@ -13,9 +14,10 @@

image_path = Path(__file__).parent.joinpath("sample.jpg")

download_image(url="https://upload.wikimedia.org/wikipedia/commons/0/0c/GoldenGateBridge-001.jpg", save_path=str(image_path))

# Read the image file content as bytes
with open(image_path, "rb") as img_file:
image_bytes = img_file.read()
image_bytes = image_path.read_bytes()

agent.print_response(
"Tell me about this image and give me the latest news about it.",
Expand Down
6 changes: 4 additions & 2 deletions cookbook/models/azure/ai_foundry/image_agent_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from agno.agent import Agent
from agno.media import Image
from agno.models.azure import AzureAIFoundry
from agno.utils.media import download_image

agent = Agent(
model=AzureAIFoundry(id="Llama-3.2-11B-Vision-Instruct"),
Expand All @@ -11,9 +12,10 @@

image_path = Path(__file__).parent.joinpath("sample.jpg")

download_image(url="https://upload.wikimedia.org/wikipedia/commons/0/0c/GoldenGateBridge-001.jpg", save_path=str(image_path))

# Read the image file content as bytes
with open(image_path, "rb") as img_file:
image_bytes = img_file.read()
image_bytes = image_path.read_bytes()

agent.print_response(
"Tell me about this image.",
Expand Down
18 changes: 18 additions & 0 deletions cookbook/models/cohere/image_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from agno.agent import Agent
from agno.media import Image
from agno.models.cohere import Cohere

agent = Agent(
model=Cohere(id="c4ai-aya-vision-8b"),
markdown=True,
)

agent.print_response(
"Tell me about this image.",
images=[
Image(
url="https://upload.wikimedia.org/wikipedia/commons/0/0c/GoldenGateBridge-001.jpg"
)
],
stream=True,
)
26 changes: 26 additions & 0 deletions cookbook/models/cohere/image_agent_bytes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from pathlib import Path

from agno.agent import Agent
from agno.media import Image
from agno.models.cohere.chat import Cohere
from agno.utils.media import download_image

agent = Agent(
model=Cohere(id="c4ai-aya-vision-8b"),
markdown=True,
)

image_path = Path(__file__).parent.joinpath("sample.jpg")

download_image(url="https://upload.wikimedia.org/wikipedia/commons/0/0c/GoldenGateBridge-001.jpg", save_path=str(image_path))

# Read the image file content as bytes
image_bytes = image_path.read_bytes()

agent.print_response(
"Tell me about this image.",
images=[
Image(content=image_bytes),
],
stream=True,
)
23 changes: 23 additions & 0 deletions cookbook/models/cohere/image_agent_local_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pathlib import Path

from agno.agent import Agent
from agno.media import Image
from agno.models.cohere.chat import Cohere
from agno.utils.media import download_image

agent = Agent(
model=Cohere(id="c4ai-aya-vision-8b"),
markdown=True,
)

image_path = Path(__file__).parent.joinpath("sample.jpg")

download_image(url="https://upload.wikimedia.org/wikipedia/commons/0/0c/GoldenGateBridge-001.jpg", save_path=str(image_path))

agent.print_response(
"Tell me about this image.",
images=[
Image(filepath=image_path),
],
stream=True,
)
4 changes: 1 addition & 3 deletions cookbook/models/ibm/watsonx/image_agent_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from agno.agent import Agent
from agno.media import Image
from agno.models.ibm import WatsonX
from agno.tools.duckduckgo import DuckDuckGoTools

agent = Agent(
model=WatsonX(id="meta-llama/llama-3-2-11b-vision-instruct"),
Expand All @@ -13,8 +12,7 @@
image_path = Path(__file__).parent.joinpath("sample.jpg")

# Read the image file content as bytes
with open(image_path, "rb") as img_file:
image_bytes = img_file.read()
image_bytes = image_path.read_bytes()

agent.print_response(
"Tell me about this image and and give me the latest news about it.",
Expand Down
6 changes: 4 additions & 2 deletions cookbook/models/openai/image_agent_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from agno.media import Image
from agno.models.openai import OpenAIChat
from agno.tools.duckduckgo import DuckDuckGoTools
from agno.utils.media import download_image

agent = Agent(
model=OpenAIChat(id="gpt-4o"),
Expand All @@ -13,9 +14,10 @@

image_path = Path(__file__).parent.joinpath("sample.jpg")

download_image(url="https://upload.wikimedia.org/wikipedia/commons/0/0c/GoldenGateBridge-001.jpg", save_path=str(image_path))

# Read the image file content as bytes
with open(image_path, "rb") as img_file:
image_bytes = img_file.read()
image_bytes = image_path.read_bytes()

agent.print_response(
"Tell me about this image and give me the latest news about it.",
Expand Down
3 changes: 1 addition & 2 deletions cookbook/models/together/image_agent_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
image_path = Path(__file__).parent.joinpath("sample.jpg")

# Read the image file content as bytes
with open(image_path, "rb") as img_file:
image_bytes = img_file.read()
image_bytes = image_path.read_bytes()

agent.print_response(
"Tell me about this image",
Expand Down
6 changes: 4 additions & 2 deletions cookbook/models/xai/image_agent_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from agno.media import Image
from agno.models.xai import xAI
from agno.tools.duckduckgo import DuckDuckGoTools
from agno.utils.media import download_image

agent = Agent(
model=xAI(id="grok-2-vision-latest"),
Expand All @@ -13,9 +14,10 @@

image_path = Path(__file__).parent.joinpath("sample.jpg")

download_image(url="https://upload.wikimedia.org/wikipedia/commons/0/0c/GoldenGateBridge-001.jpg", save_path=str(image_path))

# Read the image file content as bytes
with open(image_path, "rb") as img_file:
image_bytes = img_file.read()
image_bytes = image_path.read_bytes()

agent.print_response(
"Tell me about this image and give me the latest news about it.",
Expand Down
107 changes: 79 additions & 28 deletions libs/agno/agno/models/cohere/chat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from dataclasses import dataclass
from os import getenv
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple
from pathlib import Path
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Sequence, Tuple

from agno.exceptions import ModelProviderError
from agno.media import Image
from agno.models.base import MessageData, Model
from agno.models.message import Message
from agno.models.response import ModelResponse
Expand All @@ -17,6 +19,78 @@
raise ImportError("`cohere` not installed. Please install using `pip install cohere`")


def _format_images_for_message(message: Message, images: Sequence[Image]) -> List[Dict[str, Any]]:
"""
Format an image into the format expected by WatsonX.
"""

# Create a default message content with text
message_content_with_image: List[Dict[str, Any]] = [{"type": "text", "text": message.content}]

# Add images to the message content
for image in images:
try:
if image.content is not None:
image_content = image.content
elif image.url is not None:
image_content = image.image_url_content
elif image.filepath is not None:
if isinstance(image.filepath, Path):
image_content = image.filepath.read_bytes()
else:
with open(image.filepath, "rb") as f:
image_content = f.read()
else:
logger.warning(f"Unsupported image format: {image}")
continue

if image_content is not None:
import base64

base64_image = base64.b64encode(image_content).decode("utf-8")
image_url = f"data:image/jpeg;base64,{base64_image}"
image_payload = {"type": "image_url", "image_url": {"url": image_url}}
message_content_with_image.append(image_payload)

except Exception as e:
logger.error(f"Failed to process image: {str(e)}")

# Update the message content with the images
return message_content_with_image


def _format_messages(messages: List[Message]) -> List[Dict[str, Any]]:
"""
Format messages for the Cohere API.

Args:
messages (List[Message]): The list of messages.

Returns:
List[Dict[str, Any]]: The formatted messages.
"""
formatted_messages = []
for message in messages:
message_dict = {
"role": message.role,
"content": message.content,
"name": message.name,
"tool_call_id": message.tool_call_id,
"tool_calls": message.tool_calls,
}

if message.images is not None and len(message.images) > 0:
# Ignore non-string message content
if isinstance(message.content, str):
message_content_with_image = _format_images_for_message(message=message, images=message.images)
if len(message_content_with_image) > 1:
message_dict["content"] = message_content_with_image

message_dict = {k: v for k, v in message_dict.items() if v is not None}
formatted_messages.append(message_dict)
return formatted_messages


@dataclass
class Cohere(Model):
id: str = "command-r-plus"
Expand Down Expand Up @@ -116,29 +190,6 @@ def request_kwargs(self) -> Dict[str, Any]:
_request_params.update(self.request_params)
return _request_params

def _format_messages(self, messages: List[Message]) -> List[Dict[str, Any]]:
"""
Format messages for the Cohere API.

Args:
messages (List[Message]): The list of messages.

Returns:
List[Dict[str, Any]]: The formatted messages.
"""
formatted_messages = []
for message in messages:
message_dict = {
"role": message.role,
"content": message.content,
"name": message.name,
"tool_call_id": message.tool_call_id,
"tool_calls": message.tool_calls,
}
message_dict = {k: v for k, v in message_dict.items() if v is not None}
formatted_messages.append(message_dict)
return formatted_messages

def invoke(self, messages: List[Message]) -> ChatResponse:
"""
Invoke a non-streamed chat response from the Cohere API.
Expand All @@ -153,7 +204,7 @@ def invoke(self, messages: List[Message]) -> ChatResponse:
request_kwargs = self.request_kwargs

try:
return self.get_client().chat(model=self.id, messages=self._format_messages(messages), **request_kwargs) # type: ignore
return self.get_client().chat(model=self.id, messages=_format_messages(messages), **request_kwargs) # type: ignore
except Exception as e:
logger.error(f"Unexpected error calling Cohere API: {str(e)}")
raise ModelProviderError(message=str(e), model_name=self.name, model_id=self.id) from e
Expand All @@ -173,7 +224,7 @@ def invoke_stream(self, messages: List[Message]) -> Iterator[StreamedChatRespons
try:
return self.get_client().chat_stream(
model=self.id,
messages=self._format_messages(messages), # type: ignore
messages=_format_messages(messages), # type: ignore
**request_kwargs,
)
except Exception as e:
Expand All @@ -195,7 +246,7 @@ async def ainvoke(self, messages: List[Message]) -> ChatResponse:
try:
return await self.get_async_client().chat(
model=self.id,
messages=self._format_messages(messages), # type: ignore
messages=_format_messages(messages), # type: ignore
**request_kwargs,
)
except Exception as e:
Expand All @@ -217,7 +268,7 @@ async def ainvoke_stream(self, messages: List[Message]) -> AsyncIterator[Streame
try:
async for response in self.get_async_client().chat_stream(
model=self.id,
messages=self._format_messages(messages), # type: ignore
messages=_format_messages(messages), # type: ignore
**request_kwargs,
):
yield response
Expand Down
2 changes: 1 addition & 1 deletion libs/agno/agno/utils/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import requests


def download_image(url, save_path):
def download_image(url: str, save_path: str) -> bool:
"""
Downloads an image from the specified URL and saves it to the given local path.
Parameters:
Expand Down
2 changes: 1 addition & 1 deletion libs/agno/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ postgres = ["psycopg-binary", "psycopg"]
# Dependencies for Vector databases
pgvector = ["pgvector"]
chromadb = ["chromadb"]
lancedb = ["lancedb", "tantivy"]
lancedb = ["lancedb==0.20.0", "tantivy"]
qdrant = ["qdrant-client"]
cassandra = ["cassio"]
mongodb = ["pymongo[srv]"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ def test_image_input_bytes():
"""
agent = Agent(model=AwsBedrock(id="amazon.nova-pro-v1:0"), markdown=True, telemetry=False, monitoring=False)

image_path = Path(__file__).parent.joinpath("../../sample_image.jpg")
image_path = Path(__file__).parent.parent.parent.joinpath("sample_image.jpg")

# Read the image file content as bytes
with open(image_path, "rb") as img_file:
image_bytes = img_file.read()
image_bytes = image_path.read_bytes()

response = agent.run(
"Tell me about this image.",
Expand Down
Loading