From ee6b835de01b0f6cca23260e7f8191dd551537a2 Mon Sep 17 00:00:00 2001 From: Leonardo Pinheiro Date: Fri, 27 Dec 2024 13:44:33 +1000 Subject: [PATCH] add * before keyword args --- .../src/autogen_core/models/_model_client.py | 6 ++++-- .../packages/autogen-core/tests/test_tool_agent.py | 6 ++++-- .../autogen_ext/models/openai/_openai_client.py | 14 ++++++++------ .../replay/_replay_chat_completion_client.py | 6 ++++-- 4 files changed, 20 insertions(+), 12 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/models/_model_client.py b/python/packages/autogen-core/src/autogen_core/models/_model_client.py index 0cba5b07904e..7a0762a46e91 100644 --- a/python/packages/autogen-core/src/autogen_core/models/_model_client.py +++ b/python/packages/autogen-core/src/autogen_core/models/_model_client.py @@ -29,6 +29,7 @@ class ChatCompletionClient(ABC, ComponentLoader): async def create( self, messages: Sequence[LLMMessage], + *, tools: Sequence[Tool | ToolSchema] = [], # None means do not override the default # A value means to override the client default - often specified in the constructor @@ -41,6 +42,7 @@ async def create( def create_stream( self, messages: Sequence[LLMMessage], + *, tools: Sequence[Tool | ToolSchema] = [], # None means do not override the default # A value means to override the client default - often specified in the constructor @@ -56,10 +58,10 @@ def actual_usage(self) -> RequestUsage: ... def total_usage(self) -> RequestUsage: ... @abstractmethod - def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: ... + def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: ... @abstractmethod - def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: ... + def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: ... @property @abstractmethod diff --git a/python/packages/autogen-core/tests/test_tool_agent.py b/python/packages/autogen-core/tests/test_tool_agent.py index d0d6dec8915b..c93815d5dfd9 100644 --- a/python/packages/autogen-core/tests/test_tool_agent.py +++ b/python/packages/autogen-core/tests/test_tool_agent.py @@ -92,6 +92,7 @@ class MockChatCompletionClient(ChatCompletionClient): async def create( self, messages: Sequence[LLMMessage], + *, tools: Sequence[Tool | ToolSchema] = [], json_output: Optional[bool] = None, extra_create_args: Mapping[str, Any] = {}, @@ -116,6 +117,7 @@ async def create( def create_stream( self, messages: Sequence[LLMMessage], + *, tools: Sequence[Tool | ToolSchema] = [], json_output: Optional[bool] = None, extra_create_args: Mapping[str, Any] = {}, @@ -129,10 +131,10 @@ def actual_usage(self) -> RequestUsage: def total_usage(self) -> RequestUsage: return RequestUsage(prompt_tokens=0, completion_tokens=0) - def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: + def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: return 0 - def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: + def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: return 0 @property diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py index a70ee087aa95..2128378d98e4 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py @@ -355,6 +355,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient): def __init__( self, client: Union[AsyncOpenAI, AsyncAzureOpenAI], + *, create_args: Dict[str, Any], model_capabilities: Optional[ModelCapabilities] = None, ): @@ -389,6 +390,7 @@ def create_from_config(cls, config: Dict[str, Any]) -> ChatCompletionClient: async def create( self, messages: Sequence[LLMMessage], + *, tools: Sequence[Tool | ToolSchema] = [], json_output: Optional[bool] = None, extra_create_args: Mapping[str, Any] = {}, @@ -581,11 +583,11 @@ async def create( async def create_stream( self, messages: Sequence[LLMMessage], + *, tools: Sequence[Tool | ToolSchema] = [], json_output: Optional[bool] = None, extra_create_args: Mapping[str, Any] = {}, cancellation_token: Optional[CancellationToken] = None, - *, max_consecutive_empty_chunk_tolerance: int = 0, ) -> AsyncGenerator[Union[str, CreateResult], None]: """ @@ -800,7 +802,7 @@ def actual_usage(self) -> RequestUsage: def total_usage(self) -> RequestUsage: return self._total_usage - def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: + def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: model = self._create_args["model"] try: encoding = tiktoken.encoding_for_model(model) @@ -889,9 +891,9 @@ def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | To num_tokens += 12 return num_tokens - def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: + def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: token_limit = _model_info.get_token_limit(self._create_args["model"]) - return token_limit - self.count_tokens(messages, tools) + return token_limit - self.count_tokens(messages, tools=tools) @property def capabilities(self) -> ModelCapabilities: @@ -974,7 +976,7 @@ def __init__(self, **kwargs: Unpack[OpenAIClientConfiguration]): client = _openai_client_from_config(copied_args) create_args = _create_args_from_config(copied_args) self._raw_config: Dict[str, Any] = copied_args - super().__init__(client, create_args, model_capabilities) + super().__init__(client=client, create_args=create_args, model_capabilities=model_capabilities) def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() @@ -1059,7 +1061,7 @@ def __init__(self, **kwargs: Unpack[AzureOpenAIClientConfiguration]): client = _azure_openai_client_from_config(copied_args) create_args = _create_args_from_config(copied_args) self._raw_config: Dict[str, Any] = copied_args - super().__init__(client, create_args, model_capabilities) + super().__init__(client=client, create_args=create_args, model_capabilities=model_capabilities) def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() diff --git a/python/packages/autogen-ext/src/autogen_ext/models/replay/_replay_chat_completion_client.py b/python/packages/autogen-ext/src/autogen_ext/models/replay/_replay_chat_completion_client.py index 5b7ac7095c2b..1167be59b8ad 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/replay/_replay_chat_completion_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/replay/_replay_chat_completion_client.py @@ -128,6 +128,7 @@ def __init__( async def create( self, messages: Sequence[LLMMessage], + *, tools: Sequence[Tool | ToolSchema] = [], json_output: Optional[bool] = None, extra_create_args: Mapping[str, Any] = {}, @@ -155,6 +156,7 @@ async def create( async def create_stream( self, messages: Sequence[LLMMessage], + *, tools: Sequence[Tool | ToolSchema] = [], json_output: Optional[bool] = None, extra_create_args: Mapping[str, Any] = {}, @@ -191,11 +193,11 @@ def actual_usage(self) -> RequestUsage: def total_usage(self) -> RequestUsage: return self._total_usage - def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: + def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: _, token_count = self._tokenize(messages) return token_count - def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: + def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: return max( 0, self._total_available_tokens - self._total_usage.prompt_tokens - self._total_usage.completion_tokens )