Skip to content

Commit 7e6a524

Browse files
authored
feat(py): Added collecting of usage info for Ollama API (#2361)
1 parent c91b361 commit 7e6a524

File tree

1 file changed

+40
-6
lines changed
  • py/plugins/ollama/src/genkit/plugins/ollama

1 file changed

+40
-6
lines changed

py/plugins/ollama/src/genkit/plugins/ollama/models.py

+40-6
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,17 @@
44
import mimetypes
55
from typing import Literal
66

7+
import ollama as ollama_api
78
from pydantic import BaseModel, Field, HttpUrl
89

9-
import ollama as ollama_api
10+
from genkit.ai.model import get_basic_usage_stats
1011
from genkit.core.action import ActionRunContext
1112
from genkit.core.typing import (
1213
GenerateRequest,
1314
GenerateResponse,
1415
GenerateResponseChunk,
1516
GenerationCommonConfig,
17+
GenerationUsage,
1618
Media,
1719
MediaPart,
1820
Message,
@@ -61,12 +63,12 @@ async def generate(
6163
content = [TextPart(text='Failed to get response from Ollama API')]
6264

6365
if self.model_definition.api_type == OllamaAPITypes.CHAT:
64-
chat_response = await self._chat_with_ollama(
66+
api_response = await self._chat_with_ollama(
6567
request=request, ctx=ctx
6668
)
67-
if chat_response:
69+
if api_response:
6870
content = self._build_multimodal_chat_response(
69-
chat_response=chat_response,
71+
chat_response=api_response,
7072
)
7173
elif self.model_definition.api_type == OllamaAPITypes.GENERATE:
7274
api_response = await self._generate_ollama_response(
@@ -75,16 +77,32 @@ async def generate(
7577
if api_response:
7678
content = [TextPart(text=api_response.response)]
7779
else:
78-
LOG.error(f'Unresolved API type: {self.model_definition.api_type}')
80+
raise ValueError(
81+
f'Unresolved API type: {self.model_definition.api_type}'
82+
)
7983

8084
if self.is_streaming_request(ctx=ctx):
8185
content = []
8286

87+
response_message = Message(
88+
role=Role.MODEL,
89+
content=content,
90+
)
91+
92+
basic_generation_usage = get_basic_usage_stats(
93+
input_=request.messages,
94+
response=response_message,
95+
)
96+
8397
return GenerateResponse(
8498
message=Message(
8599
role=Role.MODEL,
86100
content=content,
87-
)
101+
),
102+
usage=self.get_usage_info(
103+
basic_generation_usage=basic_generation_usage,
104+
api_response=api_response,
105+
),
88106
)
89107

90108
async def _chat_with_ollama(
@@ -277,3 +295,19 @@ def _to_ollama_role(
277295
@staticmethod
278296
def is_streaming_request(ctx: ActionRunContext | None) -> bool:
279297
return ctx and ctx.is_streaming
298+
299+
@staticmethod
300+
def get_usage_info(
301+
basic_generation_usage: GenerationUsage,
302+
api_response: ollama_api.GenerateResponse | ollama_api.ChatResponse,
303+
) -> GenerationUsage:
304+
if api_response:
305+
basic_generation_usage.input_tokens = (
306+
api_response.prompt_eval_count or 0
307+
)
308+
basic_generation_usage.output_tokens = api_response.eval_count or 0
309+
basic_generation_usage.total_tokens = (
310+
basic_generation_usage.input_tokens
311+
+ basic_generation_usage.output_tokens
312+
)
313+
return basic_generation_usage

0 commit comments

Comments
 (0)