1111from __future__ import annotations
1212
1313import time
14- from typing import Any , Callable , Dict , List , Optional , overload , Literal
14+ from typing import Any , Dict , List , Literal , overload
1515
1616import tinker
17- from openai .types .chat .chat_completion import ChatCompletion
18- from openai .types .completion import Completion
1917from openai import AsyncOpenAI
18+ from openai ._streaming import AsyncStream
2019from openai .resources .chat import AsyncChat as OpenAIAsyncChat
2120from openai .resources .chat .completions import AsyncCompletions as OpenAIAsyncChatCompletions
2221from openai .resources .completions import AsyncCompletions as OpenAIAsyncCompletions
23- from openai ._streaming import AsyncStream
22+ from openai .types .chat .chat_completion import ChatCompletion
23+ from openai .types .completion import Completion
2424
2525from tinker_cookbook import renderers
2626from tinker_cookbook .tokenizer_utils import Tokenizer
2727
2828
29- GenerationHook = Callable [
30- [List [renderers .Message ], tinker .ModelInput , List [int ], List [float ]], None
31- ]
32-
33-
34- def convert_oai_messages_to_renderer_messages (
35- messages : List [Dict [str , Any ]],
36- ) -> List [renderers .Message ]:
37- out : List [renderers .Message ] = []
38- for m in messages :
39- role = str (m .get ("role" , "user" ))
40- content = m .get ("content" , "" )
41- # extract text from list of content parts if necessary
42- if isinstance (content , list ):
43- text_parts : List [str ] = []
44- for part in content :
45- if isinstance (part , dict ):
46- if "text" in part :
47- text_parts .append (str (part ["text" ]))
48- elif isinstance (part , str ):
49- text_parts .append (part )
50- content = "" .join (text_parts )
51- else :
52- content = str (content )
53- out .append (renderers .Message (role = role , content = content ))
54- return out
55-
56-
5729class TinkerAsyncOpenAIClient (AsyncOpenAI ):
5830 """
5931 OpenAI-compatible async client that routes calls to a Tinker SamplingClient.
@@ -69,10 +41,6 @@ def __init__(
6941 self .sampling_client = sampling_client
7042 self .renderer = renderer
7143 self .tokenizer = tokenizer
72- self .hook : Optional [GenerationHook ] = None
73-
74- def set_generation_hook (self , hook : Optional [GenerationHook ]) -> None :
75- self .hook = hook
7644
7745 def set_sampling_client (self , sampling_client : tinker .SamplingClient ) -> None :
7846 self .sampling_client = sampling_client
@@ -106,16 +74,18 @@ async def create(self, *args: Any, stream: bool, **kwargs: Any) -> ChatCompletio
10674 async def create (self , * args : Any , ** kwargs : Any ) -> ChatCompletion | AsyncStream [Any ]:
10775 model = kwargs .get ("model" , "tinker" )
10876 messages = kwargs .get ("messages" , [])
77+ if kwargs .get ("tools" ):
78+ raise NotImplementedError ("Tool calling is not yet supported by this model's renderer." )
10979 if kwargs .get ("stream" , False ):
11080 raise ValueError ("stream=True not supported by TinkerAsyncOpenAIClient" )
11181 sampling_args = {k : v for k , v in kwargs .items () if k not in ("model" , "messages" , "tools" )}
11282
113- # prepare prompt
114- conv_messages = convert_oai_messages_to_renderer_messages (messages )
11583 stop = sampling_args .get ("stop" , self ._parent .renderer .get_stop_sequences ())
11684 max_tokens = sampling_args .get ("max_tokens" ) or sampling_args .get ("max_completion_tokens" )
11785
118- model_input = self ._parent .renderer .build_generation_prompt (conv_messages )
86+ model_input = self ._parent .renderer .build_generation_prompt (messages )
87+ prompt_token_ids : List [int ] = model_input .to_ints ()
88+
11989 sample = await self ._parent .sampling_client .sample_async (
12090 prompt = model_input ,
12191 num_samples = 1 ,
@@ -128,15 +98,12 @@ async def create(self, *args: Any, **kwargs: Any) -> ChatCompletion | AsyncStrea
12898 ),
12999 )
130100 seq = sample .sequences [0 ]
131- tokens : List [int ] = seq .tokens
132- logprobs : List [float ] = seq .logprobs or [0.0 ] * len (tokens )
133-
134- if self ._parent .hook is not None :
135- self ._parent .hook (conv_messages , model_input , tokens , logprobs )
101+ completion_token_ids : List [int ] = seq .tokens
102+ logprobs : List [float ] = seq .logprobs or [0.0 ] * len (completion_token_ids )
136103
137- # build ChatCompletion via pydantic validation using renderer parsing
138- assistant_message , parse_success = self . _parent . renderer . parse_response ( tokens )
139- content_text = assistant_message [ "content" ]
104+ assistant_message , parse_success = self . _parent . renderer . parse_response (
105+ completion_token_ids
106+ )
140107 finish_reason = "stop" if parse_success else "length"
141108 response_dict : Dict [str , Any ] = {
142109 "id" : "tinker-chatcmpl" ,
@@ -146,23 +113,28 @@ async def create(self, *args: Any, **kwargs: Any) -> ChatCompletion | AsyncStrea
146113 "choices" : [
147114 {
148115 "index" : 0 ,
149- "message" : { "role" : "assistant" , "content" : content_text } ,
116+ "message" : assistant_message ,
150117 "finish_reason" : finish_reason ,
151118 "logprobs" : {
152119 "content" : [
153- {"token" : f"token_id:{ tid } " , "logprob" : float ( lp ) , "top_logprobs" : []}
154- for tid , lp in zip (tokens , logprobs )
120+ {"token" : f"token_id:{ tid } " , "logprob" : lp , "top_logprobs" : []}
121+ for tid , lp in zip (completion_token_ids , logprobs )
155122 ]
156123 },
157124 }
158125 ],
159126 "usage" : {
160- "prompt_tokens" : model_input . length ,
161- "completion_tokens" : len (tokens ),
162- "total_tokens" : model_input . length + len (tokens ),
127+ "prompt_tokens" : len ( prompt_token_ids ) ,
128+ "completion_tokens" : len (completion_token_ids ),
129+ "total_tokens" : len ( prompt_token_ids ) + len (completion_token_ids ),
163130 },
164131 }
165- return ChatCompletion .model_validate (response_dict )
132+ response = ChatCompletion .model_validate (response_dict )
133+
134+ setattr (response , "prompt_token_ids" , prompt_token_ids )
135+ setattr (response .choices [0 ], "token_ids" , completion_token_ids )
136+
137+ return response
166138
167139
168140class TinkerCompletions (OpenAIAsyncCompletions ):
@@ -190,10 +162,9 @@ async def create(self, *args: Any, **kwargs: Any) -> Completion | AsyncStream[Co
190162 prompt = kwargs .get ("prompt" , "" )
191163 sampling_args = {k : v for k , v in kwargs .items () if k not in ("model" , "prompt" )}
192164
193- # Completion-mode: render prompt directly as text chunk
194- model_input = tinker .ModelInput .from_ints (
195- self ._parent .tokenizer .encode (prompt , add_special_tokens = True )
196- )
165+ prompt_token_ids : List [int ] = self ._parent .tokenizer .encode (prompt , add_special_tokens = True )
166+ model_input = tinker .ModelInput .from_ints (prompt_token_ids )
167+
197168 sample = await self ._parent .sampling_client .sample_async (
198169 prompt = model_input ,
199170 num_samples = 1 ,
@@ -205,11 +176,11 @@ async def create(self, *args: Any, **kwargs: Any) -> Completion | AsyncStream[Co
205176 ),
206177 )
207178 seq = sample .sequences [0 ]
208- tokens : List [int ] = seq .tokens
209- logprobs : List [float ] = seq .logprobs or [0.0 ] * len (tokens )
179+ completion_token_ids : List [int ] = seq .tokens
180+ logprobs : List [float ] = seq .logprobs or [0.0 ] * len (completion_token_ids )
210181
211- text = self ._parent .tokenizer .decode (tokens )
212- tokens_str = [f"token_id:{ tid } " for tid in tokens ]
182+ text = self ._parent .tokenizer .decode (completion_token_ids )
183+ tokens_str = [f"token_id:{ tid } " for tid in completion_token_ids ]
213184 response_dict : Dict [str , Any ] = {
214185 "id" : "tinker-cmpl" ,
215186 "object" : "text_completion" ,
@@ -222,20 +193,24 @@ async def create(self, *args: Any, **kwargs: Any) -> Completion | AsyncStream[Co
222193 "finish_reason" : "stop" ,
223194 "logprobs" : {
224195 "tokens" : tokens_str ,
225- "token_logprobs" : [ float ( lp ) for lp in logprobs ] ,
196+ "token_logprobs" : logprobs ,
226197 },
227198 }
228199 ],
229200 "usage" : {
230- "prompt_tokens" : model_input . length ,
231- "completion_tokens" : len (tokens ),
232- "total_tokens" : model_input . length + len (tokens ),
201+ "prompt_tokens" : len ( prompt_token_ids ) ,
202+ "completion_tokens" : len (completion_token_ids ),
203+ "total_tokens" : len ( prompt_token_ids ) + len (completion_token_ids ),
233204 },
234205 }
235- final = Completion .model_validate (response_dict )
206+ response = Completion .model_validate (response_dict )
207+
208+ setattr (response .choices [0 ], "prompt_token_ids" , prompt_token_ids )
209+ setattr (response .choices [0 ], "token_ids" , completion_token_ids )
210+
236211 if stream :
237- return TinkerAsyncCompletionStream (final )
238- return final
212+ return TinkerAsyncCompletionStream (response )
213+ return response
239214
240215
241216class TinkerAsyncChat (OpenAIAsyncChat ):
0 commit comments