Skip to content

Commit 57935db

Browse files
authored
feat: support structured outputs in TogetherAIChatGenerator (#2534)
* Support for structured outputs * Add an example * Fix linting
1 parent 20a1f47 commit 57935db

File tree

3 files changed

+154
-4
lines changed

3 files changed

+154
-4
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
6+
# This example demonstrates how to use the TogetherAIChatGenerator component
7+
# with structured outputs.
8+
# To run this example, you will need to
9+
# set `TOGETHER_API_KEY` environment variable
10+
11+
from haystack.dataclasses import ChatMessage
12+
from pydantic import BaseModel
13+
14+
from haystack_integrations.components.generators.togetherai import TogetherAIChatGenerator
15+
16+
17+
class NobelPrizeInfo(BaseModel):
18+
recipient_name: str
19+
award_year: int
20+
category: str
21+
achievement_description: str
22+
nationality: str
23+
24+
25+
chat_messages = [
26+
ChatMessage.from_user(
27+
"In 2021, American scientist David Julius received the Nobel Prize in"
28+
" Physiology or Medicine for his groundbreaking discoveries on how the human body"
29+
" senses temperature and touch."
30+
)
31+
]
32+
component = TogetherAIChatGenerator(generation_kwargs={"response_format": NobelPrizeInfo})
33+
results = component.run(chat_messages)
34+
35+
# print(results)

integrations/togetherai/src/haystack_integrations/components/generators/togetherai/chat/chat_generator.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from haystack.tools import ToolsType
1111
from haystack.utils import serialize_callable
1212
from haystack.utils.auth import Secret
13+
from openai.lib._pydantic import to_strict_json_schema
14+
from pydantic import BaseModel
1315

1416
logger = logging.getLogger(__name__)
1517

@@ -98,6 +100,13 @@ def __init__(
98100
events as they become available, with the stream terminated by a data: [DONE] message.
99101
- `safe_prompt`: Whether to inject a safety prompt before all conversations.
100102
- `random_seed`: The seed to use for random sampling.
103+
- `response_format`: A JSON schema or a Pydantic model that enforces the structure of the model's response.
104+
If provided, the output will always be validated against this
105+
format (unless the model returns a tool call).
106+
For details, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs).
107+
Notes:
108+
- For structured outputs with streaming,
109+
the `response_format` must be a JSON schema and not a Pydantic model.
101110
:param tools:
102111
A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
103112
Each tool should have a unique name.
@@ -131,6 +140,21 @@ def to_dict(self) -> dict[str, Any]:
131140
The serialized component as a dictionary.
132141
"""
133142
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
143+
generation_kwargs = self.generation_kwargs.copy()
144+
response_format = generation_kwargs.get("response_format")
145+
# If the response format is a Pydantic model, it's converted to openai's json schema format
146+
# If it's already a json schema, it's left as is
147+
if response_format and isinstance(response_format, type) and issubclass(response_format, BaseModel):
148+
json_schema = {
149+
"type": "json_schema",
150+
"json_schema": {
151+
"name": response_format.__name__,
152+
"strict": True,
153+
"schema": to_strict_json_schema(response_format),
154+
},
155+
}
156+
157+
generation_kwargs["response_format"] = json_schema
134158

135159
# if we didn't implement the to_dict method here then the to_dict method of the superclass would be used
136160
# which would serialiaze some fields that we don't want to serialize (e.g. the ones we don't have in
@@ -141,7 +165,7 @@ def to_dict(self) -> dict[str, Any]:
141165
model=self.model,
142166
streaming_callback=callback_name,
143167
api_base_url=self.api_base_url,
144-
generation_kwargs=self.generation_kwargs,
168+
generation_kwargs=generation_kwargs,
145169
api_key=self.api_key.to_dict(),
146170
tools=[tool.to_dict() for tool in self.tools] if self.tools else None,
147171
timeout=self.timeout,

integrations/togetherai/tests/test_togetherai_chat_generator.py

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import os
23
from datetime import datetime
34
from unittest.mock import patch
@@ -16,10 +17,22 @@
1617
from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk
1718
from openai.types.chat.chat_completion_chunk import ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
1819
from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage, PromptTokensDetails
20+
from pydantic import BaseModel
1921

2022
from haystack_integrations.components.generators.togetherai.chat.chat_generator import TogetherAIChatGenerator
2123

2224

25+
class CalendarEvent(BaseModel):
26+
event_name: str
27+
event_date: str
28+
event_location: str
29+
30+
31+
@pytest.fixture
32+
def calendar_event_model():
33+
return CalendarEvent
34+
35+
2336
class CollectorCallback:
2437
"""
2538
Callback to collect streaming chunks for testing purposes.
@@ -161,14 +174,18 @@ def test_to_dict_default(self, monkeypatch):
161174
for key, value in expected_params.items():
162175
assert data["init_parameters"][key] == value
163176

164-
def test_to_dict_with_parameters(self, monkeypatch):
177+
def test_to_dict_with_parameters(self, monkeypatch, calendar_event_model):
165178
monkeypatch.setenv("ENV_VAR", "test-api-key")
166179
component = TogetherAIChatGenerator(
167180
api_key=Secret.from_env_var("ENV_VAR"),
168181
model="openai/gpt-oss-20b",
169182
streaming_callback=print_streaming_chunk,
170183
api_base_url="test-base-url",
171-
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
184+
generation_kwargs={
185+
"max_tokens": 10,
186+
"some_test_param": "test-params",
187+
"response_format": calendar_event_model,
188+
},
172189
timeout=10,
173190
max_retries=10,
174191
tools=None,
@@ -186,7 +203,28 @@ def test_to_dict_with_parameters(self, monkeypatch):
186203
"model": "openai/gpt-oss-20b",
187204
"api_base_url": "test-base-url",
188205
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
189-
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
206+
"generation_kwargs": {
207+
"max_tokens": 10,
208+
"some_test_param": "test-params",
209+
"response_format": {
210+
"type": "json_schema",
211+
"json_schema": {
212+
"name": "CalendarEvent",
213+
"strict": True,
214+
"schema": {
215+
"properties": {
216+
"event_name": {"title": "Event Name", "type": "string"},
217+
"event_date": {"title": "Event Date", "type": "string"},
218+
"event_location": {"title": "Event Location", "type": "string"},
219+
},
220+
"required": ["event_name", "event_date", "event_location"],
221+
"title": "CalendarEvent",
222+
"type": "object",
223+
"additionalProperties": False,
224+
},
225+
},
226+
},
227+
},
190228
"timeout": 10,
191229
"max_retries": 10,
192230
"tools": None,
@@ -394,6 +432,59 @@ def test_live_run_with_tools_and_response(self, tools):
394432
assert "paris" in final_message.text.lower()
395433
assert "berlin" in final_message.text.lower()
396434

435+
@pytest.mark.skipif(
436+
not os.environ.get("TOGETHER_API_KEY", None),
437+
reason="Export an env var called TOGETHER_API_KEY containing the Together AI API key to run this test.",
438+
)
439+
@pytest.mark.integration
440+
def test_live_run_with_response_format_json_schema(self):
441+
response_schema = {
442+
"type": "json_schema",
443+
"json_schema": {
444+
"name": "CapitalCity",
445+
"strict": True,
446+
"schema": {
447+
"title": "CapitalCity",
448+
"type": "object",
449+
"properties": {
450+
"city": {"title": "City", "type": "string"},
451+
"country": {"title": "Country", "type": "string"},
452+
},
453+
"required": ["city", "country"],
454+
"additionalProperties": False,
455+
},
456+
},
457+
}
458+
459+
chat_messages = [ChatMessage.from_user("What's the capital of France?")]
460+
comp = TogetherAIChatGenerator(generation_kwargs={"response_format": response_schema})
461+
results = comp.run(chat_messages)
462+
assert len(results["replies"]) == 1
463+
message: ChatMessage = results["replies"][0]
464+
msg = json.loads(message.text)
465+
assert "Paris" in msg["city"]
466+
assert isinstance(msg["country"], str)
467+
assert "France" in msg["country"]
468+
assert message.meta["finish_reason"] == "stop"
469+
470+
@pytest.mark.skipif(
471+
not os.environ.get("TOGETHER_API_KEY", None),
472+
reason="Export an env var called TOGETHER_API_KEY containing the Together AI API key to run this test.",
473+
)
474+
@pytest.mark.integration
475+
def test_live_run_with_response_format_pydantic_model(self, calendar_event_model):
476+
chat_messages = [
477+
ChatMessage.from_user("The marketing summit takes place on October12th at the Hilton Hotel downtown.")
478+
]
479+
component = TogetherAIChatGenerator(generation_kwargs={"response_format": calendar_event_model})
480+
results = component.run(chat_messages)
481+
assert len(results["replies"]) == 1
482+
message: ChatMessage = results["replies"][0]
483+
msg = json.loads(message.text)
484+
assert "Marketing Summit" in msg["event_name"]
485+
assert isinstance(msg["event_date"], str)
486+
assert isinstance(msg["event_location"], str)
487+
397488
@pytest.mark.skipif(
398489
not os.environ.get("TOGETHER_API_KEY", None),
399490
reason="Export an env var called TOGETHER_API_KEY containing the Together AI API key to run this test.",

0 commit comments

Comments
 (0)