diff --git a/python/packages/kagent-adk/src/kagent/adk/models/_ollama.py b/python/packages/kagent-adk/src/kagent/adk/models/_ollama.py index 03c268af46..2bb930eb8b 100644 --- a/python/packages/kagent-adk/src/kagent/adk/models/_ollama.py +++ b/python/packages/kagent-adk/src/kagent/adk/models/_ollama.py @@ -133,6 +133,13 @@ def _convert_tools_to_ollama(tools: list[types.Tool]) -> list[ollama_sdk.Tool]: return ollama_tools +def _convert_tool_call_to_part(tc: OllamaMessage.ToolCall) -> types.Part: + part = types.Part.from_function_call(name=tc.function.name, args=dict(tc.function.arguments)) + if part.function_call: + part.function_call.id = str(uuid.uuid4()) + return part + + class KAgentOllamaLlm(KAgentTLSMixin, BaseLlm): """Ollama model via the native Ollama SDK. @@ -190,6 +197,7 @@ async def generate_content_async( try: if stream: aggregated_text = "" + tool_calls = [] response: AsyncIterator[ollama_sdk.ChatResponse] = await self._client.chat( model=llm_request.model or self.model, messages=messages, @@ -198,6 +206,7 @@ async def generate_content_async( stream=True, ) async for chunk in response: + tool_calls.extend(chunk.message.tool_calls or []) if chunk.message.content: aggregated_text += chunk.message.content yield LlmResponse( @@ -211,13 +220,7 @@ async def generate_content_async( final_parts = [] if aggregated_text: final_parts.append(types.Part.from_text(text=aggregated_text)) - for tc in chunk.message.tool_calls or []: - part = types.Part.from_function_call( - name=tc.function.name, args=dict(tc.function.arguments) - ) - if part.function_call: - part.function_call.id = str(uuid.uuid4()) - final_parts.append(part) + final_parts.extend(_convert_tool_call_to_part(tc) for tc in tool_calls) finish_reason = _done_reason_to_finish_reason(chunk.done_reason) if chunk.done_reason else None usage_metadata = None if chunk.prompt_eval_count is not None or chunk.eval_count is not None: @@ -245,10 +248,7 @@ async def generate_content_async( if response.message.content: parts.append(types.Part.from_text(text=response.message.content)) for tc in response.message.tool_calls or []: - part = types.Part.from_function_call(name=tc.function.name, args=dict(tc.function.arguments)) - if part.function_call: - part.function_call.id = str(uuid.uuid4()) - parts.append(part) + parts.append(_convert_tool_call_to_part(tc)) finish_reason = _done_reason_to_finish_reason(response.done_reason) if response.done_reason else None usage_metadata = None if response.prompt_eval_count is not None or response.eval_count is not None: diff --git a/python/packages/kagent-adk/tests/unittests/models/test_ollama.py b/python/packages/kagent-adk/tests/unittests/models/test_ollama.py index 9ae97f2a3b..e5fe86720a 100644 --- a/python/packages/kagent-adk/tests/unittests/models/test_ollama.py +++ b/python/packages/kagent-adk/tests/unittests/models/test_ollama.py @@ -92,6 +92,51 @@ async def test_generate_content_forwards_ollama_options(self): assert mock_client.chat.call_args.kwargs["options"] == opts + @pytest.mark.asyncio + async def test_generate_content_streaming_accumulates_tool_calls_before_done_chunk(self): + llm = KAgentOllamaLlm(model="llama3.2:latest") + + tool_call = mock.MagicMock() + tool_call.function.name = "get_weather" + tool_call.function.arguments = {"city": "Tokyo"} + + tool_chunk = mock.MagicMock() + tool_chunk.message.content = "" + tool_chunk.message.tool_calls = [tool_call] + tool_chunk.done = False + + done_chunk = mock.MagicMock() + done_chunk.message.content = "" + done_chunk.message.tool_calls = None + done_chunk.done = True + done_chunk.done_reason = "stop" + done_chunk.prompt_eval_count = 10 + done_chunk.eval_count = 0 + + async def chunks(): + yield tool_chunk + yield done_chunk + + mock_client = mock.AsyncMock() + mock_client.chat = mock.AsyncMock(return_value=chunks()) + + request = mock.MagicMock() + request.model = "llama3.2:latest" + request.contents = [] + request.config = None + + with mock.patch.object(type(llm), "_client", new_callable=lambda: property(lambda self: mock_client)): + responses = [r async for r in llm.generate_content_async(request, stream=True)] + + assert len(responses) == 1 + final_response = responses[0] + assert final_response.partial is False + assert final_response.turn_complete is True + assert len(final_response.content.parts) == 1 + function_call = final_response.content.parts[0].function_call + assert function_call.name == "get_weather" + assert dict(function_call.args) == {"city": "Tokyo"} + class TestConvertContentToOllamaMessages: def test_image_inline_data_included(self):