1+ import json
12import os
23from datetime import datetime
34from unittest .mock import patch
1617from openai .types .chat .chat_completion_chunk import Choice as ChoiceChunk
1718from openai .types .chat .chat_completion_chunk import ChoiceDelta , ChoiceDeltaToolCall , ChoiceDeltaToolCallFunction
1819from openai .types .completion_usage import CompletionTokensDetails , CompletionUsage , PromptTokensDetails
20+ from pydantic import BaseModel
1921
2022from haystack_integrations .components .generators .togetherai .chat .chat_generator import TogetherAIChatGenerator
2123
2224
25+ class CalendarEvent (BaseModel ):
26+ event_name : str
27+ event_date : str
28+ event_location : str
29+
30+
31+ @pytest .fixture
32+ def calendar_event_model ():
33+ return CalendarEvent
34+
35+
2336class CollectorCallback :
2437 """
2538 Callback to collect streaming chunks for testing purposes.
@@ -161,14 +174,18 @@ def test_to_dict_default(self, monkeypatch):
161174 for key , value in expected_params .items ():
162175 assert data ["init_parameters" ][key ] == value
163176
164- def test_to_dict_with_parameters (self , monkeypatch ):
177+ def test_to_dict_with_parameters (self , monkeypatch , calendar_event_model ):
165178 monkeypatch .setenv ("ENV_VAR" , "test-api-key" )
166179 component = TogetherAIChatGenerator (
167180 api_key = Secret .from_env_var ("ENV_VAR" ),
168181 model = "openai/gpt-oss-20b" ,
169182 streaming_callback = print_streaming_chunk ,
170183 api_base_url = "test-base-url" ,
171- generation_kwargs = {"max_tokens" : 10 , "some_test_param" : "test-params" },
184+ generation_kwargs = {
185+ "max_tokens" : 10 ,
186+ "some_test_param" : "test-params" ,
187+ "response_format" : calendar_event_model ,
188+ },
172189 timeout = 10 ,
173190 max_retries = 10 ,
174191 tools = None ,
@@ -186,7 +203,28 @@ def test_to_dict_with_parameters(self, monkeypatch):
186203 "model" : "openai/gpt-oss-20b" ,
187204 "api_base_url" : "test-base-url" ,
188205 "streaming_callback" : "haystack.components.generators.utils.print_streaming_chunk" ,
189- "generation_kwargs" : {"max_tokens" : 10 , "some_test_param" : "test-params" },
206+ "generation_kwargs" : {
207+ "max_tokens" : 10 ,
208+ "some_test_param" : "test-params" ,
209+ "response_format" : {
210+ "type" : "json_schema" ,
211+ "json_schema" : {
212+ "name" : "CalendarEvent" ,
213+ "strict" : True ,
214+ "schema" : {
215+ "properties" : {
216+ "event_name" : {"title" : "Event Name" , "type" : "string" },
217+ "event_date" : {"title" : "Event Date" , "type" : "string" },
218+ "event_location" : {"title" : "Event Location" , "type" : "string" },
219+ },
220+ "required" : ["event_name" , "event_date" , "event_location" ],
221+ "title" : "CalendarEvent" ,
222+ "type" : "object" ,
223+ "additionalProperties" : False ,
224+ },
225+ },
226+ },
227+ },
190228 "timeout" : 10 ,
191229 "max_retries" : 10 ,
192230 "tools" : None ,
@@ -394,6 +432,59 @@ def test_live_run_with_tools_and_response(self, tools):
394432 assert "paris" in final_message .text .lower ()
395433 assert "berlin" in final_message .text .lower ()
396434
435+ @pytest .mark .skipif (
436+ not os .environ .get ("TOGETHER_API_KEY" , None ),
437+ reason = "Export an env var called TOGETHER_API_KEY containing the Together AI API key to run this test." ,
438+ )
439+ @pytest .mark .integration
440+ def test_live_run_with_response_format_json_schema (self ):
441+ response_schema = {
442+ "type" : "json_schema" ,
443+ "json_schema" : {
444+ "name" : "CapitalCity" ,
445+ "strict" : True ,
446+ "schema" : {
447+ "title" : "CapitalCity" ,
448+ "type" : "object" ,
449+ "properties" : {
450+ "city" : {"title" : "City" , "type" : "string" },
451+ "country" : {"title" : "Country" , "type" : "string" },
452+ },
453+ "required" : ["city" , "country" ],
454+ "additionalProperties" : False ,
455+ },
456+ },
457+ }
458+
459+ chat_messages = [ChatMessage .from_user ("What's the capital of France?" )]
460+ comp = TogetherAIChatGenerator (generation_kwargs = {"response_format" : response_schema })
461+ results = comp .run (chat_messages )
462+ assert len (results ["replies" ]) == 1
463+ message : ChatMessage = results ["replies" ][0 ]
464+ msg = json .loads (message .text )
465+ assert "Paris" in msg ["city" ]
466+ assert isinstance (msg ["country" ], str )
467+ assert "France" in msg ["country" ]
468+ assert message .meta ["finish_reason" ] == "stop"
469+
470+ @pytest .mark .skipif (
471+ not os .environ .get ("TOGETHER_API_KEY" , None ),
472+ reason = "Export an env var called TOGETHER_API_KEY containing the Together AI API key to run this test." ,
473+ )
474+ @pytest .mark .integration
475+ def test_live_run_with_response_format_pydantic_model (self , calendar_event_model ):
476+ chat_messages = [
477+ ChatMessage .from_user ("The marketing summit takes place on October12th at the Hilton Hotel downtown." )
478+ ]
479+ component = TogetherAIChatGenerator (generation_kwargs = {"response_format" : calendar_event_model })
480+ results = component .run (chat_messages )
481+ assert len (results ["replies" ]) == 1
482+ message : ChatMessage = results ["replies" ][0 ]
483+ msg = json .loads (message .text )
484+ assert "Marketing Summit" in msg ["event_name" ]
485+ assert isinstance (msg ["event_date" ], str )
486+ assert isinstance (msg ["event_location" ], str )
487+
397488 @pytest .mark .skipif (
398489 not os .environ .get ("TOGETHER_API_KEY" , None ),
399490 reason = "Export an env var called TOGETHER_API_KEY containing the Together AI API key to run this test." ,
0 commit comments