4
4
import mimetypes
5
5
from typing import Literal
6
6
7
+ import ollama as ollama_api
7
8
from pydantic import BaseModel , Field , HttpUrl
8
9
9
- import ollama as ollama_api
10
+ from genkit . ai . model import get_basic_usage_stats
10
11
from genkit .core .action import ActionRunContext
11
12
from genkit .core .typing import (
12
13
GenerateRequest ,
13
14
GenerateResponse ,
14
15
GenerateResponseChunk ,
15
16
GenerationCommonConfig ,
17
+ GenerationUsage ,
16
18
Media ,
17
19
MediaPart ,
18
20
Message ,
@@ -61,12 +63,12 @@ async def generate(
61
63
content = [TextPart (text = 'Failed to get response from Ollama API' )]
62
64
63
65
if self .model_definition .api_type == OllamaAPITypes .CHAT :
64
- chat_response = await self ._chat_with_ollama (
66
+ api_response = await self ._chat_with_ollama (
65
67
request = request , ctx = ctx
66
68
)
67
- if chat_response :
69
+ if api_response :
68
70
content = self ._build_multimodal_chat_response (
69
- chat_response = chat_response ,
71
+ chat_response = api_response ,
70
72
)
71
73
elif self .model_definition .api_type == OllamaAPITypes .GENERATE :
72
74
api_response = await self ._generate_ollama_response (
@@ -75,16 +77,32 @@ async def generate(
75
77
if api_response :
76
78
content = [TextPart (text = api_response .response )]
77
79
else :
78
- LOG .error (f'Unresolved API type: { self .model_definition .api_type } ' )
80
+ raise ValueError (
81
+ f'Unresolved API type: { self .model_definition .api_type } '
82
+ )
79
83
80
84
if self .is_streaming_request (ctx = ctx ):
81
85
content = []
82
86
87
+ response_message = Message (
88
+ role = Role .MODEL ,
89
+ content = content ,
90
+ )
91
+
92
+ basic_generation_usage = get_basic_usage_stats (
93
+ input_ = request .messages ,
94
+ response = response_message ,
95
+ )
96
+
83
97
return GenerateResponse (
84
98
message = Message (
85
99
role = Role .MODEL ,
86
100
content = content ,
87
- )
101
+ ),
102
+ usage = self .get_usage_info (
103
+ basic_generation_usage = basic_generation_usage ,
104
+ api_response = api_response ,
105
+ ),
88
106
)
89
107
90
108
async def _chat_with_ollama (
@@ -277,3 +295,19 @@ def _to_ollama_role(
277
295
@staticmethod
278
296
def is_streaming_request (ctx : ActionRunContext | None ) -> bool :
279
297
return ctx and ctx .is_streaming
298
+
299
+ @staticmethod
300
+ def get_usage_info (
301
+ basic_generation_usage : GenerationUsage ,
302
+ api_response : ollama_api .GenerateResponse | ollama_api .ChatResponse ,
303
+ ) -> GenerationUsage :
304
+ if api_response :
305
+ basic_generation_usage .input_tokens = (
306
+ api_response .prompt_eval_count or 0
307
+ )
308
+ basic_generation_usage .output_tokens = api_response .eval_count or 0
309
+ basic_generation_usage .total_tokens = (
310
+ basic_generation_usage .input_tokens
311
+ + basic_generation_usage .output_tokens
312
+ )
313
+ return basic_generation_usage
0 commit comments