Skip to content

Commit d177465

Browse files
authored
verifiers_rl recipe updates for verifiers v0.1.8 release (#133)
1 parent 461bf9b commit d177465

File tree

4 files changed

+198
-170
lines changed

4 files changed

+198
-170
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ trackio = [
5454
"trackio<1.0.0",
5555
]
5656
verifiers = [
57-
"verifiers",
57+
"verifiers>=0.1.8.post0",
5858
"openai",
5959
]
6060
all = [

tinker_cookbook/recipes/verifiers_rl/tinker_openai.py

Lines changed: 44 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -11,49 +11,21 @@
1111
from __future__ import annotations
1212

1313
import time
14-
from typing import Any, Callable, Dict, List, Optional, overload, Literal
14+
from typing import Any, Dict, List, Literal, overload
1515

1616
import tinker
17-
from openai.types.chat.chat_completion import ChatCompletion
18-
from openai.types.completion import Completion
1917
from openai import AsyncOpenAI
18+
from openai._streaming import AsyncStream
2019
from openai.resources.chat import AsyncChat as OpenAIAsyncChat
2120
from openai.resources.chat.completions import AsyncCompletions as OpenAIAsyncChatCompletions
2221
from openai.resources.completions import AsyncCompletions as OpenAIAsyncCompletions
23-
from openai._streaming import AsyncStream
22+
from openai.types.chat.chat_completion import ChatCompletion
23+
from openai.types.completion import Completion
2424

2525
from tinker_cookbook import renderers
2626
from tinker_cookbook.tokenizer_utils import Tokenizer
2727

2828

29-
GenerationHook = Callable[
30-
[List[renderers.Message], tinker.ModelInput, List[int], List[float]], None
31-
]
32-
33-
34-
def convert_oai_messages_to_renderer_messages(
35-
messages: List[Dict[str, Any]],
36-
) -> List[renderers.Message]:
37-
out: List[renderers.Message] = []
38-
for m in messages:
39-
role = str(m.get("role", "user"))
40-
content = m.get("content", "")
41-
# extract text from list of content parts if necessary
42-
if isinstance(content, list):
43-
text_parts: List[str] = []
44-
for part in content:
45-
if isinstance(part, dict):
46-
if "text" in part:
47-
text_parts.append(str(part["text"]))
48-
elif isinstance(part, str):
49-
text_parts.append(part)
50-
content = "".join(text_parts)
51-
else:
52-
content = str(content)
53-
out.append(renderers.Message(role=role, content=content))
54-
return out
55-
56-
5729
class TinkerAsyncOpenAIClient(AsyncOpenAI):
5830
"""
5931
OpenAI-compatible async client that routes calls to a Tinker SamplingClient.
@@ -69,10 +41,6 @@ def __init__(
6941
self.sampling_client = sampling_client
7042
self.renderer = renderer
7143
self.tokenizer = tokenizer
72-
self.hook: Optional[GenerationHook] = None
73-
74-
def set_generation_hook(self, hook: Optional[GenerationHook]) -> None:
75-
self.hook = hook
7644

7745
def set_sampling_client(self, sampling_client: tinker.SamplingClient) -> None:
7846
self.sampling_client = sampling_client
@@ -106,16 +74,18 @@ async def create(self, *args: Any, stream: bool, **kwargs: Any) -> ChatCompletio
10674
async def create(self, *args: Any, **kwargs: Any) -> ChatCompletion | AsyncStream[Any]:
10775
model = kwargs.get("model", "tinker")
10876
messages = kwargs.get("messages", [])
77+
if kwargs.get("tools"):
78+
raise NotImplementedError("Tool calling is not yet supported by this model's renderer.")
10979
if kwargs.get("stream", False):
11080
raise ValueError("stream=True not supported by TinkerAsyncOpenAIClient")
11181
sampling_args = {k: v for k, v in kwargs.items() if k not in ("model", "messages", "tools")}
11282

113-
# prepare prompt
114-
conv_messages = convert_oai_messages_to_renderer_messages(messages)
11583
stop = sampling_args.get("stop", self._parent.renderer.get_stop_sequences())
11684
max_tokens = sampling_args.get("max_tokens") or sampling_args.get("max_completion_tokens")
11785

118-
model_input = self._parent.renderer.build_generation_prompt(conv_messages)
86+
model_input = self._parent.renderer.build_generation_prompt(messages)
87+
prompt_token_ids: List[int] = model_input.to_ints()
88+
11989
sample = await self._parent.sampling_client.sample_async(
12090
prompt=model_input,
12191
num_samples=1,
@@ -128,15 +98,12 @@ async def create(self, *args: Any, **kwargs: Any) -> ChatCompletion | AsyncStrea
12898
),
12999
)
130100
seq = sample.sequences[0]
131-
tokens: List[int] = seq.tokens
132-
logprobs: List[float] = seq.logprobs or [0.0] * len(tokens)
133-
134-
if self._parent.hook is not None:
135-
self._parent.hook(conv_messages, model_input, tokens, logprobs)
101+
completion_token_ids: List[int] = seq.tokens
102+
logprobs: List[float] = seq.logprobs or [0.0] * len(completion_token_ids)
136103

137-
# build ChatCompletion via pydantic validation using renderer parsing
138-
assistant_message, parse_success = self._parent.renderer.parse_response(tokens)
139-
content_text = assistant_message["content"]
104+
assistant_message, parse_success = self._parent.renderer.parse_response(
105+
completion_token_ids
106+
)
140107
finish_reason = "stop" if parse_success else "length"
141108
response_dict: Dict[str, Any] = {
142109
"id": "tinker-chatcmpl",
@@ -146,23 +113,28 @@ async def create(self, *args: Any, **kwargs: Any) -> ChatCompletion | AsyncStrea
146113
"choices": [
147114
{
148115
"index": 0,
149-
"message": {"role": "assistant", "content": content_text},
116+
"message": assistant_message,
150117
"finish_reason": finish_reason,
151118
"logprobs": {
152119
"content": [
153-
{"token": f"token_id:{tid}", "logprob": float(lp), "top_logprobs": []}
154-
for tid, lp in zip(tokens, logprobs)
120+
{"token": f"token_id:{tid}", "logprob": lp, "top_logprobs": []}
121+
for tid, lp in zip(completion_token_ids, logprobs)
155122
]
156123
},
157124
}
158125
],
159126
"usage": {
160-
"prompt_tokens": model_input.length,
161-
"completion_tokens": len(tokens),
162-
"total_tokens": model_input.length + len(tokens),
127+
"prompt_tokens": len(prompt_token_ids),
128+
"completion_tokens": len(completion_token_ids),
129+
"total_tokens": len(prompt_token_ids) + len(completion_token_ids),
163130
},
164131
}
165-
return ChatCompletion.model_validate(response_dict)
132+
response = ChatCompletion.model_validate(response_dict)
133+
134+
setattr(response, "prompt_token_ids", prompt_token_ids)
135+
setattr(response.choices[0], "token_ids", completion_token_ids)
136+
137+
return response
166138

167139

168140
class TinkerCompletions(OpenAIAsyncCompletions):
@@ -190,10 +162,9 @@ async def create(self, *args: Any, **kwargs: Any) -> Completion | AsyncStream[Co
190162
prompt = kwargs.get("prompt", "")
191163
sampling_args = {k: v for k, v in kwargs.items() if k not in ("model", "prompt")}
192164

193-
# Completion-mode: render prompt directly as text chunk
194-
model_input = tinker.ModelInput.from_ints(
195-
self._parent.tokenizer.encode(prompt, add_special_tokens=True)
196-
)
165+
prompt_token_ids: List[int] = self._parent.tokenizer.encode(prompt, add_special_tokens=True)
166+
model_input = tinker.ModelInput.from_ints(prompt_token_ids)
167+
197168
sample = await self._parent.sampling_client.sample_async(
198169
prompt=model_input,
199170
num_samples=1,
@@ -205,11 +176,11 @@ async def create(self, *args: Any, **kwargs: Any) -> Completion | AsyncStream[Co
205176
),
206177
)
207178
seq = sample.sequences[0]
208-
tokens: List[int] = seq.tokens
209-
logprobs: List[float] = seq.logprobs or [0.0] * len(tokens)
179+
completion_token_ids: List[int] = seq.tokens
180+
logprobs: List[float] = seq.logprobs or [0.0] * len(completion_token_ids)
210181

211-
text = self._parent.tokenizer.decode(tokens)
212-
tokens_str = [f"token_id:{tid}" for tid in tokens]
182+
text = self._parent.tokenizer.decode(completion_token_ids)
183+
tokens_str = [f"token_id:{tid}" for tid in completion_token_ids]
213184
response_dict: Dict[str, Any] = {
214185
"id": "tinker-cmpl",
215186
"object": "text_completion",
@@ -222,20 +193,24 @@ async def create(self, *args: Any, **kwargs: Any) -> Completion | AsyncStream[Co
222193
"finish_reason": "stop",
223194
"logprobs": {
224195
"tokens": tokens_str,
225-
"token_logprobs": [float(lp) for lp in logprobs],
196+
"token_logprobs": logprobs,
226197
},
227198
}
228199
],
229200
"usage": {
230-
"prompt_tokens": model_input.length,
231-
"completion_tokens": len(tokens),
232-
"total_tokens": model_input.length + len(tokens),
201+
"prompt_tokens": len(prompt_token_ids),
202+
"completion_tokens": len(completion_token_ids),
203+
"total_tokens": len(prompt_token_ids) + len(completion_token_ids),
233204
},
234205
}
235-
final = Completion.model_validate(response_dict)
206+
response = Completion.model_validate(response_dict)
207+
208+
setattr(response.choices[0], "prompt_token_ids", prompt_token_ids)
209+
setattr(response.choices[0], "token_ids", completion_token_ids)
210+
236211
if stream:
237-
return TinkerAsyncCompletionStream(final)
238-
return final
212+
return TinkerAsyncCompletionStream(response)
213+
return response
239214

240215

241216
class TinkerAsyncChat(OpenAIAsyncChat):

0 commit comments

Comments
 (0)