Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add litellm model #2356

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add changes in lite llm model for tool calling
kausmeows committed Mar 11, 2025
commit 1a13ee81698c2fe2868b568647a40536b6da69c1
11 changes: 7 additions & 4 deletions cookbook/models/litellm/tool_use.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from agno.agent import Agent
from agno.models.litellm import LiteLLMSDK
from agno.tools.duckduckgo import DuckDuckGoTools
from agno.tools.yfinance import YFinanceTools

openai_agent = Agent(
model=LiteLLMSDK(
id="gpt-4o",
# id="gpt-4o",
id="huggingface/mistralai/Mistral-7B-Instruct-v0.2",
top_p=0.95,
name="LiteLLM",
),
markdown=True,
tools=[DuckDuckGoTools()],
tools=[YFinanceTools()],
)

openai_agent.print_response("What's the age of Elon Musk")
# Ask a question that would likely trigger tool use
openai_agent.print_response("How is TSLA stock doing right now?")
261 changes: 167 additions & 94 deletions libs/agno/agno/models/litellm/litellm_chat.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
from typing import Any, Dict, Iterator, List, Mapping, Optional, Union

import litellm
from pydantic import BaseModel

from agno.models.base import Model
from agno.models.message import Message
@@ -39,44 +40,117 @@ class LiteLLMSDK(Model):
max_tokens: Optional[int] = None
temperature: float = 0.7
top_p: float = 1.0

# Additional parameters
request_params: Optional[Dict[str, Any]] = None

def __post_init__(self):
"""
Initialize the model after the dataclass initialization.

For Huggingface, we need to use the model name directly
LiteLLM expects Huggingface models to be formatted as: "huggingface/mistralai/Mistral-7B-Instruct-v0.2"
But internally it needs to use just "mistralai/Mistral-7B-Instruct-v0.2"
"""
"""Initialize the model after the dataclass initialization."""
super().__post_init__()
# Handle Huggingface models
if self.id.startswith("huggingface/"):
# Extract the actual model name without the "huggingface/" prefix
self.model_name = self.id.replace("huggingface/", "")

logger.info(f"Using Huggingface model: {self.model_name}")
# Keep the full model name for LiteLLM routing
self.model_name = self.id
logger.info(
f"Using Huggingface model via LiteLLM: {self.model_name}")
else:
self.model_name = self.id

# Set up API key from environment variable if not already set
if not self.api_key:
self.api_key = getenv("LITELLM_API_KEY")
if not self.api_key:
logger.warning("LITELLM_API_KEY not set. Please set the LITELLM_API_KEY environment variable.")
logger.warning(
"LITELLM_API_KEY not set. Please set the LITELLM_API_KEY environment variable.")

def to_dict(self) -> Dict[str, Any]:
"""Convert the model to a dictionary."""
model_dict = {
"id": self.id,
"name": self.name,
"provider": self.provider,
"api_key": self.api_key,
"api_base": self.api_base,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"request_params": self.request_params,
}
# Add tools if present
if self._tools is not None:
model_dict["tools"] = self._tools
# Remove None values
return {k: v for k, v in model_dict.items() if v is not None}

@staticmethod
def parse_tool_calls(tool_calls_data: List[Any]) -> List[Dict[str, Any]]:
"""Build tool calls from streamed tool call data."""
tool_calls: List[Dict[str, Any]] = []
for tool_call in tool_calls_data:
tool_call_entry = {
"id": tool_call.id if hasattr(tool_call, 'id') else None,
"type": "function",
"function": {
"name": tool_call.function.name if hasattr(tool_call.function, 'name') else "",
"arguments": tool_call.function.arguments if hasattr(tool_call.function, 'arguments') else ""
}
}
tool_calls.append(tool_call_entry)
return tool_calls

@property
def request_kwargs(self) -> Dict[str, Any]:
def _format_message(self, message: Message) -> Dict[str, Any]:
"""
Returns keyword arguments for API requests.
Format a message into the format expected by LiteLLM.

Args:
message (Message): The message to format.

Returns:
Dict[str, Any]: The API kwargs for the model.
Dict[str, Any]: The formatted message.
"""
base_params: Dict[str, Any] = {}
_message: Dict[str, Any] = {
"role": message.role,
"content": message.content,
}

# Handle images if present
if message.role == "user" and message.images is not None and len(message.images) > 0:
content_parts = []

# Add text content if it exists
if message.content:
content_parts.append({"type": "text", "text": message.content})

# Add image content
for image in message.images:
if image.url is not None:
content_parts.append(
{"type": "image_url", "image_url": {"url": image.url}})
elif image.filepath is not None:
content_parts.append({"type": "image_url", "image_url": {
"url": f"file://{image.filepath}"}})
elif image.content is not None and isinstance(image.content, bytes):
import base64
base64_image = base64.b64encode(
image.content).decode("utf-8")
content_parts.append(
{"type": "image_url", "image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"}}
)

# Replace content with content parts
if content_parts:
_message["content"] = content_parts

return _message

@property
def request_kwargs(self) -> Dict[str, Any]:
"""Returns keyword arguments for API requests."""
base_params = {
"model": self.model_name if hasattr(self, 'model_name') else self.id,
"temperature": self.temperature,
"top_p": self.top_p,
}

# Use the correct model identifier based on provider
if self.id.startswith("huggingface/"):
# For Huggingface, we need to keep the "huggingface/" prefix
@@ -95,73 +169,71 @@ def request_kwargs(self) -> Dict[str, Any]:
else:
base_params["top_p"] = self.top_p

# Add optional parameters if they are set
# Add optional parameters
if self.max_tokens:
base_params["max_tokens"] = self.max_tokens

if self.api_key:
base_params["api_key"] = self.api_key

if self.api_base:
base_params["api_base"] = self.api_base

# Create request_kwargs dict with non-None values
request_kwargs = {k: v for k, v in base_params.items() if v is not None}

# Add additional request params if provided
if self.request_params:
request_kwargs.update(self.request_params)
# Add tools with proper formatting for OpenAI-style APIs
if self._tools is not None and len(self._tools) > 0:
tools_list = []
for tool in self._tools:
if isinstance(tool, dict):
tool_dict = tool
else:
# Assuming tool has a to_dict method that returns the function definition
tool_dict = {
"type": "function",
"function": {
"name": tool.name, # Make sure the tool has a name attribute
"description": tool.description, # And a description
"parameters": tool.parameters # And parameters schema
}
}
tools_list.append(tool_dict)

# Debug log (without sensitive info)
debug_kwargs = request_kwargs.copy()
if "api_key" in debug_kwargs:
debug_kwargs["api_key"] = "***REDACTED***"
logger.debug(f"LiteLLM request parameters: {debug_kwargs}")
base_params["tools"] = tools_list

return request_kwargs
# Set tool_choice
if hasattr(self, 'tool_choice') and self.tool_choice is not None:
base_params["tool_choice"] = self.tool_choice
else:
# Default to "auto" when tools are present
base_params["tool_choice"] = "auto"

def _format_message(self, message: Message) -> Dict[str, Any]:
"""
Format a message into the format expected by LiteLLM.
# Add additional request params
if self.request_params:
base_params.update(self.request_params)

Args:
message (Message): The message to format.
return base_params

Returns:
Dict[str, Any]: The formatted message.
"""
_message: Dict[str, Any] = {
def _format_message(self, message: Message) -> Dict[str, Any]:
"""Format a message for the LiteLLM API."""
formatted = {
"role": message.role,
"content": message.content,
"content": message.content
}

# Handle images if present
if message.role == "user" and message.images is not None and len(message.images) > 0:
content_parts = []

# Add text content if it exists
if message.content:
content_parts.append({"type": "text", "text": message.content})

# Add image content
for image in message.images:
if image.url is not None:
content_parts.append({"type": "image_url", "image_url": {"url": image.url}})
elif image.filepath is not None:
content_parts.append({"type": "image_url", "image_url": {"url": f"file://{image.filepath}"}})
elif image.content is not None and isinstance(image.content, bytes):
import base64

base64_image = base64.b64encode(image.content).decode("utf-8")
content_parts.append(
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
)
# Handle tool calls in assistant messages
if message.role == "assistant" and message.tool_calls:
formatted["tool_calls"] = [{
"id": tc.get("id", f"call_{i}"),
"type": "function",
"function": {
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}
} for i, tc in enumerate(message.tool_calls)]

# Replace content with content parts
if content_parts:
_message["content"] = content_parts
# Handle tool responses in tool messages
if message.role == "tool":
formatted["tool_call_id"] = message.tool_call_id
formatted["name"] = message.name

return _message
return formatted

def invoke(self, messages: List[Message]) -> Mapping[str, Any]:
"""
@@ -256,50 +328,47 @@ async def ainvoke_stream(self, messages: List[Message]) -> Any:
yield chunk

def parse_provider_response(self, response: Any) -> ModelResponse:
"""
Parse the provider response.

Args:
response (Any): The response from the provider.

Returns:
ModelResponse: The model response.
"""
"""Parse the provider response."""
model_response = ModelResponse()

# Get response message
response_message = response.choices[0].message

# Set role if available
if hasattr(response_message, "role"):
model_response.role = response_message.role

# Set content if available
if hasattr(response_message, "content") and response_message.content is not None:
model_response.content = response_message.content

# Handle tool calls if present
# Handle tool calls
if hasattr(response_message, "tool_calls") and response_message.tool_calls:
if model_response.tool_calls is None:
model_response.tool_calls = []

for tool_call in response_message.tool_calls:
if tool_call.type == "function":
function_def = {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
}
model_response.tool_calls.append({"type": "function", "function": function_def})
model_response.tool_calls = self.parse_tool_calls(
response_message.tool_calls)

# Get response usage
# Handle usage stats
if hasattr(response, "usage"):
model_response.response_usage = {
"input_tokens": response.usage.prompt_tokens,
"output_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
"total_tokens": response.usage.total_tokens
}

# Store the raw response
model_response.raw = response
# Parse structured outputs if enabled
try:
if (
self.response_format is not None
and self.structured_outputs
and issubclass(self.response_format, BaseModel)
):
parsed_object = response_message.content
if parsed_object is not None:
model_response.parsed = parsed_object
except Exception as e:
logger.warning(f"Error retrieving structured outputs: {e}")

model_response.raw = response
return model_response

def parse_provider_response_delta(self, response_delta: Any) -> ModelResponse:
@@ -338,7 +407,11 @@ def parse_provider_response_delta(self, response_delta: Any) -> ModelResponse:
function_def["arguments"] = tool_call.function.arguments

if function_def:
model_response.tool_calls.append({"type": "function", "function": function_def})
model_response.tool_calls.append({
"id": tool_call.id if hasattr(tool_call, 'id') else None,
"type": "function",
"function": function_def
})

# Store the raw response
model_response.raw = response_delta