11import uuid
22from concurrent .futures import ThreadPoolExecutor
33
4- import pytest
5- from agents import ModelResponse , Usage
6- from openai .types .responses import ResponseOutputMessage , ResponseOutputText
74from temporalio .client import Client
85from temporalio .worker import Worker
6+ from temporalio .contrib .openai_agents .testing import (
7+ AgentEnvironment ,
8+ ResponseBuilders ,
9+ TestModel ,
10+ )
911
1012from openai_agents .basic .workflows .hello_world_workflow import HelloWorldAgent
11- from tests .openai_agents .conftest import sequential_test_model
1213
1314
14- @pytest .fixture
15- def test_model ():
16- return sequential_test_model (
17- [
18- ModelResponse (
19- output = [
20- ResponseOutputMessage (
21- id = "1" ,
22- content = [
23- ResponseOutputText (
24- annotations = [],
25- text = "This is a haiku (not really)" ,
26- type = "output_text" ,
27- )
28- ],
29- role = "assistant" ,
30- status = "completed" ,
31- type = "message" ,
32- )
33- ],
34- usage = Usage (
35- requests = 1 , input_tokens = 1 , output_tokens = 1 , total_tokens = 1
36- ),
37- response_id = "1" ,
38- )
39- ]
40- )
41-
42-
43- @pytest .fixture
44- def test_model ():
15+ def haiku_test_model ():
4516 return TestModel .returning_responses (
4617 [ResponseBuilders .output_message ("This is a haiku (not really)" )]
4718 )
@@ -50,17 +21,19 @@ def test_model():
5021async def test_execute_workflow (client : Client ):
5122 task_queue_name = str (uuid .uuid4 ())
5223
53- async with Worker (
54- client ,
55- task_queue = task_queue_name ,
56- workflows = [HelloWorldAgent ],
57- activity_executor = ThreadPoolExecutor (5 ),
58- ):
59- result = await client .execute_workflow (
60- HelloWorldAgent .run ,
61- "Write a recursive haiku about recursive haikus." ,
62- id = str (uuid .uuid4 ()),
24+ async with AgentEnvironment (model = haiku_test_model ()) as agent_env :
25+ client = agent_env .applied_on_client (client )
26+ async with Worker (
27+ client ,
6328 task_queue = task_queue_name ,
64- )
65- assert isinstance (result , str )
66- assert len (result ) > 0
29+ workflows = [HelloWorldAgent ],
30+ activity_executor = ThreadPoolExecutor (5 ),
31+ ):
32+ result = await client .execute_workflow (
33+ HelloWorldAgent .run ,
34+ "Write a recursive haiku about recursive haikus." ,
35+ id = str (uuid .uuid4 ()),
36+ task_queue = task_queue_name ,
37+ )
38+ assert isinstance (result , str )
39+ assert len (result ) > 0
0 commit comments