From 6263510c6986126751952f41c7baacb70b04eefd Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Mon, 21 Jul 2025 21:44:24 +0000 Subject: [PATCH 01/11] Add optional `id` field to toolsets --- docs/tools.md | 3 +- docs/toolsets.md | 15 +- pydantic_ai_slim/pydantic_ai/_output.py | 4 + pydantic_ai_slim/pydantic_ai/ag_ui.py | 3 +- pydantic_ai_slim/pydantic_ai/agent.py | 2 +- pydantic_ai_slim/pydantic_ai/ext/aci.py | 6 +- pydantic_ai_slim/pydantic_ai/ext/langchain.py | 4 +- pydantic_ai_slim/pydantic_ai/mcp.py | 154 ++++++++++++++++-- .../pydantic_ai/toolsets/abstract.py | 18 +- .../pydantic_ai/toolsets/combined.py | 10 +- .../pydantic_ai/toolsets/deferred.py | 13 +- .../pydantic_ai/toolsets/function.py | 14 +- .../pydantic_ai/toolsets/prefixed.py | 4 + .../pydantic_ai/toolsets/wrapper.py | 8 + tests/test_examples.py | 4 + tests/test_mcp.py | 2 +- tests/test_tools.py | 11 +- 17 files changed, 243 insertions(+), 32 deletions(-) diff --git a/docs/tools.md b/docs/tools.md index 4b40e7881..eddffb703 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -770,7 +770,7 @@ from pydantic_ai.ext.langchain import LangChainToolset toolkit = SlackToolkit() -toolset = LangChainToolset(toolkit.get_tools()) +toolset = LangChainToolset(toolkit.get_tools(), id='slack') agent = Agent('openai:gpt-4o', toolsets=[toolset]) # ... @@ -823,6 +823,7 @@ toolset = ACIToolset( 'OPEN_WEATHER_MAP__FORECAST', ], linked_account_owner_id=os.getenv('LINKED_ACCOUNT_OWNER_ID'), + id='open_weather_map', ) agent = Agent('openai:gpt-4o', toolsets=[toolset]) diff --git a/docs/toolsets.md b/docs/toolsets.md index 5caac22c0..add7009f9 100644 --- a/docs/toolsets.md +++ b/docs/toolsets.md @@ -84,7 +84,10 @@ def temperature_fahrenheit(city: str) -> float: return 69.8 -weather_toolset = FunctionToolset(tools=[temperature_celsius, temperature_fahrenheit]) +weather_toolset = FunctionToolset( + tools=[temperature_celsius, temperature_fahrenheit], + id='weather', # (1)! +) @weather_toolset.tool @@ -95,10 +98,10 @@ def conditions(ctx: RunContext, city: str) -> str: return "It's raining" -datetime_toolset = FunctionToolset() +datetime_toolset = FunctionToolset(id='datetime') datetime_toolset.add_function(lambda: datetime.now(), name='now') -test_model = TestModel() # (1)! +test_model = TestModel() # (2)! agent = Agent(test_model) result = agent.run_sync('What tools are available?', toolsets=[weather_toolset]) @@ -110,7 +113,8 @@ print([t.name for t in test_model.last_model_request_parameters.function_tools]) #> ['now'] ``` -1. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. +1. `FunctionToolset` supports an optional `id` argument that can help to identify the toolset in error messages. A toolset also needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the toolset's activities within the workflow. +2. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. _(This example is complete, it can be run "as is")_ @@ -609,7 +613,7 @@ from pydantic_ai.ext.langchain import LangChainToolset toolkit = SlackToolkit() -toolset = LangChainToolset(toolkit.get_tools()) +toolset = LangChainToolset(toolkit.get_tools(), id='slack') agent = Agent('openai:gpt-4o', toolsets=[toolset]) # ... @@ -634,6 +638,7 @@ toolset = ACIToolset( 'OPEN_WEATHER_MAP__FORECAST', ], linked_account_owner_id=os.getenv('LINKED_ACCOUNT_OWNER_ID'), + id='open_weather_map', ) agent = Agent('openai:gpt-4o', toolsets=[toolset]) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 67849b0c7..71e7251fd 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -977,6 +977,10 @@ def __init__( self.max_retries = max_retries self.output_validators = output_validators or [] + @property + def id(self) -> str | None: + return 'output' + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: return { tool_def.name: ToolsetTool( diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index a43d8bda4..13201abb0 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -273,7 +273,8 @@ async def run( parameters_json_schema=tool.parameters, ) for tool in run_input.tools - ] + ], + id='ag_ui_frontend', ) toolsets = [*toolsets, toolset] if toolsets else [toolset] diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 9f6348c6c..c3deb779c 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -420,7 +420,7 @@ def __init__( if self._output_toolset: self._output_toolset.max_retries = self._max_result_retries - self._function_toolset = FunctionToolset(tools, max_retries=retries) + self._function_toolset = FunctionToolset(tools, max_retries=retries, id='agent') self._user_toolsets = toolsets or () self.history_processors = history_processors or [] diff --git a/pydantic_ai_slim/pydantic_ai/ext/aci.py b/pydantic_ai_slim/pydantic_ai/ext/aci.py index 6cd43402a..ef686d134 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/aci.py +++ b/pydantic_ai_slim/pydantic_ai/ext/aci.py @@ -71,5 +71,7 @@ def implementation(*args: Any, **kwargs: Any) -> str: class ACIToolset(FunctionToolset): """A toolset that wraps ACI.dev tools.""" - def __init__(self, aci_functions: Sequence[str], linked_account_owner_id: str): - super().__init__([tool_from_aci(aci_function, linked_account_owner_id) for aci_function in aci_functions]) + def __init__(self, aci_functions: Sequence[str], linked_account_owner_id: str, id: str | None = None): + super().__init__( + [tool_from_aci(aci_function, linked_account_owner_id) for aci_function in aci_functions], id=id + ) diff --git a/pydantic_ai_slim/pydantic_ai/ext/langchain.py b/pydantic_ai_slim/pydantic_ai/ext/langchain.py index 3fb407938..3782c0b9d 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/langchain.py +++ b/pydantic_ai_slim/pydantic_ai/ext/langchain.py @@ -65,5 +65,5 @@ def proxy(*args: Any, **kwargs: Any) -> str: class LangChainToolset(FunctionToolset): """A toolset that wraps LangChain tools.""" - def __init__(self, tools: list[LangChainTool]): - super().__init__([tool_from_langchain(tool) for tool in tools]) + def __init__(self, tools: list[LangChainTool], id: str | None = None): + super().__init__([tool_from_langchain(tool) for tool in tools], id=id) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 2ca7950b3..efdf2eb40 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -61,10 +61,12 @@ class MCPServer(AbstractToolset[Any], ABC): timeout: float = 5 process_tool_call: ProcessToolCallback | None = None allow_sampling: bool = True - max_retries: int = 1 sampling_model: models.Model | None = None + max_retries: int = 1 # } end of "abstract fields" + _id: str | None = field(init=False, default=None) + _enter_lock: Lock = field(compare=False) _running_count: int _exit_stack: AsyncExitStack | None @@ -73,7 +75,29 @@ class MCPServer(AbstractToolset[Any], ABC): _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] _write_stream: MemoryObjectSendStream[SessionMessage] - def __post_init__(self): + def __init__( + self, + tool_prefix: str | None = None, + log_level: mcp_types.LoggingLevel | None = None, + log_handler: LoggingFnT | None = None, + timeout: float = 5, + process_tool_call: ProcessToolCallback | None = None, + allow_sampling: bool = True, + sampling_model: models.Model | None = None, + max_retries: int = 1, + id: str | None = None, + ): + self.tool_prefix = tool_prefix + self.log_level = log_level + self.log_handler = log_handler + self.timeout = timeout + self.process_tool_call = process_tool_call + self.allow_sampling = allow_sampling + self.sampling_model = sampling_model + self.max_retries = max_retries + + self._id = id or tool_prefix + self._enter_lock = Lock() self._running_count = 0 self._exit_stack = None @@ -93,7 +117,11 @@ async def client_streams( yield @property - def name(self) -> str: + def id(self) -> str | None: + return self._id + + @property + def label(self) -> str: return repr(self) @property @@ -294,7 +322,7 @@ def _map_tool_result_part( assert_never(part) -@dataclass +@dataclass(init=False) class MCPServerStdio(MCPServer): """Runs an MCP server in a subprocess and communicates with it over stdin/stdout. @@ -378,11 +406,61 @@ async def main(): allow_sampling: bool = True """Whether to allow MCP sampling through this client.""" + sampling_model: models.Model | None = None + """The model to use for sampling.""" + max_retries: int = 1 """The maximum number of times to retry a tool call.""" - sampling_model: models.Model | None = None - """The model to use for sampling.""" + def __init__( + self, + command: str, + args: Sequence[str], + env: dict[str, str] | None = None, + cwd: str | Path | None = None, + id: str | None = None, + tool_prefix: str | None = None, + log_level: mcp_types.LoggingLevel | None = None, + log_handler: LoggingFnT | None = None, + timeout: float = 5, + process_tool_call: ProcessToolCallback | None = None, + allow_sampling: bool = True, + sampling_model: models.Model | None = None, + max_retries: int = 1, + ): + """Build a new MCP server. + + Args: + command: The command to run. + args: The arguments to pass to the command. + env: The environment variables to set in the subprocess. + cwd: The working directory to use when spawning the process. + id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow. + tool_prefix: A prefix to add to all tools that are registered with the server. + log_level: The log level to set when connecting to the server, if any. + log_handler: A handler for logging messages from the server. + timeout: The timeout in seconds to wait for the client to initialize. + process_tool_call: Hook to customize tool calling and optionally pass extra metadata. + allow_sampling: Whether to allow MCP sampling through this client. + sampling_model: The model to use for sampling. + max_retries: The maximum number of times to retry a tool call. + """ + self.command = command + self.args = args + self.env = env + self.cwd = cwd + + super().__init__( + tool_prefix, + log_level, + log_handler, + timeout, + process_tool_call, + allow_sampling, + sampling_model, + max_retries, + id, + ) @asynccontextmanager async def client_streams( @@ -398,7 +476,10 @@ async def client_streams( yield read_stream, write_stream def __repr__(self) -> str: - return f'MCPServerStdio(command={self.command!r}, args={self.args!r}, tool_prefix={self.tool_prefix!r})' + if self.id: + return f'{self.__class__.__name__} {self.id!r}' + else: + return f'{self.__class__.__name__}(command={self.command!r}, args={self.args!r})' @dataclass @@ -479,11 +560,61 @@ class _MCPServerHTTP(MCPServer): allow_sampling: bool = True """Whether to allow MCP sampling through this client.""" + sampling_model: models.Model | None = None + """The model to use for sampling.""" + max_retries: int = 1 """The maximum number of times to retry a tool call.""" - sampling_model: models.Model | None = None - """The model to use for sampling.""" + def __init__( + self, + url: str, + headers: dict[str, Any] | None = None, + http_client: httpx.AsyncClient | None = None, + sse_read_timeout: float = 5 * 60, + id: str | None = None, + tool_prefix: str | None = None, + log_level: mcp_types.LoggingLevel | None = None, + log_handler: LoggingFnT | None = None, + timeout: float = 5, + process_tool_call: ProcessToolCallback | None = None, + allow_sampling: bool = True, + sampling_model: models.Model | None = None, + max_retries: int = 1, + ): + """Build a new MCP server. + + Args: + url: The URL of the endpoint on the MCP server. + headers: Optional HTTP headers to be sent with each request to the endpoint. + http_client: An `httpx.AsyncClient` to use with the endpoint. + sse_read_timeout: Maximum time in seconds to wait for new SSE messages before timing out. + id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow. + tool_prefix: A prefix to add to all tools that are registered with the server. + log_level: The log level to set when connecting to the server, if any. + log_handler: A handler for logging messages from the server. + timeout: The timeout in seconds to wait for the client to initialize. + process_tool_call: Hook to customize tool calling and optionally pass extra metadata. + allow_sampling: Whether to allow MCP sampling through this client. + sampling_model: The model to use for sampling. + max_retries: The maximum number of times to retry a tool call. + """ + self.url = url + self.headers = headers + self.http_client = http_client + self.sse_read_timeout = sse_read_timeout + + super().__init__( + tool_prefix, + log_level, + log_handler, + timeout, + process_tool_call, + allow_sampling, + sampling_model, + max_retries, + id, + ) @property @abstractmethod @@ -546,7 +677,10 @@ def httpx_client_factory( yield read_stream, write_stream def __repr__(self) -> str: # pragma: no cover - return f'{self.__class__.__name__}(url={self.url!r}, tool_prefix={self.tool_prefix!r})' + if self.id: + return f'{self.__class__.__name__} {self.id!r}' + else: + return f'{self.__class__.__name__}(url={self.url!r})' @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py b/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py index 455336418..d73119e58 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py @@ -70,9 +70,23 @@ class AbstractToolset(ABC, Generic[AgentDepsT]): """ @property - def name(self) -> str: + @abstractmethod + def id(self) -> str | None: + """An ID for the toolset that is unique among all toolsets registered with the same agent. + + If you're implementing a concrete implementation that users can instantiate more than once, you should let them optionally pass a custom ID to the constructor and return that here. + + A toolset needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the toolset's activities within the workflow. + """ + raise NotImplementedError() + + @property + def label(self) -> str: """The name of the toolset for use in error messages.""" - return self.__class__.__name__.replace('Toolset', ' toolset') + label = self.__class__.__name__ + if self.id: + label += f' {self.id!r}' + return label @property def tool_name_conflict_hint(self) -> str: diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py index 4b1511fae..750c54b8e 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -40,6 +40,14 @@ def __post_init__(self): self._entered_count = 0 self._exit_stack = None + @property + def id(self) -> str | None: + return None + + @property + def label(self) -> str: + return f'{self.__class__.__name__}({", ".join(toolset.label for toolset in self.toolsets)})' + async def __aenter__(self) -> Self: async with self._enter_lock: if self._entered_count == 0: @@ -64,7 +72,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[ for name, tool in tools.items(): if existing_tools := all_tools.get(name): raise UserError( - f'{toolset.name} defines a tool whose name conflicts with existing tool from {existing_tools.toolset.name}: {name!r}. {toolset.tool_name_conflict_hint}' + f'{toolset.label} defines a tool whose name conflicts with existing tool from {existing_tools.toolset.label}: {name!r}. {toolset.tool_name_conflict_hint}' ) all_tools[name] = _CombinedToolsetTool( diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py index 3ad2e976b..a67c3b0ad 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass, replace +from dataclasses import dataclass, field, replace from typing import Any from pydantic_core import SchemaValidator, core_schema @@ -12,7 +12,7 @@ TOOL_SCHEMA_VALIDATOR = SchemaValidator(schema=core_schema.any_schema()) -@dataclass +@dataclass(init=False) class DeferredToolset(AbstractToolset[AgentDepsT]): """A toolset that holds deferred tools whose results will be produced outside of the Pydantic AI agent run in which they were called. @@ -20,6 +20,15 @@ class DeferredToolset(AbstractToolset[AgentDepsT]): """ tool_defs: list[ToolDefinition] + _id: str | None = field(init=False, default=None) + + def __init__(self, tool_defs: list[ToolDefinition], id: str | None = None): + self._id = id + self.tool_defs = tool_defs + + @property + def id(self) -> str | None: + return self._id async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: return { diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/function.py b/pydantic_ai_slim/pydantic_ai/toolsets/function.py index 63f44a1f0..81c667a9e 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/function.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/function.py @@ -35,14 +35,22 @@ class FunctionToolset(AbstractToolset[AgentDepsT]): max_retries: int = field(default=1) tools: dict[str, Tool[Any]] = field(default_factory=dict) + _id: str | None = field(init=False, default=None) - def __init__(self, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], max_retries: int = 1): + def __init__( + self, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], + max_retries: int = 1, + id: str | None = None, + ): """Build a new function toolset. Args: tools: The tools to add to the toolset. max_retries: The maximum number of retries for each tool during a run. + id: An optional unique ID for the toolset. A toolset needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the toolset's activities within the workflow. """ + self._id = id self.max_retries = max_retries self.tools = {} for tool in tools: @@ -51,6 +59,10 @@ def __init__(self, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, else: self.add_function(tool) + @property + def id(self) -> str | None: + return self._id + @overload def tool(self, func: ToolFuncEither[AgentDepsT, ToolParams], /) -> ToolFuncEither[AgentDepsT, ToolParams]: ... diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py b/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py index be70ed4f0..a430b0dab 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py @@ -17,6 +17,10 @@ class PrefixedToolset(WrapperToolset[AgentDepsT]): prefix: str + @property + def tool_name_conflict_hint(self) -> str: + return 'Change the `prefix` attribute to avoid name conflicts.' + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: return { new_name: replace( diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py index 8440f1c46..6d5a409a5 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py @@ -18,6 +18,14 @@ class WrapperToolset(AbstractToolset[AgentDepsT]): wrapped: AbstractToolset[AgentDepsT] + @property + def id(self) -> str | None: + return None + + @property + def label(self) -> str: + return f'{self.__class__.__name__}({self.wrapped.label})' + async def __aenter__(self) -> Self: await self.wrapped.__aenter__() return self diff --git a/tests/test_examples.py b/tests/test_examples.py index 4b6bc27bc..380c56d66 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -263,6 +263,10 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: class MockMCPServer(AbstractToolset[Any]): + @property + def id(self) -> str | None: + return None + async def __aenter__(self) -> MockMCPServer: return self diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 94528b40a..b5ad7e1ae 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -237,7 +237,7 @@ def get_none() -> None: # pragma: no cover with pytest.raises( UserError, match=re.escape( - "MCPServerStdio(command='python', args=['-m', 'tests.mcp_server'], tool_prefix=None) defines a tool whose name conflicts with existing tool from Function toolset: 'get_none'. Consider setting `tool_prefix` to avoid name conflicts." + "MCPServerStdio(command='python', args=['-m', 'tests.mcp_server']) defines a tool whose name conflicts with existing tool from FunctionToolset 'agent': 'get_none'. Set the `tool_prefix` attribute to avoid name conflicts." ), ): await agent.run('Get me a conflict') diff --git a/tests/test_tools.py b/tests/test_tools.py index c72cd1e08..e5dac6d3f 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,4 +1,5 @@ import json +import re from dataclasses import dataclass, replace from typing import Annotated, Any, Callable, Literal, Union @@ -586,7 +587,9 @@ def test_tool_return_conflict(): # this raises an error with pytest.raises( UserError, - match="Function toolset defines a tool whose name conflicts with existing tool from Output toolset: 'ctx_tool'. Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts.", + match=re.escape( + "FunctionToolset 'agent' defines a tool whose name conflicts with existing tool from OutputToolset 'output': 'ctx_tool'. Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts." + ), ): Agent('test', tools=[ctx_tool], deps_type=int, output_type=ToolOutput(int, name='ctx_tool')).run_sync( '', deps=0 @@ -596,7 +599,9 @@ def test_tool_return_conflict(): def test_tool_name_conflict_hint(): with pytest.raises( UserError, - match="Prefixed toolset defines a tool whose name conflicts with existing tool from Function toolset: 'foo_tool'. Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts.", + match=re.escape( + "PrefixedToolset(FunctionToolset 'tool') defines a tool whose name conflicts with existing tool from FunctionToolset 'agent': 'foo_tool'. Change the `prefix` attribute to avoid name conflicts." + ), ): def tool(x: int) -> int: @@ -605,7 +610,7 @@ def tool(x: int) -> int: def foo_tool(x: str) -> str: return x + 'foo' # pragma: no cover - function_toolset = FunctionToolset([tool]) + function_toolset = FunctionToolset([tool], id='tool') prefixed_toolset = PrefixedToolset(function_toolset, 'foo') Agent('test', tools=[foo_tool], toolsets=[prefixed_toolset]).run_sync('') From d02361a3a453dc784f61ae5680b22b006d828590 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Mon, 21 Jul 2025 22:05:36 +0000 Subject: [PATCH 02/11] WIP: temporalize_agent --- pydantic_ai_slim/pydantic_ai/agent.py | 8 ++ pydantic_ai_slim/pydantic_ai/mcp.py | 2 +- .../pydantic_ai/temporal/__init__.py | 104 +++++++++++++++ .../pydantic_ai/temporal/agent.py | 66 ++++++++++ .../pydantic_ai/temporal/function_toolset.py | 100 +++++++++++++++ .../pydantic_ai/temporal/mcp_server.py | 121 ++++++++++++++++++ .../pydantic_ai/temporal/model.py | 116 +++++++++++++++++ pydantic_ai_slim/pyproject.toml | 2 + pyproject.toml | 2 +- temporal.py | 99 ++++++++++++++ uv.lock | 38 +++++- 11 files changed, 653 insertions(+), 5 deletions(-) create mode 100644 pydantic_ai_slim/pydantic_ai/temporal/__init__.py create mode 100644 pydantic_ai_slim/pydantic_ai/temporal/agent.py create mode 100644 pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py create mode 100644 pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py create mode 100644 pydantic_ai_slim/pydantic_ai/temporal/model.py create mode 100644 temporal.py diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index c3deb779c..ad3012d60 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1698,6 +1698,14 @@ def _get_toolset( return CombinedToolset(all_toolsets) + @property + def toolset(self) -> AbstractToolset[AgentDepsT]: + """The complete toolset that will be available to the model during an agent run. + + This will include function tools registered directly to the agent, output tools, and user-provided toolsets including MCP servers. + """ + return self._get_toolset() + def _infer_name(self, function_frame: FrameType | None) -> None: """Infer the agent name from the call frame. diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index efdf2eb40..2bbcb7e75 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -126,7 +126,7 @@ def label(self) -> str: @property def tool_name_conflict_hint(self) -> str: - return 'Consider setting `tool_prefix` to avoid name conflicts.' + return 'Set the `tool_prefix` attribute to avoid name conflicts.' async def list_tools(self) -> list[mcp_types.Tool]: """Retrieve tools that are currently active on the server. diff --git a/pydantic_ai_slim/pydantic_ai/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/temporal/__init__.py new file mode 100644 index 000000000..21f0ca90d --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/temporal/__init__.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import timedelta +from typing import Any, Callable + +from temporalio.common import Priority, RetryPolicy +from temporalio.workflow import ActivityCancellationType, VersioningIntent + +from pydantic_ai._run_context import AgentDepsT, RunContext + + +class _TemporalRunContext(RunContext[AgentDepsT]): + _data: dict[str, Any] + + def __init__(self, **kwargs: Any): + self._data = kwargs + setattr( + self, + '__dataclass_fields__', + {name: field for name, field in RunContext.__dataclass_fields__.items() if name in kwargs}, + ) + + def __getattribute__(self, name: str) -> Any: + try: + return super().__getattribute__(name) + except AttributeError as e: + data = super().__getattribute__('_data') + if name in data: + return data[name] + raise e # TODO: Explain how to make a new run context attribute available + + @classmethod + def serialize_run_context(cls, ctx: RunContext[AgentDepsT]) -> dict[str, Any]: + return { + 'deps': ctx.deps, + 'retries': ctx.retries, + 'tool_call_id': ctx.tool_call_id, + 'tool_name': ctx.tool_name, + 'retry': ctx.retry, + 'run_step': ctx.run_step, + } + + @classmethod + def deserialize_run_context(cls, ctx: dict[str, Any]) -> RunContext[AgentDepsT]: + return cls(**ctx) + + +@dataclass +class TemporalSettings: + """Settings for Temporal `execute_activity` and Pydantic AI-specific Temporal activity behavior.""" + + # Temporal settings + task_queue: str | None = None + schedule_to_close_timeout: timedelta | None = None + schedule_to_start_timeout: timedelta | None = None + start_to_close_timeout: timedelta | None = None + heartbeat_timeout: timedelta | None = None + retry_policy: RetryPolicy | None = None + cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL + activity_id: str | None = None + versioning_intent: VersioningIntent | None = None + summary: str | None = None + priority: Priority = Priority.default + + # Pydantic AI specific + tool_settings: dict[str, dict[str, TemporalSettings]] | None = None + + def for_tool(self, toolset_id: str, tool_id: str) -> TemporalSettings: + if self.tool_settings is None: + return self + return self.tool_settings.get(toolset_id, {}).get(tool_id, self) + + serialize_run_context: Callable[[RunContext], Any] = _TemporalRunContext.serialize_run_context + deserialize_run_context: Callable[[dict[str, Any]], RunContext] = _TemporalRunContext.deserialize_run_context + + @property + def execute_activity_kwargs(self) -> dict[str, Any]: + return { + 'task_queue': self.task_queue, + 'schedule_to_close_timeout': self.schedule_to_close_timeout, + 'schedule_to_start_timeout': self.schedule_to_start_timeout, + 'start_to_close_timeout': self.start_to_close_timeout, + 'heartbeat_timeout': self.heartbeat_timeout, + 'retry_policy': self.retry_policy, + 'cancellation_type': self.cancellation_type, + 'activity_id': self.activity_id, + 'versioning_intent': self.versioning_intent, + 'summary': self.summary, + 'priority': self.priority, + } + + +def initialize_temporal(): + """Explicitly import types without which Temporal will not be able to serialize/deserialize `ModelMessage`s.""" + from pydantic_ai.messages import ( # noqa F401 + ModelResponse, # pyright: ignore[reportUnusedImport] + ImageUrl, # pyright: ignore[reportUnusedImport] + AudioUrl, # pyright: ignore[reportUnusedImport] + DocumentUrl, # pyright: ignore[reportUnusedImport] + VideoUrl, # pyright: ignore[reportUnusedImport] + BinaryContent, # pyright: ignore[reportUnusedImport] + UserContent, # pyright: ignore[reportUnusedImport] + ) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/agent.py b/pydantic_ai_slim/pydantic_ai/temporal/agent.py new file mode 100644 index 000000000..b8a0fc538 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/temporal/agent.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from typing import Any, Callable + +from pydantic_ai.agent import Agent +from pydantic_ai.mcp import MCPServer +from pydantic_ai.toolsets.abstract import AbstractToolset +from pydantic_ai.toolsets.function import FunctionToolset + +from ..models import Model +from . import TemporalSettings +from .function_toolset import temporalize_function_toolset +from .mcp_server import temporalize_mcp_server +from .model import temporalize_model + + +def temporalize_toolset(toolset: AbstractToolset, settings: TemporalSettings | None) -> list[Callable[..., Any]]: + """Temporalize a toolset. + + Args: + toolset: The toolset to temporalize. + settings: The temporal settings to use. + """ + if isinstance(toolset, FunctionToolset): + return temporalize_function_toolset(toolset, settings) + elif isinstance(toolset, MCPServer): + return temporalize_mcp_server(toolset, settings) + else: + return [] + + +def temporalize_agent( + agent: Agent, + settings: TemporalSettings | None = None, + temporalize_toolset_func: Callable[ + [AbstractToolset, TemporalSettings | None], list[Callable[..., Any]] + ] = temporalize_toolset, +) -> list[Callable[..., Any]]: + """Temporalize an agent. + + Args: + agent: The agent to temporalize. + settings: The temporal settings to use. + temporalize_toolset_func: The function to use to temporalize the toolsets. + """ + if existing_activities := getattr(agent, '__temporal_activities', None): + return existing_activities + + settings = settings or TemporalSettings() + + # TODO: Doesn't consider model/toolsets passed at iter time. + + activities: list[Callable[..., Any]] = [] + if isinstance(agent.model, Model): + activities.extend(temporalize_model(agent.model, settings)) + + def temporalize_toolset(toolset: AbstractToolset) -> None: + activities.extend(temporalize_toolset_func(toolset, settings)) + + agent.toolset.apply(temporalize_toolset) + + setattr(agent, '__temporal_activities', activities) + return activities + + +# TODO: untemporalize_agent diff --git a/pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py b/pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py new file mode 100644 index 000000000..842e86aaf --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + +from pydantic import ConfigDict, with_config +from temporalio import activity, workflow + +from pydantic_ai.toolsets.function import FunctionToolset + +from .._run_context import RunContext +from ..toolsets import ToolsetTool +from . import TemporalSettings + + +@dataclass +@with_config(ConfigDict(arbitrary_types_allowed=True)) +class _CallToolParams: + name: str + tool_args: dict[str, Any] + serialized_run_context: Any + + +def temporalize_function_toolset( + toolset: FunctionToolset, + settings: TemporalSettings | None = None, +) -> list[Callable[..., Any]]: + """Temporalize a function toolset. + + Args: + toolset: The function toolset to temporalize. + settings: The temporal settings to use. + """ + if activities := getattr(toolset, '__temporal_activities', None): + return activities + + id = toolset.id + if not id: + raise ValueError( + "A function toolset needs to have an ID in order to be used in a durable execution environment like Temporal. The ID will be used to identify the toolset's activities within the workflow." + ) + + settings = settings or TemporalSettings() + + original_call_tool = toolset.call_tool + + @activity.defn(name=f'function_toolset__{id}__call_tool') + async def call_tool_activity(params: _CallToolParams) -> Any: + name = params.name + ctx = settings.for_tool(id, name).deserialize_run_context(params.serialized_run_context) + tool = (await toolset.get_tools(ctx))[name] + return await original_call_tool(name, params.tool_args, ctx, tool) + + async def call_tool(name: str, tool_args: dict[str, Any], ctx: RunContext, tool: ToolsetTool) -> Any: + tool_settings = settings.for_tool(id, name) + serialized_run_context = tool_settings.serialize_run_context(ctx) + return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] + activity=call_tool_activity, + arg=_CallToolParams(name=name, tool_args=tool_args, serialized_run_context=serialized_run_context), + **tool_settings.execute_activity_kwargs, + ) + + toolset.call_tool = call_tool + + activities = [call_tool_activity] + setattr(toolset, '__temporal_activities', activities) + return activities + + +# class TemporalFunctionToolset(FunctionToolset[AgentDepsT]): +# def __init__( +# self, +# tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], +# max_retries: int = 1, +# temporal_settings: TemporalSettings | None = None, +# serialize_run_context: Callable[[RunContext[AgentDepsT]], Any] | None = None, +# deserialize_run_context: Callable[[Any], RunContext[AgentDepsT]] | None = None, +# ): +# super().__init__(tools, max_retries) +# self.temporal_settings = temporal_settings or TemporalSettings() +# self.serialize_run_context = serialize_run_context or TemporalRunContext[AgentDepsT].serialize_run_context +# self.deserialize_run_context = deserialize_run_context or TemporalRunContext[AgentDepsT].deserialize_run_context + +# @activity.defn(name='function_toolset_call_tool') +# async def call_tool_activity(params: FunctionCallToolParams) -> Any: +# ctx = self.deserialize_run_context(params.serialized_run_context) +# tool = (await self.get_tools(ctx))[params.name] +# return await FunctionToolset[AgentDepsT].call_tool(self, params.name, params.tool_args, ctx, tool) + +# self.call_tool_activity = call_tool_activity + +# async def call_tool( +# self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] +# ) -> Any: +# serialized_run_context = self.serialize_run_context(ctx) +# return await workflow.execute_activity( +# activity=self.call_tool_activity, +# arg=FunctionCallToolParams(name=name, tool_args=tool_args, serialized_run_context=serialized_run_context), +# **self.temporal_settings.__dict__, +# ) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py b/pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py new file mode 100644 index 000000000..8d3be7624 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + +from mcp import types as mcp_types +from pydantic import ConfigDict, with_config +from temporalio import activity, workflow + +from pydantic_ai.mcp import MCPServer, ToolResult + +from . import TemporalSettings + + +@dataclass +@with_config(ConfigDict(arbitrary_types_allowed=True)) +class _CallToolParams: + name: str + tool_args: dict[str, Any] + metadata: dict[str, Any] | None = None + + +def temporalize_mcp_server( + server: MCPServer, + settings: TemporalSettings | None = None, +) -> list[Callable[..., Any]]: + """Temporalize an MCP server. + + Args: + server: The MCP server to temporalize. + settings: The temporal settings to use. + """ + if activities := getattr(server, '__temporal_activities', None): + return activities + + id = server.id + if not id: + raise ValueError( + "An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal. The ID will be used to identify the server's activities within the workflow." + ) + + settings = settings or TemporalSettings() + + original_list_tools = server.list_tools + original_direct_call_tool = server.direct_call_tool + + @activity.defn(name=f'mcp_server__{id}__list_tools') + async def list_tools_activity() -> list[mcp_types.Tool]: + return await original_list_tools() + + @activity.defn(name=f'mcp_server__{id}__call_tool') + async def call_tool_activity(params: _CallToolParams) -> ToolResult: + return await original_direct_call_tool(params.name, params.tool_args, params.metadata) + + async def list_tools() -> list[mcp_types.Tool]: + return await workflow.execute_activity( # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType] + activity=list_tools_activity, + **settings.execute_activity_kwargs, + ) + + async def direct_call_tool( + name: str, + args: dict[str, Any], + metadata: dict[str, Any] | None = None, + ) -> ToolResult: + return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] + activity=call_tool_activity, + arg=_CallToolParams(name=name, tool_args=args, metadata=metadata), + **settings.for_tool(id, name).execute_activity_kwargs, + ) + + server.list_tools = list_tools + server.direct_call_tool = direct_call_tool + + activities = [list_tools_activity, call_tool_activity] + setattr(server, '__temporal_activities', activities) + return activities + + +# class TemporalMCPServer(WrapperToolset[Any]): +# temporal_settings: TemporalSettings + +# @property +# def wrapped_server(self) -> MCPServer: +# assert isinstance(self.wrapped, MCPServer) +# return self.wrapped + +# def __init__(self, wrapped: MCPServer, temporal_settings: TemporalSettings | None = None): +# assert isinstance(self.wrapped, MCPServer) +# super().__init__(wrapped) +# self.temporal_settings = temporal_settings or TemporalSettings() + +# @activity.defn(name='mcp_server_list_tools') +# async def list_tools_activity() -> list[mcp_types.Tool]: +# return await self.wrapped_server.list_tools() + +# self.list_tools_activity = list_tools_activity + +# @activity.defn(name='mcp_server_call_tool') +# async def call_tool_activity(params: MCPCallToolParams) -> ToolResult: +# return await self.wrapped_server.direct_call_tool(params.name, params.tool_args, params.metadata) + +# self.call_tool_activity = call_tool_activity + +# async def list_tools(self) -> list[mcp_types.Tool]: +# return await workflow.execute_activity( +# activity=self.list_tools_activity, +# **self.temporal_settings.__dict__, +# ) + +# async def direct_call_tool( +# self, +# name: str, +# args: dict[str, Any], +# metadata: dict[str, Any] | None = None, +# ) -> ToolResult: +# return await workflow.execute_activity( +# activity=self.call_tool_activity, +# arg=MCPCallToolParams(name=name, tool_args=args, metadata=metadata), +# **self.temporal_settings.__dict__, +# ) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/model.py b/pydantic_ai_slim/pydantic_ai/temporal/model.py new file mode 100644 index 000000000..7f8a25e71 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/temporal/model.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import Any, Callable + +from pydantic import ConfigDict, with_config +from temporalio import activity, workflow + +from ..messages import ( + ModelMessage, + ModelResponse, +) +from ..models import Model, ModelRequestParameters, StreamedResponse +from ..settings import ModelSettings +from . import TemporalSettings + + +@dataclass +@with_config(ConfigDict(arbitrary_types_allowed=True)) +class _RequestParams: + messages: list[ModelMessage] + model_settings: ModelSettings | None + model_request_parameters: ModelRequestParameters + + +def temporalize_model(model: Model, settings: TemporalSettings | None = None) -> list[Callable[..., Any]]: + """Temporalize a model. + + Args: + model: The model to temporalize. + settings: The temporal settings to use. + """ + if activities := getattr(model, '__temporal_activities', None): + return activities + + settings = settings or TemporalSettings() + + original_request = model.request + + @activity.defn(name='model_request') + async def request_activity(params: _RequestParams) -> ModelResponse: + return await original_request(params.messages, params.model_settings, params.model_request_parameters) + + async def request( + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> ModelResponse: + return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] + activity=request_activity, + arg=_RequestParams( + messages=messages, model_settings=model_settings, model_request_parameters=model_request_parameters + ), + **settings.execute_activity_kwargs, + ) + + @asynccontextmanager + async def request_stream( + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> AsyncIterator[StreamedResponse]: + raise NotImplementedError('Cannot stream with temporal yet') + yield + + model.request = request + model.request_stream = request_stream + + activities = [request_activity] + setattr(model, '__temporal_activities', activities) + return activities + + +# @dataclass +# class TemporalModel(WrapperModel): +# temporal_settings: TemporalSettings + +# def __init__( +# self, +# wrapped: Model | KnownModelName, +# temporal_settings: TemporalSettings | None = None, +# ) -> None: +# super().__init__(wrapped) +# self.temporal_settings = temporal_settings or TemporalSettings() + +# @activity.defn +# async def request_activity(params: ModelRequestParams) -> ModelResponse: +# return await self.wrapped.request(params.messages, params.model_settings, params.model_request_parameters) + +# self.request_activity = request_activity + +# async def request( +# self, +# messages: list[ModelMessage], +# model_settings: ModelSettings | None, +# model_request_parameters: ModelRequestParameters, +# ) -> ModelResponse: +# return await workflow.execute_activity( +# activity=self.request_activity, +# arg=ModelRequestParams( +# messages=messages, model_settings=model_settings, model_request_parameters=model_request_parameters +# ), +# **self.temporal_settings.__dict__, +# ) + +# @asynccontextmanager +# async def request_stream( +# self, +# messages: list[ModelMessage], +# model_settings: ModelSettings | None, +# model_request_parameters: ModelRequestParameters, +# ) -> AsyncIterator[StreamedResponse]: +# raise NotImplementedError('Cannot stream with temporal yet') +# yield diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 3dfc4a766..10d36d6b7 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -84,6 +84,8 @@ evals = ["pydantic-evals=={{ version }}"] a2a = ["fasta2a>=0.4.1"] # AG-UI ag-ui = ["ag-ui-protocol>=0.1.8", "starlette>=0.45.3"] +# Temporal +temporal = ["temporalio>=1.13.0"] [dependency-groups] dev = [ diff --git a/pyproject.toml b/pyproject.toml index 841f186ef..fdd67f183 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ requires-python = ">=3.9" [tool.hatch.metadata.hooks.uv-dynamic-versioning] dependencies = [ - "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,huggingface,cli,mcp,evals,ag-ui]=={{ version }}", + "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,huggingface,cli,mcp,evals,ag-ui,temporal]=={{ version }}", ] [tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies] diff --git a/temporal.py b/temporal.py new file mode 100644 index 000000000..340019b8c --- /dev/null +++ b/temporal.py @@ -0,0 +1,99 @@ +# /// script +# dependencies = [ +# "temporalio", +# "logfire", +# ] +# /// +import asyncio +import random +from datetime import timedelta + +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.opentelemetry import TracingInterceptor +from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig +from temporalio.worker import Worker + +with workflow.unsafe.imports_passed_through(): + from pydantic_ai import Agent + from pydantic_ai.mcp import MCPServerStdio + from pydantic_ai.models.openai import OpenAIModel + from pydantic_ai.temporal import ( + TemporalSettings, + initialize_temporal, + ) + from pydantic_ai.temporal.agent import temporalize_agent + from pydantic_ai.toolsets.function import FunctionToolset + + initialize_temporal() + + def get_uv_index(location: str) -> int: + return 3 + + toolset = FunctionToolset(tools=[get_uv_index], id='uv_index') + mcp_server = MCPServerStdio( + 'python', + ['-m', 'tests.mcp_server'], + timeout=20, + id='test', + ) + + model = OpenAIModel('gpt-4o') + my_agent = Agent(model=model, instructions='be helpful', toolsets=[toolset, mcp_server]) + + temporal_settings = TemporalSettings( + start_to_close_timeout=timedelta(seconds=60), + tool_settings={ + 'uv_index': { + 'get_uv_index': TemporalSettings(start_to_close_timeout=timedelta(seconds=110)), + }, + }, + ) + activities = temporalize_agent(my_agent, temporal_settings) + + +def init_runtime_with_telemetry() -> Runtime: + # import logfire + + # logfire.configure(send_to_logfire=True, service_version='0.0.1', console=False) + # logfire.instrument_pydantic_ai() + # logfire.instrument_httpx(capture_all=True) + + # Setup SDK metrics to OTel endpoint + return Runtime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url='http://localhost:4318'))) + + +# Basic workflow that logs and invokes an activity +@workflow.defn +class MyAgentWorkflow: + @workflow.run + async def run(self, prompt: str) -> str: + return (await my_agent.run(prompt)).output + + +async def main(): + client = await Client.connect( + 'localhost:7233', + interceptors=[TracingInterceptor()], + data_converter=pydantic_data_converter, + runtime=init_runtime_with_telemetry(), + ) + + async with Worker( + client, + task_queue='my-agent-task-queue', + workflows=[MyAgentWorkflow], + activities=activities, + ): + output = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType] + MyAgentWorkflow.run, + 'what is 2 plus the UV Index in Mexico City? and what is the product name?', + id=f'my-agent-workflow-id-{random.random()}', + task_queue='my-agent-task-queue', + ) + print(output) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/uv.lock b/uv.lock index a47bd1445..a44ee9b56 100644 --- a/uv.lock +++ b/uv.lock @@ -3000,7 +3000,7 @@ wheels = [ name = "pydantic-ai" source = { editable = "." } dependencies = [ - { name = "pydantic-ai-slim", extra = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "vertexai"] }, + { name = "pydantic-ai-slim", extra = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "temporal", "vertexai"] }, ] [package.optional-dependencies] @@ -3039,7 +3039,7 @@ requires-dist = [ { name = "fasta2a", marker = "extra == 'a2a'", specifier = ">=0.4.1" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.11.0" }, { name = "pydantic-ai-examples", marker = "extra == 'examples'", editable = "examples" }, - { name = "pydantic-ai-slim", extras = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "vertexai"], editable = "pydantic_ai_slim" }, + { name = "pydantic-ai-slim", extras = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "temporal", "vertexai"], editable = "pydantic_ai_slim" }, ] provides-extras = ["a2a", "examples", "logfire"] @@ -3163,6 +3163,9 @@ openai = [ tavily = [ { name = "tavily-python" }, ] +temporal = [ + { name = "temporalio" }, +] vertexai = [ { name = "google-auth" }, { name = "requests" }, @@ -3219,9 +3222,10 @@ requires-dist = [ { name = "rich", marker = "extra == 'cli'", specifier = ">=13" }, { name = "starlette", marker = "extra == 'ag-ui'", specifier = ">=0.45.3" }, { name = "tavily-python", marker = "extra == 'tavily'", specifier = ">=0.5.0" }, + { name = "temporalio", marker = "extra == 'temporal'", specifier = ">=1.13.0" }, { name = "typing-inspection", specifier = ">=0.4.0" }, ] -provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "tavily", "vertexai"] +provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "tavily", "temporal", "vertexai"] [package.metadata.requires-dev] dev = [ @@ -3931,6 +3935,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a5/cd/71088461d7720128c78802289b3b36298f42745e5f8c334b0ffc157b881e/tavily_python-0.5.1-py3-none-any.whl", hash = "sha256:169601f703c55cf338758dcacfa7102473b479a9271d65a3af6fc3668990f757", size = 43767, upload-time = "2025-02-07T00:22:04.99Z" }, ] +[[package]] +name = "temporalio" +version = "1.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, + { name = "python-dateutil", marker = "python_full_version < '3.11'" }, + { name = "types-protobuf" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3e/a3/a76477b523937f47a21941188c16b3c6b1eef6baadc7c8efeea497d909de/temporalio-1.13.0.tar.gz", hash = "sha256:5a979eee5433da6ab5d8a2bcde25a1e7d454e91920acb0bf7ca93d415750828b", size = 1558745, upload-time = "2025-06-20T19:57:26.944Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f5/f4/a5a74284c671bd50ce7353ad1dad7dab1a795f891458454049e95bc5378f/temporalio-1.13.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:7ee14cab581352e77171d1e4ce01a899231abfe75c5f7233e3e260f361a344cc", size = 12086961, upload-time = "2025-06-20T19:57:15.25Z" }, + { url = "https://files.pythonhosted.org/packages/1f/b7/5dc6e34f4e9a3da8b75cb3fe0d32edca1d9201d598c38d022501d38650a9/temporalio-1.13.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:575a0c57dbb089298b4775f3aca86ebaf8d58d5ba155e7fc5509877c25e6bb44", size = 11745239, upload-time = "2025-06-20T19:57:17.934Z" }, + { url = "https://files.pythonhosted.org/packages/04/30/4b9b15af87c181fd9364b61971faa0faa07d199320d7ff1712b5d51b5bbb/temporalio-1.13.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf099a27f22c0dbc22f3d86dba76d59be5da812ff044ba3fa183e3e14bd5e9a", size = 12119197, upload-time = "2025-06-20T19:57:20.509Z" }, + { url = "https://files.pythonhosted.org/packages/46/9f/a5b627d773974c654b6cd22ed3937e7e2471023af244ea417f0e917e617b/temporalio-1.13.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7e20c711f41c66877b9d54ab33c79a14ccaac9ed498a174274f6129110f4d84", size = 12413459, upload-time = "2025-06-20T19:57:22.816Z" }, + { url = "https://files.pythonhosted.org/packages/a3/73/efb6957212eb8c8dfff26c7c2c6ddf745aa5990a3b722cff17c8feaa66fc/temporalio-1.13.0-cp39-abi3-win_amd64.whl", hash = "sha256:9286cb84c1e078b2bcc6e8c6bd0be878d8ed395be991ac0d7cff555e3a82ac0b", size = 12440644, upload-time = "2025-06-20T19:57:25.175Z" }, +] + [[package]] name = "tenacity" version = "8.5.0" @@ -4121,6 +4144,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b5/63/2463d89481e811f007b0e1cd0a91e52e141b47f9de724d20db7b861dcfec/types_certifi-2021.10.8.3-py3-none-any.whl", hash = "sha256:b2d1e325e69f71f7c78e5943d410e650b4707bb0ef32e4ddf3da37f54176e88a", size = 2136, upload-time = "2022-06-09T15:19:03.127Z" }, ] +[[package]] +name = "types-protobuf" +version = "6.30.2.20250516" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ac/6c/5cf088aaa3927d1cc39910f60f220f5ff573ab1a6485b2836e8b26beb58c/types_protobuf-6.30.2.20250516.tar.gz", hash = "sha256:aecd1881770a9bb225ede66872ef7f0da4505edd0b193108edd9892e48d49a41", size = 62254, upload-time = "2025-05-16T03:06:50.794Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/66/06a9c161f5dd5deb4f5c016ba29106a8f1903eb9a1ba77d407dd6588fecb/types_protobuf-6.30.2.20250516-py3-none-any.whl", hash = "sha256:8c226d05b5e8b2623111765fa32d6e648bbc24832b4c2fddf0fa340ba5d5b722", size = 76480, upload-time = "2025-05-16T03:06:49.444Z" }, +] + [[package]] name = "types-requests" version = "2.31.0.6" From f65770936648fc365165e96411b3b51156e04431 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 22 Jul 2025 23:29:23 +0000 Subject: [PATCH 03/11] Add Agent event_stream_handler --- pydantic_ai_slim/pydantic_ai/agent.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index ad3012d60..a888204f2 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -5,12 +5,12 @@ import json import warnings from asyncio import Lock -from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping, Sequence +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator, Mapping, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager from contextvars import ContextVar from copy import deepcopy from types import FrameType -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, TypeAlias, cast, final, overload from opentelemetry.trace import NoOpTracer, use_span from pydantic.json_schema import GenerateJsonSchema @@ -96,6 +96,14 @@ RunOutputDataT = TypeVar('RunOutputDataT') """Type variable for the result data of a run where `output_type` was customized on the run call.""" +EventStreamHandler: TypeAlias = Callable[ + [ + RunContext[AgentDepsT], + AsyncIterable[_messages.AgentStreamEvent | _messages.HandleResponseEvent], + ], + Awaitable[None], +] + @final @dataclasses.dataclass(init=False) @@ -168,6 +176,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]): _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) + _event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False) _enter_lock: Lock = dataclasses.field(repr=False) _entered_count: int = dataclasses.field(repr=False) @@ -197,6 +206,7 @@ def __init__( end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> None: ... @overload @@ -257,6 +267,7 @@ def __init__( end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> None: ... def __init__( @@ -283,6 +294,7 @@ def __init__( end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, **_deprecated_kwargs: Any, ): """Create an agent. @@ -331,6 +343,7 @@ def __init__( history_processors: Optional list of callables to process the message history before sending it to the model. Each processor takes a list of messages and returns a modified list of messages. Processors can be sync or async and are applied in sequence. + event_stream_handler: TODO: Optional handler for events from the agent stream. """ if model is None or defer_model_check: self.model = model @@ -425,6 +438,8 @@ def __init__( self.history_processors = history_processors or [] + self._event_stream_handler = event_stream_handler + self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None) self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None) self._override_toolsets: ContextVar[_utils.Option[Sequence[AbstractToolset[AgentDepsT]]]] = ContextVar( @@ -559,8 +574,12 @@ async def main(): usage=usage, toolsets=toolsets, ) as agent_run: - async for _ in agent_run: - pass + async for node in agent_run: + if self._event_stream_handler is not None and ( + self.is_model_request_node(node) or self.is_call_tools_node(node) + ): + async with node.stream(agent_run.ctx) as stream: + await self._event_stream_handler(_agent_graph.build_run_context(agent_run.ctx), stream) assert agent_run.result is not None, 'The graph run did not finish properly' return agent_run.result From 2f0489481e492e90d5d097469b32b9314145333e Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 22 Jul 2025 23:33:00 +0000 Subject: [PATCH 04/11] Pass run_context to Model.request_stream for Temporal --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 27 ++++++++++--------- .../pydantic_ai/models/__init__.py | 7 ++--- .../pydantic_ai/models/anthropic.py | 2 ++ .../pydantic_ai/models/bedrock.py | 2 ++ .../pydantic_ai/models/fallback.py | 6 +++-- .../pydantic_ai/models/function.py | 7 ++--- pydantic_ai_slim/pydantic_ai/models/gemini.py | 5 ++-- pydantic_ai_slim/pydantic_ai/models/google.py | 2 ++ pydantic_ai_slim/pydantic_ai/models/groq.py | 7 ++--- .../pydantic_ai/models/huggingface.py | 9 ++++--- .../pydantic_ai/models/instrumented.py | 4 ++- .../pydantic_ai/models/mcp_sampling.py | 4 ++- .../pydantic_ai/models/mistral.py | 5 ++-- pydantic_ai_slim/pydantic_ai/models/openai.py | 10 ++++--- pydantic_ai_slim/pydantic_ai/models/test.py | 2 ++ .../pydantic_ai/models/wrapper.py | 6 ++++- tests/models/test_instrumented.py | 2 ++ 17 files changed, 68 insertions(+), 39 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 312a8a2fc..0dea48dab 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -325,13 +325,9 @@ async def _stream( ) -> AsyncIterator[models.StreamedResponse]: assert not self._did_stream, 'stream() should only be called once per node' - model_settings, model_request_parameters = await self._prepare_request(ctx) - model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters) - message_history = await _process_message_history( - ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx) - ) + model_settings, model_request_parameters, message_history, run_context = await self._prepare_request(ctx) async with ctx.deps.model.request_stream( - message_history, model_settings, model_request_parameters + message_history, model_settings, model_request_parameters, run_context ) as streamed_response: self._did_stream = True ctx.state.usage.requests += 1 @@ -351,11 +347,7 @@ async def _make_request( if self._result is not None: return self._result # pragma: no cover - model_settings, model_request_parameters = await self._prepare_request(ctx) - model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters) - message_history = await _process_message_history( - ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx) - ) + model_settings, model_request_parameters, message_history, _ = await self._prepare_request(ctx) model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters) ctx.state.usage.incr(_usage.Usage()) @@ -363,7 +355,7 @@ async def _make_request( async def _prepare_request( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] - ) -> tuple[ModelSettings | None, models.ModelRequestParameters]: + ) -> tuple[ModelSettings | None, models.ModelRequestParameters, list[_messages.ModelMessage], RunContext[DepsT]]: ctx.state.message_history.append(self.request) # Check usage @@ -373,9 +365,18 @@ async def _prepare_request( # Increment run_step ctx.state.run_step += 1 + run_context = build_run_context(ctx) + model_settings = merge_model_settings(ctx.deps.model_settings, None) + model_request_parameters = await _prepare_request_parameters(ctx) - return model_settings, model_request_parameters + model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters) + + message_history = await _process_message_history( + ctx.state.message_history, ctx.deps.history_processors, run_context + ) + + return model_settings, model_request_parameters, message_history, run_context def _finish_handling( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 6193d1d41..f998d9007 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -13,19 +13,19 @@ from dataclasses import dataclass, field, replace from datetime import datetime from functools import cache, cached_property -from typing import Generic, TypeVar, overload +from typing import Any, Generic, TypeVar, overload import httpx from typing_extensions import Literal, TypeAliasType, TypedDict -from pydantic_ai.profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec - from .. import _utils from .._output import OutputObjectDefinition from .._parts_manager import ModelResponsePartsManager +from .._run_context import RunContext from ..exceptions import UserError from ..messages import FileUrl, ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent, VideoUrl from ..output import OutputMode +from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec from ..profiles._json_schema import JsonSchemaTransformer from ..settings import ModelSettings from ..tools import ToolDefinition @@ -379,6 +379,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: """Make a request to the model and return a streaming response.""" # This method is not required, but you need to implement it if you want to support streamed responses diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index a62741568..3006cfde6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -11,6 +11,7 @@ from typing_extensions import assert_never from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage +from .._run_context import RunContext from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( BinaryContent, @@ -171,6 +172,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() response = await self._messages_create( diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index f16f9d111..3fde8d252 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -15,6 +15,7 @@ from typing_extensions import ParamSpec, assert_never from pydantic_ai import _utils, usage +from pydantic_ai._run_context import RunContext from pydantic_ai.messages import ( AudioUrl, BinaryContent, @@ -264,6 +265,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: settings = cast(BedrockModelSettings, model_settings or {}) response = await self._messages_create(messages, True, settings, model_request_parameters) diff --git a/pydantic_ai_slim/pydantic_ai/models/fallback.py b/pydantic_ai_slim/pydantic_ai/models/fallback.py index 4455defce..498d8e1bd 100644 --- a/pydantic_ai_slim/pydantic_ai/models/fallback.py +++ b/pydantic_ai_slim/pydantic_ai/models/fallback.py @@ -3,10 +3,11 @@ from collections.abc import AsyncIterator from contextlib import AsyncExitStack, asynccontextmanager, suppress from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Any, Callable from opentelemetry.trace import get_current_span +from pydantic_ai._run_context import RunContext from pydantic_ai.models.instrumented import InstrumentedModel from ..exceptions import FallbackExceptionGroup, ModelHTTPError @@ -83,6 +84,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: """Try each model in sequence until one succeeds.""" exceptions: list[Exception] = [] @@ -92,7 +94,7 @@ async def request_stream( async with AsyncExitStack() as stack: try: response = await stack.enter_async_context( - model.request_stream(messages, model_settings, customized_model_request_parameters) + model.request_stream(messages, model_settings, customized_model_request_parameters, run_context) ) except Exception as exc: if self._fallback_on(exc): diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index e8476b554..0fcf9e819 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -7,13 +7,12 @@ from dataclasses import dataclass, field from datetime import datetime from itertools import chain -from typing import Callable, Union +from typing import Any, Callable, Union from typing_extensions import TypeAlias, assert_never, overload -from pydantic_ai.profiles import ModelProfileSpec - from .. import _utils, usage +from .._run_context import RunContext from .._utils import PeekableAsyncStream from ..messages import ( AudioUrl, @@ -32,6 +31,7 @@ UserContent, UserPromptPart, ) +from ..profiles import ModelProfileSpec from ..settings import ModelSettings from ..tools import ToolDefinition from . import Model, ModelRequestParameters, StreamedResponse @@ -147,6 +147,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: agent_info = AgentInfo( model_request_parameters.function_tools, diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 7c371a943..feb6cf10e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -13,10 +13,9 @@ from httpx import USE_CLIENT_DEFAULT, Response as HTTPResponse from typing_extensions import NotRequired, TypedDict, assert_never -from pydantic_ai.providers import Provider, infer_provider - from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage from .._output import OutputObjectDefinition +from .._run_context import RunContext from ..exceptions import UserError from ..messages import ( BinaryContent, @@ -36,6 +35,7 @@ VideoUrl, ) from ..profiles import ModelProfileSpec +from ..providers import Provider, infer_provider from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( @@ -164,6 +164,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() async with self._make_request( diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 9ec1260d4..7eac84113 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -12,6 +12,7 @@ from .. import UnexpectedModelBehavior, _utils, usage from .._output import OutputObjectDefinition +from .._run_context import RunContext from ..exceptions import UserError from ..messages import ( BinaryContent, @@ -187,6 +188,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() model_settings = cast(GoogleModelSettings, model_settings or {}) diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 92376b44d..23a1fe18c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -5,13 +5,13 @@ from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime -from typing import Literal, Union, cast, overload +from typing import Any, Literal, Union, cast, overload from typing_extensions import assert_never -from pydantic_ai._thinking_part import split_content_into_text_and_thinking - from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage +from .._run_context import RunContext +from .._thinking_part import split_content_into_text_and_thinking from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime from ..messages import ( BinaryContent, @@ -166,6 +166,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() response = await self._completions_create( diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 41d53ca62..fa6b2b57d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -5,14 +5,13 @@ from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import Literal, Union, cast, overload +from typing import Any, Literal, Union, cast, overload from typing_extensions import assert_never -from pydantic_ai._thinking_part import split_content_into_text_and_thinking -from pydantic_ai.providers import Provider, infer_provider - from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage +from .._run_context import RunContext +from .._thinking_part import split_content_into_text_and_thinking from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc from ..messages import ( AudioUrl, @@ -33,6 +32,7 @@ UserPromptPart, VideoUrl, ) +from ..providers import Provider, infer_provider from ..settings import ModelSettings from ..tools import ToolDefinition from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests @@ -146,6 +146,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() response = await self._completions_create( diff --git a/pydantic_ai_slim/pydantic_ai/models/instrumented.py b/pydantic_ai_slim/pydantic_ai/models/instrumented.py index 233020f6f..49a97bd49 100644 --- a/pydantic_ai_slim/pydantic_ai/models/instrumented.py +++ b/pydantic_ai_slim/pydantic_ai/models/instrumented.py @@ -18,6 +18,7 @@ from opentelemetry.util.types import AttributeValue from pydantic import TypeAdapter +from .._run_context import RunContext from ..messages import ( ModelMessage, ModelRequest, @@ -222,12 +223,13 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: with self._instrument(messages, model_settings, model_request_parameters) as finish: response_stream: StreamedResponse | None = None try: async with super().request_stream( - messages, model_settings, model_request_parameters + messages, model_settings, model_request_parameters, run_context ) as response_stream: yield response_stream finally: diff --git a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py index ebfaac92d..a4f649786 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py +++ b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py @@ -3,9 +3,10 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast from .. import _mcp, exceptions, usage +from .._run_context import RunContext from ..messages import ModelMessage, ModelResponse from ..settings import ModelSettings from . import Model, ModelRequestParameters, StreamedResponse @@ -76,6 +77,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: raise NotImplementedError('MCP Sampling does not support streaming') yield diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 46b627826..b475e3363 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -11,9 +11,9 @@ from httpx import Timeout from typing_extensions import assert_never -from pydantic_ai._thinking_part import split_content_into_text_and_thinking - from .. import ModelHTTPError, UnexpectedModelBehavior, _utils +from .._run_context import RunContext +from .._thinking_part import split_content_into_text_and_thinking from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc, number_to_datetime from ..messages import ( BinaryContent, @@ -172,6 +172,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: """Make a streaming request to the model from Pydantic AI call.""" check_allow_model_requests() diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 61d280fb2..0df5055e2 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -11,12 +11,10 @@ from pydantic import ValidationError from typing_extensions import assert_never -from pydantic_ai._thinking_part import split_content_into_text_and_thinking -from pydantic_ai.profiles.openai import OpenAIModelProfile -from pydantic_ai.providers import Provider, infer_provider - from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition +from .._run_context import RunContext +from .._thinking_part import split_content_into_text_and_thinking from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime from ..messages import ( AudioUrl, @@ -38,6 +36,8 @@ VideoUrl, ) from ..profiles import ModelProfileSpec +from ..profiles.openai import OpenAIModelProfile +from ..providers import Provider, infer_provider from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( @@ -244,6 +244,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() response = await self._completions_create( @@ -659,6 +660,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() response = await self._responses_create( diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index a80d551ff..d1b015078 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -12,6 +12,7 @@ from typing_extensions import assert_never from .. import _utils +from .._run_context import RunContext from ..messages import ( ModelMessage, ModelRequest, @@ -118,6 +119,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: self.last_model_request_parameters = model_request_parameters diff --git a/pydantic_ai_slim/pydantic_ai/models/wrapper.py b/pydantic_ai_slim/pydantic_ai/models/wrapper.py index cc91f9c72..9818ad603 100644 --- a/pydantic_ai_slim/pydantic_ai/models/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/models/wrapper.py @@ -6,6 +6,7 @@ from functools import cached_property from typing import Any +from .._run_context import RunContext from ..messages import ModelMessage, ModelResponse from ..profiles import ModelProfile from ..settings import ModelSettings @@ -35,8 +36,11 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: - async with self.wrapped.request_stream(messages, model_settings, model_request_parameters) as response_stream: + async with self.wrapped.request_stream( + messages, model_settings, model_request_parameters, run_context + ) as response_stream: yield response_stream def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters: diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index b952bf716..831f6f339 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -11,6 +11,7 @@ from opentelemetry._events import NoOpEventLoggerProvider from opentelemetry.trace import NoOpTracerProvider +from pydantic_ai._run_context import RunContext from pydantic_ai.messages import ( AudioUrl, BinaryContent, @@ -89,6 +90,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext | None = None, ) -> AsyncIterator[StreamedResponse]: yield MyResponseStream() From 5f6cfa77c8652a4046645e04023f39a3c5a87580 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 22 Jul 2025 23:36:54 +0000 Subject: [PATCH 05/11] Streaming with Temporal --- .../pydantic_ai/temporal/__init__.py | 16 +- .../pydantic_ai/temporal/agent.py | 4 +- .../pydantic_ai/temporal/function_toolset.py | 35 +--- .../pydantic_ai/temporal/mcp_server.py | 48 +---- .../pydantic_ai/temporal/model.py | 171 +++++++++++++----- temporal.py | 46 +++-- 6 files changed, 169 insertions(+), 151 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/temporal/__init__.py index 21f0ca90d..d6a060860 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/temporal/__init__.py @@ -11,10 +11,8 @@ class _TemporalRunContext(RunContext[AgentDepsT]): - _data: dict[str, Any] - def __init__(self, **kwargs: Any): - self._data = kwargs + self.__dict__ = kwargs setattr( self, '__dataclass_fields__', @@ -25,10 +23,12 @@ def __getattribute__(self, name: str) -> Any: try: return super().__getattribute__(name) except AttributeError as e: - data = super().__getattribute__('_data') - if name in data: - return data[name] - raise e # TODO: Explain how to make a new run context attribute available + if name in RunContext.__dataclass_fields__: + raise AttributeError( + f'Temporalized {RunContext.__name__!r} object has no attribute {name!r}. To make the attribute available, pass a `TemporalSettings` object to `temporalize_agent` that has a custom `serialize_run_context` function that returns a dictionary that includes the attribute.' + ) + else: + raise e @classmethod def serialize_run_context(cls, ctx: RunContext[AgentDepsT]) -> dict[str, Any]: @@ -75,7 +75,7 @@ def for_tool(self, toolset_id: str, tool_id: str) -> TemporalSettings: deserialize_run_context: Callable[[dict[str, Any]], RunContext] = _TemporalRunContext.deserialize_run_context @property - def execute_activity_kwargs(self) -> dict[str, Any]: + def execute_activity_options(self) -> dict[str, Any]: return { 'task_queue': self.task_queue, 'schedule_to_close_timeout': self.schedule_to_close_timeout, diff --git a/pydantic_ai_slim/pydantic_ai/temporal/agent.py b/pydantic_ai_slim/pydantic_ai/temporal/agent.py index b8a0fc538..5f59d51fc 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/agent.py +++ b/pydantic_ai_slim/pydantic_ai/temporal/agent.py @@ -30,7 +30,7 @@ def temporalize_toolset(toolset: AbstractToolset, settings: TemporalSettings | N def temporalize_agent( - agent: Agent, + agent: Agent[Any, Any], settings: TemporalSettings | None = None, temporalize_toolset_func: Callable[ [AbstractToolset, TemporalSettings | None], list[Callable[..., Any]] @@ -52,7 +52,7 @@ def temporalize_agent( activities: list[Callable[..., Any]] = [] if isinstance(agent.model, Model): - activities.extend(temporalize_model(agent.model, settings)) + activities.extend(temporalize_model(agent.model, settings, agent._event_stream_handler)) # pyright: ignore[reportPrivateUsage] def temporalize_toolset(toolset: AbstractToolset) -> None: activities.extend(temporalize_toolset_func(toolset, settings)) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py b/pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py index 842e86aaf..8b34a8fb5 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py @@ -57,7 +57,7 @@ async def call_tool(name: str, tool_args: dict[str, Any], ctx: RunContext, tool: return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] activity=call_tool_activity, arg=_CallToolParams(name=name, tool_args=tool_args, serialized_run_context=serialized_run_context), - **tool_settings.execute_activity_kwargs, + **tool_settings.execute_activity_options, ) toolset.call_tool = call_tool @@ -65,36 +65,3 @@ async def call_tool(name: str, tool_args: dict[str, Any], ctx: RunContext, tool: activities = [call_tool_activity] setattr(toolset, '__temporal_activities', activities) return activities - - -# class TemporalFunctionToolset(FunctionToolset[AgentDepsT]): -# def __init__( -# self, -# tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], -# max_retries: int = 1, -# temporal_settings: TemporalSettings | None = None, -# serialize_run_context: Callable[[RunContext[AgentDepsT]], Any] | None = None, -# deserialize_run_context: Callable[[Any], RunContext[AgentDepsT]] | None = None, -# ): -# super().__init__(tools, max_retries) -# self.temporal_settings = temporal_settings or TemporalSettings() -# self.serialize_run_context = serialize_run_context or TemporalRunContext[AgentDepsT].serialize_run_context -# self.deserialize_run_context = deserialize_run_context or TemporalRunContext[AgentDepsT].deserialize_run_context - -# @activity.defn(name='function_toolset_call_tool') -# async def call_tool_activity(params: FunctionCallToolParams) -> Any: -# ctx = self.deserialize_run_context(params.serialized_run_context) -# tool = (await self.get_tools(ctx))[params.name] -# return await FunctionToolset[AgentDepsT].call_tool(self, params.name, params.tool_args, ctx, tool) - -# self.call_tool_activity = call_tool_activity - -# async def call_tool( -# self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] -# ) -> Any: -# serialized_run_context = self.serialize_run_context(ctx) -# return await workflow.execute_activity( -# activity=self.call_tool_activity, -# arg=FunctionCallToolParams(name=name, tool_args=tool_args, serialized_run_context=serialized_run_context), -# **self.temporal_settings.__dict__, -# ) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py b/pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py index 8d3be7624..6a7248468 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py +++ b/pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py @@ -55,7 +55,7 @@ async def call_tool_activity(params: _CallToolParams) -> ToolResult: async def list_tools() -> list[mcp_types.Tool]: return await workflow.execute_activity( # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType] activity=list_tools_activity, - **settings.execute_activity_kwargs, + **settings.execute_activity_options, ) async def direct_call_tool( @@ -66,7 +66,7 @@ async def direct_call_tool( return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] activity=call_tool_activity, arg=_CallToolParams(name=name, tool_args=args, metadata=metadata), - **settings.for_tool(id, name).execute_activity_kwargs, + **settings.for_tool(id, name).execute_activity_options, ) server.list_tools = list_tools @@ -75,47 +75,3 @@ async def direct_call_tool( activities = [list_tools_activity, call_tool_activity] setattr(server, '__temporal_activities', activities) return activities - - -# class TemporalMCPServer(WrapperToolset[Any]): -# temporal_settings: TemporalSettings - -# @property -# def wrapped_server(self) -> MCPServer: -# assert isinstance(self.wrapped, MCPServer) -# return self.wrapped - -# def __init__(self, wrapped: MCPServer, temporal_settings: TemporalSettings | None = None): -# assert isinstance(self.wrapped, MCPServer) -# super().__init__(wrapped) -# self.temporal_settings = temporal_settings or TemporalSettings() - -# @activity.defn(name='mcp_server_list_tools') -# async def list_tools_activity() -> list[mcp_types.Tool]: -# return await self.wrapped_server.list_tools() - -# self.list_tools_activity = list_tools_activity - -# @activity.defn(name='mcp_server_call_tool') -# async def call_tool_activity(params: MCPCallToolParams) -> ToolResult: -# return await self.wrapped_server.direct_call_tool(params.name, params.tool_args, params.metadata) - -# self.call_tool_activity = call_tool_activity - -# async def list_tools(self) -> list[mcp_types.Tool]: -# return await workflow.execute_activity( -# activity=self.list_tools_activity, -# **self.temporal_settings.__dict__, -# ) - -# async def direct_call_tool( -# self, -# name: str, -# args: dict[str, Any], -# metadata: dict[str, Any] | None = None, -# ) -> ToolResult: -# return await workflow.execute_activity( -# activity=self.call_tool_activity, -# arg=MCPCallToolParams(name=name, tool_args=args, metadata=metadata), -# **self.temporal_settings.__dict__, -# ) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/model.py b/pydantic_ai_slim/pydantic_ai/temporal/model.py index 7f8a25e71..32b4967c0 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/model.py +++ b/pydantic_ai_slim/pydantic_ai/temporal/model.py @@ -3,17 +3,27 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass +from datetime import datetime from typing import Any, Callable from pydantic import ConfigDict, with_config from temporalio import activity, workflow +from .._run_context import RunContext +from ..agent import EventStreamHandler +from ..exceptions import UserError from ..messages import ( + FinalResultEvent, ModelMessage, ModelResponse, + ModelResponseStreamEvent, + PartStartEvent, + TextPart, + ToolCallPart, ) from ..models import Model, ModelRequestParameters, StreamedResponse from ..settings import ModelSettings +from ..usage import Usage from . import TemporalSettings @@ -23,14 +33,48 @@ class _RequestParams: messages: list[ModelMessage] model_settings: ModelSettings | None model_request_parameters: ModelRequestParameters + serialized_run_context: Any -def temporalize_model(model: Model, settings: TemporalSettings | None = None) -> list[Callable[..., Any]]: +class _TemporalStreamedResponse(StreamedResponse): + def __init__(self, response: ModelResponse): + self.response = response + + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: + return + # noinspection PyUnreachableCode + yield + + def get(self) -> ModelResponse: + """Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far.""" + return self.response + + def usage(self) -> Usage: + """Get the usage of the response so far. This will not be the final usage until the stream is exhausted.""" + return self.response.usage + + @property + def model_name(self) -> str: + """Get the model name of the response.""" + return self.response.model_name or '' + + @property + def timestamp(self) -> datetime: + """Get the timestamp of the response.""" + return self.response.timestamp + + +def temporalize_model( # noqa: C901 + model: Model, + settings: TemporalSettings | None = None, + event_stream_handler: EventStreamHandler | None = None, +) -> list[Callable[..., Any]]: """Temporalize a model. Args: model: The model to temporalize. settings: The temporal settings to use. + event_stream_handler: The event stream handler to use. """ if activities := getattr(model, '__temporal_activities', None): return activities @@ -38,11 +82,64 @@ def temporalize_model(model: Model, settings: TemporalSettings | None = None) -> settings = settings or TemporalSettings() original_request = model.request + original_request_stream = model.request_stream @activity.defn(name='model_request') async def request_activity(params: _RequestParams) -> ModelResponse: return await original_request(params.messages, params.model_settings, params.model_request_parameters) + @activity.defn(name='model_request_stream') + async def request_stream_activity(params: _RequestParams) -> ModelResponse: + run_context = settings.deserialize_run_context(params.serialized_run_context) + async with original_request_stream( + params.messages, params.model_settings, params.model_request_parameters, run_context + ) as streamed_response: + tool_defs = { + tool_def.name: tool_def + for tool_def in [ + *params.model_request_parameters.output_tools, + *params.model_request_parameters.function_tools, + ] + } + + async def aiter(): + def _get_final_result_event(e: ModelResponseStreamEvent) -> FinalResultEvent | None: + """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" + if isinstance(e, PartStartEvent): + new_part = e.part + if ( + isinstance(new_part, TextPart) and params.model_request_parameters.allow_text_output + ): # pragma: no branch + return FinalResultEvent(tool_name=None, tool_call_id=None) + elif isinstance(new_part, ToolCallPart) and (tool_def := tool_defs.get(new_part.tool_name)): + if tool_def.kind == 'output': + return FinalResultEvent( + tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id + ) + elif tool_def.kind == 'deferred': + return FinalResultEvent(tool_name=None, tool_call_id=None) + + # TODO: usage_checking_stream = _get_usage_checking_stream_response( + # self._raw_stream_response, self._usage_limits, self.usage + # ) + async for event in streamed_response: + yield event + if (final_result_event := _get_final_result_event(event)) is not None: + yield final_result_event + break + + # If we broke out of the above loop, we need to yield the rest of the events + # If we didn't, this will just be a no-op + async for event in streamed_response: + yield event + + assert event_stream_handler is not None + await event_stream_handler(run_context, aiter()) + + async for _ in streamed_response: + pass + return streamed_response.get() + async def request( messages: list[ModelMessage], model_settings: ModelSettings | None, @@ -51,9 +148,12 @@ async def request( return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] activity=request_activity, arg=_RequestParams( - messages=messages, model_settings=model_settings, model_request_parameters=model_request_parameters + messages=messages, + model_settings=model_settings, + model_request_parameters=model_request_parameters, + serialized_run_context=None, ), - **settings.execute_activity_kwargs, + **settings.execute_activity_options, ) @asynccontextmanager @@ -61,56 +161,29 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: - raise NotImplementedError('Cannot stream with temporal yet') - yield + if event_stream_handler is None: + raise UserError('Streaming with Temporal requires `Agent` to have an `event_stream_handler`') + if run_context is None: + raise UserError('Streaming with Temporal requires `request_stream` to be called with a `run_context`') + + serialized_run_context = settings.serialize_run_context(run_context) + response = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] + activity=request_stream_activity, + arg=_RequestParams( + messages=messages, + model_settings=model_settings, + model_request_parameters=model_request_parameters, + serialized_run_context=serialized_run_context, + ), + **settings.execute_activity_options, + ) + yield _TemporalStreamedResponse(response) model.request = request model.request_stream = request_stream - activities = [request_activity] + activities = [request_activity, request_stream_activity] setattr(model, '__temporal_activities', activities) return activities - - -# @dataclass -# class TemporalModel(WrapperModel): -# temporal_settings: TemporalSettings - -# def __init__( -# self, -# wrapped: Model | KnownModelName, -# temporal_settings: TemporalSettings | None = None, -# ) -> None: -# super().__init__(wrapped) -# self.temporal_settings = temporal_settings or TemporalSettings() - -# @activity.defn -# async def request_activity(params: ModelRequestParams) -> ModelResponse: -# return await self.wrapped.request(params.messages, params.model_settings, params.model_request_parameters) - -# self.request_activity = request_activity - -# async def request( -# self, -# messages: list[ModelMessage], -# model_settings: ModelSettings | None, -# model_request_parameters: ModelRequestParameters, -# ) -> ModelResponse: -# return await workflow.execute_activity( -# activity=self.request_activity, -# arg=ModelRequestParams( -# messages=messages, model_settings=model_settings, model_request_parameters=model_request_parameters -# ), -# **self.temporal_settings.__dict__, -# ) - -# @asynccontextmanager -# async def request_stream( -# self, -# messages: list[ModelMessage], -# model_settings: ModelSettings | None, -# model_request_parameters: ModelRequestParameters, -# ) -> AsyncIterator[StreamedResponse]: -# raise NotImplementedError('Cannot stream with temporal yet') -# yield diff --git a/temporal.py b/temporal.py index 340019b8c..057b427be 100644 --- a/temporal.py +++ b/temporal.py @@ -6,6 +6,7 @@ # /// import asyncio import random +from collections.abc import AsyncIterable from datetime import timedelta from temporalio import workflow @@ -14,11 +15,13 @@ from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig from temporalio.worker import Worker +from typing_extensions import TypedDict with workflow.unsafe.imports_passed_through(): from pydantic_ai import Agent + from pydantic_ai._run_context import RunContext from pydantic_ai.mcp import MCPServerStdio - from pydantic_ai.models.openai import OpenAIModel + from pydantic_ai.messages import AgentStreamEvent, HandleResponseEvent from pydantic_ai.temporal import ( TemporalSettings, initialize_temporal, @@ -28,10 +31,13 @@ initialize_temporal() - def get_uv_index(location: str) -> int: - return 3 + class Deps(TypedDict): + country: str - toolset = FunctionToolset(tools=[get_uv_index], id='uv_index') + def get_country(ctx: RunContext[Deps]) -> str: + return ctx.deps['country'] + + toolset = FunctionToolset[Deps](tools=[get_country], id='country') mcp_server = MCPServerStdio( 'python', ['-m', 'tests.mcp_server'], @@ -39,14 +45,26 @@ def get_uv_index(location: str) -> int: id='test', ) - model = OpenAIModel('gpt-4o') - my_agent = Agent(model=model, instructions='be helpful', toolsets=[toolset, mcp_server]) + async def event_stream_handler( + ctx: RunContext[Deps], + stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent], + ): + print(f'{ctx.run_step=}') + async for event in stream: + print(event) + + my_agent = Agent( + 'openai:gpt-4o', + toolsets=[toolset, mcp_server], + event_stream_handler=event_stream_handler, + deps_type=Deps, + ) temporal_settings = TemporalSettings( start_to_close_timeout=timedelta(seconds=60), - tool_settings={ - 'uv_index': { - 'get_uv_index': TemporalSettings(start_to_close_timeout=timedelta(seconds=110)), + tool_settings={ # TODO: Allow default temporal settings to be set for an entire toolset + 'country': { + 'get_country': TemporalSettings(start_to_close_timeout=timedelta(seconds=110)), }, }, ) @@ -68,8 +86,9 @@ def init_runtime_with_telemetry() -> Runtime: @workflow.defn class MyAgentWorkflow: @workflow.run - async def run(self, prompt: str) -> str: - return (await my_agent.run(prompt)).output + async def run(self, prompt: str, deps: Deps) -> str: + result = await my_agent.run(prompt, deps=deps) + return result.output async def main(): @@ -88,7 +107,10 @@ async def main(): ): output = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType] MyAgentWorkflow.run, - 'what is 2 plus the UV Index in Mexico City? and what is the product name?', + args=[ + 'what is the capital of the capital of the country? and what is the product name?', + Deps(country='Mexico'), + ], id=f'my-agent-workflow-id-{random.random()}', task_queue='my-agent-task-queue', ) From a1e96e6ff69c90df1471df68364e7d8bac30f7f4 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 24 Jul 2025 14:01:58 +0000 Subject: [PATCH 06/11] Fix google types issues by importing only google.genai.Client --- docs/models/google.md | 4 ++-- pydantic_ai_slim/pydantic_ai/models/google.py | 8 ++++---- pydantic_ai_slim/pydantic_ai/providers/google.py | 14 +++++++------- tests/conftest.py | 4 ++-- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/models/google.md b/docs/models/google.md index 2cc35d9c0..30816715a 100644 --- a/docs/models/google.md +++ b/docs/models/google.md @@ -104,14 +104,14 @@ You can supply a custom `GoogleProvider` instance using the `provider` argument This is useful if you're using a custom-compatible endpoint with the Google Generative Language API. ```python -from google import genai +from google.genai import Client from google.genai.types import HttpOptions from pydantic_ai import Agent from pydantic_ai.models.google import GoogleModel from pydantic_ai.providers.google import GoogleProvider -client = genai.Client( +client = Client( api_key='gemini-custom-api-key', http_options=HttpOptions(base_url='gemini-custom-base-url'), ) diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 7eac84113..80ec6af85 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -45,7 +45,7 @@ ) try: - from google import genai + from google.genai import Client from google.genai.types import ( ContentDict, ContentUnionDict, @@ -131,10 +131,10 @@ class GoogleModel(Model): Apart from `__init__`, all methods are private or match those of the base class. """ - client: genai.Client = field(repr=False) + client: Client = field(repr=False) _model_name: GoogleModelName = field(repr=False) - _provider: Provider[genai.Client] = field(repr=False) + _provider: Provider[Client] = field(repr=False) _url: str | None = field(repr=False) _system: str = field(default='google', repr=False) @@ -142,7 +142,7 @@ def __init__( self, model_name: GoogleModelName, *, - provider: Literal['google-gla', 'google-vertex'] | Provider[genai.Client] = 'google-gla', + provider: Literal['google-gla', 'google-vertex'] | Provider[Client] = 'google-gla', profile: ModelProfileSpec | None = None, settings: ModelSettings | None = None, ): diff --git a/pydantic_ai_slim/pydantic_ai/providers/google.py b/pydantic_ai_slim/pydantic_ai/providers/google.py index fc876fcff..70eaa6d86 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/google.py +++ b/pydantic_ai_slim/pydantic_ai/providers/google.py @@ -10,8 +10,8 @@ from pydantic_ai.providers import Provider try: - from google import genai from google.auth.credentials import Credentials + from google.genai import Client except ImportError as _import_error: raise ImportError( 'Please install the `google-genai` package to use the Google provider, ' @@ -19,7 +19,7 @@ ) from _import_error -class GoogleProvider(Provider[genai.Client]): +class GoogleProvider(Provider[Client]): """Provider for Google.""" @property @@ -31,7 +31,7 @@ def base_url(self) -> str: return str(self._client._api_client._http_options.base_url) # type: ignore[reportPrivateUsage] @property - def client(self) -> genai.Client: + def client(self) -> Client: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: @@ -50,7 +50,7 @@ def __init__( ) -> None: ... @overload - def __init__(self, *, client: genai.Client) -> None: ... + def __init__(self, *, client: Client) -> None: ... @overload def __init__(self, *, vertexai: bool = False) -> None: ... @@ -62,7 +62,7 @@ def __init__( credentials: Credentials | None = None, project: str | None = None, location: VertexAILocation | Literal['global'] | None = None, - client: genai.Client | None = None, + client: Client | None = None, vertexai: bool | None = None, ) -> None: """Create a new Google provider. @@ -95,13 +95,13 @@ def __init__( 'Set the `GOOGLE_API_KEY` environment variable or pass it via `GoogleProvider(api_key=...)`' 'to use the Google Generative Language API.' ) - self._client = genai.Client( + self._client = Client( vertexai=vertexai, api_key=api_key, http_options={'headers': {'User-Agent': get_user_agent()}}, ) else: - self._client = genai.Client( + self._client = Client( vertexai=vertexai, project=project or os.environ.get('GOOGLE_CLOUD_PROJECT'), # From https://github.com/pydantic/pydantic-ai/pull/2031/files#r2169682149: diff --git a/tests/conftest.py b/tests/conftest.py index 3ae576c63..16c39380b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -383,7 +383,7 @@ async def vertex_provider(): pytest.skip('Requires properly configured local google vertex config to pass') try: - from google import genai + from google.genai import Client from pydantic_ai.providers.google import GoogleProvider except ImportError: # pragma: lax no cover @@ -391,7 +391,7 @@ async def vertex_provider(): project = os.getenv('GOOGLE_PROJECT', 'pydantic-ai') location = os.getenv('GOOGLE_LOCATION', 'us-central1') - client = genai.Client(vertexai=True, project=project, location=location) + client = Client(vertexai=True, project=project, location=location) try: yield GoogleProvider(client=client) From 966e7f8e148d72b633ad09ec3decfc5e95a56400 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 24 Jul 2025 14:03:20 +0000 Subject: [PATCH 07/11] Import TypeAlias from typing_extensions for Python 3.9 --- pydantic_ai_slim/pydantic_ai/agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index a888204f2..be9adef30 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -10,11 +10,11 @@ from contextvars import ContextVar from copy import deepcopy from types import FrameType -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, TypeAlias, cast, final, overload +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload from opentelemetry.trace import NoOpTracer, use_span from pydantic.json_schema import GenerateJsonSchema -from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated +from typing_extensions import Literal, Never, Self, TypeAlias, TypeIs, TypeVar, deprecated from pydantic_graph import End, Graph, GraphRun, GraphRunContext from pydantic_graph._utils import get_event_loop From 090ec23d0e1abe89dc56b772fbd36a4b876b865c Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 24 Jul 2025 23:51:09 +0000 Subject: [PATCH 08/11] Start cleaning up temporal integration --- .../pydantic_ai/temporal/__init__.py | 155 +++++++----------- ...nction_toolset.py => _function_toolset.py} | 2 +- .../{mcp_server.py => _mcp_server.py} | 2 +- .../temporal/{model.py => _model.py} | 2 +- .../pydantic_ai/temporal/_run_context.py | 41 +++++ .../pydantic_ai/temporal/_settings.py | 57 +++++++ .../pydantic_ai/temporal/_toolset.py | 26 +++ .../pydantic_ai/temporal/agent.py | 66 -------- temporal.py | 136 ++++++++------- 9 files changed, 262 insertions(+), 225 deletions(-) rename pydantic_ai_slim/pydantic_ai/temporal/{function_toolset.py => _function_toolset.py} (98%) rename pydantic_ai_slim/pydantic_ai/temporal/{mcp_server.py => _mcp_server.py} (98%) rename pydantic_ai_slim/pydantic_ai/temporal/{model.py => _model.py} (99%) create mode 100644 pydantic_ai_slim/pydantic_ai/temporal/_run_context.py create mode 100644 pydantic_ai_slim/pydantic_ai/temporal/_settings.py create mode 100644 pydantic_ai_slim/pydantic_ai/temporal/_toolset.py delete mode 100644 pydantic_ai_slim/pydantic_ai/temporal/agent.py diff --git a/pydantic_ai_slim/pydantic_ai/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/temporal/__init__.py index d6a060860..f01028c17 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/temporal/__init__.py @@ -1,104 +1,65 @@ from __future__ import annotations -from dataclasses import dataclass -from datetime import timedelta +import contextlib from typing import Any, Callable -from temporalio.common import Priority, RetryPolicy -from temporalio.workflow import ActivityCancellationType, VersioningIntent - -from pydantic_ai._run_context import AgentDepsT, RunContext - - -class _TemporalRunContext(RunContext[AgentDepsT]): - def __init__(self, **kwargs: Any): - self.__dict__ = kwargs - setattr( - self, - '__dataclass_fields__', - {name: field for name, field in RunContext.__dataclass_fields__.items() if name in kwargs}, - ) - - def __getattribute__(self, name: str) -> Any: - try: - return super().__getattribute__(name) - except AttributeError as e: - if name in RunContext.__dataclass_fields__: - raise AttributeError( - f'Temporalized {RunContext.__name__!r} object has no attribute {name!r}. To make the attribute available, pass a `TemporalSettings` object to `temporalize_agent` that has a custom `serialize_run_context` function that returns a dictionary that includes the attribute.' - ) - else: - raise e - - @classmethod - def serialize_run_context(cls, ctx: RunContext[AgentDepsT]) -> dict[str, Any]: - return { - 'deps': ctx.deps, - 'retries': ctx.retries, - 'tool_call_id': ctx.tool_call_id, - 'tool_name': ctx.tool_name, - 'retry': ctx.retry, - 'run_step': ctx.run_step, - } - - @classmethod - def deserialize_run_context(cls, ctx: dict[str, Any]) -> RunContext[AgentDepsT]: - return cls(**ctx) - - -@dataclass -class TemporalSettings: - """Settings for Temporal `execute_activity` and Pydantic AI-specific Temporal activity behavior.""" - - # Temporal settings - task_queue: str | None = None - schedule_to_close_timeout: timedelta | None = None - schedule_to_start_timeout: timedelta | None = None - start_to_close_timeout: timedelta | None = None - heartbeat_timeout: timedelta | None = None - retry_policy: RetryPolicy | None = None - cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL - activity_id: str | None = None - versioning_intent: VersioningIntent | None = None - summary: str | None = None - priority: Priority = Priority.default - - # Pydantic AI specific - tool_settings: dict[str, dict[str, TemporalSettings]] | None = None - - def for_tool(self, toolset_id: str, tool_id: str) -> TemporalSettings: - if self.tool_settings is None: - return self - return self.tool_settings.get(toolset_id, {}).get(tool_id, self) - - serialize_run_context: Callable[[RunContext], Any] = _TemporalRunContext.serialize_run_context - deserialize_run_context: Callable[[dict[str, Any]], RunContext] = _TemporalRunContext.deserialize_run_context - - @property - def execute_activity_options(self) -> dict[str, Any]: - return { - 'task_queue': self.task_queue, - 'schedule_to_close_timeout': self.schedule_to_close_timeout, - 'schedule_to_start_timeout': self.schedule_to_start_timeout, - 'start_to_close_timeout': self.start_to_close_timeout, - 'heartbeat_timeout': self.heartbeat_timeout, - 'retry_policy': self.retry_policy, - 'cancellation_type': self.cancellation_type, - 'activity_id': self.activity_id, - 'versioning_intent': self.versioning_intent, - 'summary': self.summary, - 'priority': self.priority, - } +from temporalio import workflow + +from pydantic_ai.agent import Agent +from pydantic_ai.toolsets.abstract import AbstractToolset + +from ..models import Model +from ._model import temporalize_model +from ._run_context import TemporalRunContext +from ._settings import TemporalSettings +from ._toolset import temporalize_toolset + +__all__ = [ + 'initialize_temporal', + 'TemporalSettings', + 'TemporalRunContext', +] def initialize_temporal(): - """Explicitly import types without which Temporal will not be able to serialize/deserialize `ModelMessage`s.""" - from pydantic_ai.messages import ( # noqa F401 - ModelResponse, # pyright: ignore[reportUnusedImport] - ImageUrl, # pyright: ignore[reportUnusedImport] - AudioUrl, # pyright: ignore[reportUnusedImport] - DocumentUrl, # pyright: ignore[reportUnusedImport] - VideoUrl, # pyright: ignore[reportUnusedImport] - BinaryContent, # pyright: ignore[reportUnusedImport] - UserContent, # pyright: ignore[reportUnusedImport] - ) + """Initialize Temporal.""" + with workflow.unsafe.imports_passed_through(): + with contextlib.suppress(ModuleNotFoundError): + import pandas # pyright: ignore[reportUnusedImport] # noqa: F401 + + +def temporalize_agent( + agent: Agent[Any, Any], + settings: TemporalSettings | None = None, + temporalize_toolset_func: Callable[ + [AbstractToolset, TemporalSettings | None], list[Callable[..., Any]] + ] = temporalize_toolset, +) -> list[Callable[..., Any]]: + """Temporalize an agent. + + Args: + agent: The agent to temporalize. + settings: The temporal settings to use. + temporalize_toolset_func: The function to use to temporalize the toolsets. + """ + if existing_activities := getattr(agent, '__temporal_activities', None): + return existing_activities + + settings = settings or TemporalSettings() + + # TODO: Doesn't consider model/toolsets passed at iter time. + + activities: list[Callable[..., Any]] = [] + if isinstance(agent.model, Model): + activities.extend(temporalize_model(agent.model, settings, agent._event_stream_handler)) # pyright: ignore[reportPrivateUsage] + + def temporalize_toolset(toolset: AbstractToolset) -> None: + activities.extend(temporalize_toolset_func(toolset, settings)) + + agent.toolset.apply(temporalize_toolset) + + setattr(agent, '__temporal_activities', activities) + return activities + + +# TODO: untemporalize_agent diff --git a/pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py b/pydantic_ai_slim/pydantic_ai/temporal/_function_toolset.py similarity index 98% rename from pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py rename to pydantic_ai_slim/pydantic_ai/temporal/_function_toolset.py index 8b34a8fb5..2c371a1ff 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/temporal/_function_toolset.py @@ -10,7 +10,7 @@ from .._run_context import RunContext from ..toolsets import ToolsetTool -from . import TemporalSettings +from ._settings import TemporalSettings @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py b/pydantic_ai_slim/pydantic_ai/temporal/_mcp_server.py similarity index 98% rename from pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py rename to pydantic_ai_slim/pydantic_ai/temporal/_mcp_server.py index 6a7248468..6a93e0c34 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py +++ b/pydantic_ai_slim/pydantic_ai/temporal/_mcp_server.py @@ -9,7 +9,7 @@ from pydantic_ai.mcp import MCPServer, ToolResult -from . import TemporalSettings +from ._settings import TemporalSettings @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/temporal/model.py b/pydantic_ai_slim/pydantic_ai/temporal/_model.py similarity index 99% rename from pydantic_ai_slim/pydantic_ai/temporal/model.py rename to pydantic_ai_slim/pydantic_ai/temporal/_model.py index 32b4967c0..cbadf7bd1 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/model.py +++ b/pydantic_ai_slim/pydantic_ai/temporal/_model.py @@ -24,7 +24,7 @@ from ..models import Model, ModelRequestParameters, StreamedResponse from ..settings import ModelSettings from ..usage import Usage -from . import TemporalSettings +from ._settings import TemporalSettings @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/temporal/_run_context.py b/pydantic_ai_slim/pydantic_ai/temporal/_run_context.py new file mode 100644 index 000000000..8bc7029e6 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/temporal/_run_context.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import Any + +from pydantic_ai._run_context import AgentDepsT, RunContext + + +class TemporalRunContext(RunContext[AgentDepsT]): + def __init__(self, **kwargs: Any): + self.__dict__ = kwargs + setattr( + self, + '__dataclass_fields__', + {name: field for name, field in RunContext.__dataclass_fields__.items() if name in kwargs}, + ) + + def __getattribute__(self, name: str) -> Any: + try: + return super().__getattribute__(name) + except AttributeError as e: + if name in RunContext.__dataclass_fields__: + raise AttributeError( + f'Temporalized {RunContext.__name__!r} object has no attribute {name!r}. To make the attribute available, pass a `TemporalSettings` object to `temporalize_agent` that has a custom `serialize_run_context` function that returns a dictionary that includes the attribute.' + ) + else: + raise e + + @classmethod + def serialize_run_context(cls, ctx: RunContext[AgentDepsT]) -> dict[str, Any]: + return { + 'deps': ctx.deps, + 'retries': ctx.retries, + 'tool_call_id': ctx.tool_call_id, + 'tool_name': ctx.tool_name, + 'retry': ctx.retry, + 'run_step': ctx.run_step, + } + + @classmethod + def deserialize_run_context(cls, ctx: dict[str, Any]) -> RunContext[AgentDepsT]: + return cls(**ctx) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/_settings.py b/pydantic_ai_slim/pydantic_ai/temporal/_settings.py new file mode 100644 index 000000000..14c9d595e --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/temporal/_settings.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import timedelta +from typing import Any, Callable + +from temporalio.common import Priority, RetryPolicy +from temporalio.workflow import ActivityCancellationType, VersioningIntent + +from pydantic_ai._run_context import RunContext + +from ._run_context import TemporalRunContext + + +@dataclass +class TemporalSettings: + """Settings for Temporal `execute_activity` and Pydantic AI-specific Temporal activity behavior.""" + + # Temporal settings + task_queue: str | None = None + schedule_to_close_timeout: timedelta | None = None + schedule_to_start_timeout: timedelta | None = None + start_to_close_timeout: timedelta | None = None + heartbeat_timeout: timedelta | None = None + retry_policy: RetryPolicy | None = None + cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL + activity_id: str | None = None + versioning_intent: VersioningIntent | None = None + summary: str | None = None + priority: Priority = Priority.default + + # Pydantic AI specific + tool_settings: dict[str, dict[str, TemporalSettings]] | None = None + + def for_tool(self, toolset_id: str, tool_id: str) -> TemporalSettings: + if self.tool_settings is None: + return self + return self.tool_settings.get(toolset_id, {}).get(tool_id, self) + + serialize_run_context: Callable[[RunContext], Any] = TemporalRunContext.serialize_run_context + deserialize_run_context: Callable[[dict[str, Any]], RunContext] = TemporalRunContext.deserialize_run_context + + @property + def execute_activity_options(self) -> dict[str, Any]: + return { + 'task_queue': self.task_queue, + 'schedule_to_close_timeout': self.schedule_to_close_timeout, + 'schedule_to_start_timeout': self.schedule_to_start_timeout, + 'start_to_close_timeout': self.start_to_close_timeout, + 'heartbeat_timeout': self.heartbeat_timeout, + 'retry_policy': self.retry_policy, + 'cancellation_type': self.cancellation_type, + 'activity_id': self.activity_id, + 'versioning_intent': self.versioning_intent, + 'summary': self.summary, + 'priority': self.priority, + } diff --git a/pydantic_ai_slim/pydantic_ai/temporal/_toolset.py b/pydantic_ai_slim/pydantic_ai/temporal/_toolset.py new file mode 100644 index 000000000..289d90071 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/temporal/_toolset.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import Any, Callable + +from pydantic_ai.mcp import MCPServer +from pydantic_ai.toolsets.abstract import AbstractToolset +from pydantic_ai.toolsets.function import FunctionToolset + +from ._function_toolset import temporalize_function_toolset +from ._mcp_server import temporalize_mcp_server +from ._settings import TemporalSettings + + +def temporalize_toolset(toolset: AbstractToolset, settings: TemporalSettings | None) -> list[Callable[..., Any]]: + """Temporalize a toolset. + + Args: + toolset: The toolset to temporalize. + settings: The temporal settings to use. + """ + if isinstance(toolset, FunctionToolset): + return temporalize_function_toolset(toolset, settings) + elif isinstance(toolset, MCPServer): + return temporalize_mcp_server(toolset, settings) + else: + return [] diff --git a/pydantic_ai_slim/pydantic_ai/temporal/agent.py b/pydantic_ai_slim/pydantic_ai/temporal/agent.py deleted file mode 100644 index 5f59d51fc..000000000 --- a/pydantic_ai_slim/pydantic_ai/temporal/agent.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import annotations - -from typing import Any, Callable - -from pydantic_ai.agent import Agent -from pydantic_ai.mcp import MCPServer -from pydantic_ai.toolsets.abstract import AbstractToolset -from pydantic_ai.toolsets.function import FunctionToolset - -from ..models import Model -from . import TemporalSettings -from .function_toolset import temporalize_function_toolset -from .mcp_server import temporalize_mcp_server -from .model import temporalize_model - - -def temporalize_toolset(toolset: AbstractToolset, settings: TemporalSettings | None) -> list[Callable[..., Any]]: - """Temporalize a toolset. - - Args: - toolset: The toolset to temporalize. - settings: The temporal settings to use. - """ - if isinstance(toolset, FunctionToolset): - return temporalize_function_toolset(toolset, settings) - elif isinstance(toolset, MCPServer): - return temporalize_mcp_server(toolset, settings) - else: - return [] - - -def temporalize_agent( - agent: Agent[Any, Any], - settings: TemporalSettings | None = None, - temporalize_toolset_func: Callable[ - [AbstractToolset, TemporalSettings | None], list[Callable[..., Any]] - ] = temporalize_toolset, -) -> list[Callable[..., Any]]: - """Temporalize an agent. - - Args: - agent: The agent to temporalize. - settings: The temporal settings to use. - temporalize_toolset_func: The function to use to temporalize the toolsets. - """ - if existing_activities := getattr(agent, '__temporal_activities', None): - return existing_activities - - settings = settings or TemporalSettings() - - # TODO: Doesn't consider model/toolsets passed at iter time. - - activities: list[Callable[..., Any]] = [] - if isinstance(agent.model, Model): - activities.extend(temporalize_model(agent.model, settings, agent._event_stream_handler)) # pyright: ignore[reportPrivateUsage] - - def temporalize_toolset(toolset: AbstractToolset) -> None: - activities.extend(temporalize_toolset_func(toolset, settings)) - - agent.toolset.apply(temporalize_toolset) - - setattr(agent, '__temporal_activities', activities) - return activities - - -# TODO: untemporalize_agent diff --git a/temporal.py b/temporal.py index 057b427be..8b2baec60 100644 --- a/temporal.py +++ b/temporal.py @@ -9,80 +9,77 @@ from collections.abc import AsyncIterable from datetime import timedelta +import logfire +from opentelemetry import trace from temporalio import workflow from temporalio.client import Client from temporalio.contrib.opentelemetry import TracingInterceptor from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig from temporalio.worker import Worker +from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner, SandboxRestrictions from typing_extensions import TypedDict -with workflow.unsafe.imports_passed_through(): - from pydantic_ai import Agent - from pydantic_ai._run_context import RunContext - from pydantic_ai.mcp import MCPServerStdio - from pydantic_ai.messages import AgentStreamEvent, HandleResponseEvent - from pydantic_ai.temporal import ( - TemporalSettings, - initialize_temporal, - ) - from pydantic_ai.temporal.agent import temporalize_agent - from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai import Agent, RunContext +from pydantic_ai.mcp import MCPServerStdio +from pydantic_ai.messages import AgentStreamEvent, HandleResponseEvent +from pydantic_ai.temporal import ( + TemporalSettings, + initialize_temporal, + temporalize_agent, +) +from pydantic_ai.toolsets import FunctionToolset - initialize_temporal() +initialize_temporal() - class Deps(TypedDict): - country: str - def get_country(ctx: RunContext[Deps]) -> str: - return ctx.deps['country'] +class Deps(TypedDict): + country: str - toolset = FunctionToolset[Deps](tools=[get_country], id='country') - mcp_server = MCPServerStdio( - 'python', - ['-m', 'tests.mcp_server'], - timeout=20, - id='test', - ) - async def event_stream_handler( - ctx: RunContext[Deps], - stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent], - ): - print(f'{ctx.run_step=}') - async for event in stream: - print(event) - - my_agent = Agent( - 'openai:gpt-4o', - toolsets=[toolset, mcp_server], - event_stream_handler=event_stream_handler, - deps_type=Deps, - ) +def get_country(ctx: RunContext[Deps]) -> str: + return ctx.deps['country'] - temporal_settings = TemporalSettings( - start_to_close_timeout=timedelta(seconds=60), - tool_settings={ # TODO: Allow default temporal settings to be set for an entire toolset - 'country': { - 'get_country': TemporalSettings(start_to_close_timeout=timedelta(seconds=110)), - }, - }, - ) - activities = temporalize_agent(my_agent, temporal_settings) +toolset = FunctionToolset[Deps](tools=[get_country], id='country') +mcp_server = MCPServerStdio( + 'python', + ['-m', 'tests.mcp_server'], + timeout=20, + id='test', +) -def init_runtime_with_telemetry() -> Runtime: - # import logfire - # logfire.configure(send_to_logfire=True, service_version='0.0.1', console=False) - # logfire.instrument_pydantic_ai() - # logfire.instrument_httpx(capture_all=True) +async def event_stream_handler( + ctx: RunContext[Deps], + stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent], +): + logfire.info(f'{ctx.run_step=}') + async for event in stream: + logfire.info(f'{event=}') - # Setup SDK metrics to OTel endpoint - return Runtime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url='http://localhost:4318'))) + +my_agent = Agent( + 'openai:gpt-4o', + toolsets=[toolset, mcp_server], + event_stream_handler=event_stream_handler, + deps_type=Deps, +) + +temporal_settings = TemporalSettings( + start_to_close_timeout=timedelta(seconds=60), + tool_settings={ # TODO: Allow default temporal settings to be set for all activities in a toolset + 'country': { + 'get_country': TemporalSettings(start_to_close_timeout=timedelta(seconds=110)), + }, + }, +) +activities = temporalize_agent(my_agent, temporal_settings) + + +TASK_QUEUE = 'pydantic-ai-agent-task-queue' -# Basic workflow that logs and invokes an activity @workflow.defn class MyAgentWorkflow: @workflow.run @@ -92,18 +89,39 @@ async def run(self, prompt: str, deps: Deps) -> str: async def main(): + def init_runtime_with_telemetry() -> Runtime: + logfire.configure(console=False) + logfire.instrument_pydantic_ai() + logfire.instrument_httpx(capture_all=True) + + # Setup SDK metrics to OTel endpoint + return Runtime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url='http://localhost:4318'))) + client = await Client.connect( 'localhost:7233', - interceptors=[TracingInterceptor()], - data_converter=pydantic_data_converter, - runtime=init_runtime_with_telemetry(), + interceptors=[ # TODO: Use ClientPlugin.configure_client for this + TracingInterceptor(trace.get_tracer('temporal')) + ], + data_converter=pydantic_data_converter, # TODO: Use ClientPlugin.configure_client for this + runtime=init_runtime_with_telemetry(), # TODO: Use ClientPlugin.connect_service_client for this ) async with Worker( client, - task_queue='my-agent-task-queue', + task_queue=TASK_QUEUE, workflows=[MyAgentWorkflow], activities=activities, + workflow_runner=SandboxedWorkflowRunner( # TODO: Use WorkerPlugin.configure_worker for this, see https://github.com/temporalio/sdk-python/blob/da6616a93e9ee5170842bb5a056e2383e18d07c6/tests/test_plugins.py#L71 + restrictions=SandboxRestrictions.default.with_passthrough_modules( + 'pydantic_ai', + 'logfire', # TODO: Only if module available? + # Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize + 'attrs', + # Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize + 'numpy', # TODO: Only if module available? + 'pandas', # TODO: Only if module available? + ), + ), ): output = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType] MyAgentWorkflow.run, @@ -112,7 +130,7 @@ async def main(): Deps(country='Mexico'), ], id=f'my-agent-workflow-id-{random.random()}', - task_queue='my-agent-task-queue', + task_queue=TASK_QUEUE, ) print(output) From 4c87691603a1b4cc444493b6464517a443a5d045 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 25 Jul 2025 17:34:55 +0000 Subject: [PATCH 09/11] with_passthrough_modules doesn't import itself --- temporal.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/temporal.py b/temporal.py index 8b2baec60..ea64f6797 100644 --- a/temporal.py +++ b/temporal.py @@ -114,12 +114,12 @@ def init_runtime_with_telemetry() -> Runtime: workflow_runner=SandboxedWorkflowRunner( # TODO: Use WorkerPlugin.configure_worker for this, see https://github.com/temporalio/sdk-python/blob/da6616a93e9ee5170842bb5a056e2383e18d07c6/tests/test_plugins.py#L71 restrictions=SandboxRestrictions.default.with_passthrough_modules( 'pydantic_ai', - 'logfire', # TODO: Only if module available? + 'logfire', # Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize 'attrs', # Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize - 'numpy', # TODO: Only if module available? - 'pandas', # TODO: Only if module available? + 'numpy', + 'pandas', ), ), ): From 694fa6b2f8705ab0988cdae4197cff579ef1ee9b Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 25 Jul 2025 20:56:49 +0000 Subject: [PATCH 10/11] Use Temporal plugins --- .../pydantic_ai/temporal/__init__.py | 77 ++++++++++++++++--- pyproject.toml | 1 + temporal.py | 53 +++---------- uv.lock | 23 +++--- 4 files changed, 92 insertions(+), 62 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/temporal/__init__.py index f01028c17..32ea06775 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/temporal/__init__.py @@ -1,9 +1,18 @@ from __future__ import annotations -import contextlib +from collections.abc import Sequence +from dataclasses import replace from typing import Any, Callable -from temporalio import workflow +import logfire # TODO: Not always available +from opentelemetry import trace # TODO: Not always available +from temporalio.client import ClientConfig, Plugin as ClientPlugin +from temporalio.contrib.opentelemetry import TracingInterceptor +from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig +from temporalio.service import ConnectConfig, ServiceClient +from temporalio.worker import Plugin as WorkerPlugin, WorkerConfig +from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner from pydantic_ai.agent import Agent from pydantic_ai.toolsets.abstract import AbstractToolset @@ -15,17 +24,66 @@ from ._toolset import temporalize_toolset __all__ = [ - 'initialize_temporal', 'TemporalSettings', 'TemporalRunContext', + 'PydanticAIPlugin', + 'LogfirePlugin', + 'AgentPlugin', ] -def initialize_temporal(): - """Initialize Temporal.""" - with workflow.unsafe.imports_passed_through(): - with contextlib.suppress(ModuleNotFoundError): - import pandas # pyright: ignore[reportUnusedImport] # noqa: F401 +class PydanticAIPlugin(ClientPlugin, WorkerPlugin): + """Temporal client and worker plugin for Pydantic AI.""" + + def configure_client(self, config: ClientConfig) -> ClientConfig: + config['data_converter'] = pydantic_data_converter + return super().configure_client(config) + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + runner = config.get('workflow_runner') # pyright: ignore[reportUnknownMemberType] + if isinstance(runner, SandboxedWorkflowRunner): + config['workflow_runner'] = replace( + runner, + restrictions=runner.restrictions.with_passthrough_modules( + 'pydantic_ai', + 'logfire', + # Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize + 'attrs', + # Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize + 'numpy', + 'pandas', + ), + ) + return super().configure_worker(config) + + +class LogfirePlugin(ClientPlugin): + """Temporal client plugin for Logfire.""" + + def configure_client(self, config: ClientConfig) -> ClientConfig: + config['interceptors'] = [TracingInterceptor(trace.get_tracer('temporal'))] + return super().configure_client(config) + + async def connect_service_client(self, config: ConnectConfig) -> ServiceClient: + # TODO: Do we need this here? + logfire.configure(console=False) + logfire.instrument_pydantic_ai() + logfire.instrument_httpx(capture_all=True) + + config.runtime = Runtime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url='http://localhost:4318'))) + return await super().connect_service_client(config) + + +class AgentPlugin(WorkerPlugin): + """Temporal worker plugin for a specific Pydantic AI agent.""" + + def __init__(self, agent: Agent[Any, Any], settings: TemporalSettings | None = None): + self.activities = temporalize_agent(agent, settings) + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType] + config['activities'] = [*activities, *self.activities] + return super().configure_worker(config) def temporalize_agent( @@ -47,7 +105,8 @@ def temporalize_agent( settings = settings or TemporalSettings() - # TODO: Doesn't consider model/toolsets passed at iter time. + # TODO: Doesn't consider model/toolsets passed at iter time, raise an error if that happens. + # Similarly, passing event_stream_handler at iter time should raise an error. activities: list[Callable[..., Any]] = [] if isinstance(agent.model, Model): diff --git a/pyproject.toml b/pyproject.toml index fdd67f183..e3f2ed07f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ pydantic-ai-slim = { workspace = true } pydantic-evals = { workspace = true } pydantic-graph = { workspace = true } pydantic-ai-examples = { workspace = true } +temporalio = { git = "https://github.com/temporalio/sdk-python.git", rev = "main" } [tool.uv.workspace] members = [ diff --git a/temporal.py b/temporal.py index ea64f6797..2b8a4a319 100644 --- a/temporal.py +++ b/temporal.py @@ -1,37 +1,25 @@ -# /// script -# dependencies = [ -# "temporalio", -# "logfire", -# ] -# /// import asyncio import random from collections.abc import AsyncIterable from datetime import timedelta import logfire -from opentelemetry import trace from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.opentelemetry import TracingInterceptor -from temporalio.contrib.pydantic import pydantic_data_converter -from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig from temporalio.worker import Worker -from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner, SandboxRestrictions from typing_extensions import TypedDict from pydantic_ai import Agent, RunContext from pydantic_ai.mcp import MCPServerStdio from pydantic_ai.messages import AgentStreamEvent, HandleResponseEvent from pydantic_ai.temporal import ( + AgentPlugin, + LogfirePlugin, + PydanticAIPlugin, TemporalSettings, - initialize_temporal, - temporalize_agent, ) from pydantic_ai.toolsets import FunctionToolset -initialize_temporal() - class Deps(TypedDict): country: str @@ -59,7 +47,7 @@ async def event_stream_handler( logfire.info(f'{event=}') -my_agent = Agent( +agent = Agent( 'openai:gpt-4o', toolsets=[toolset, mcp_server], event_stream_handler=event_stream_handler, @@ -74,7 +62,6 @@ async def event_stream_handler( }, }, ) -activities = temporalize_agent(my_agent, temporal_settings) TASK_QUEUE = 'pydantic-ai-agent-task-queue' @@ -84,44 +71,26 @@ async def event_stream_handler( class MyAgentWorkflow: @workflow.run async def run(self, prompt: str, deps: Deps) -> str: - result = await my_agent.run(prompt, deps=deps) + result = await agent.run(prompt, deps=deps) return result.output -async def main(): - def init_runtime_with_telemetry() -> Runtime: - logfire.configure(console=False) - logfire.instrument_pydantic_ai() - logfire.instrument_httpx(capture_all=True) +# TODO: For some reason, when I put this (specifically the temporalize_agent call) inside `async def main()`, +# we get tons of errors. +plugin = AgentPlugin(agent, temporal_settings) - # Setup SDK metrics to OTel endpoint - return Runtime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url='http://localhost:4318'))) +async def main(): client = await Client.connect( 'localhost:7233', - interceptors=[ # TODO: Use ClientPlugin.configure_client for this - TracingInterceptor(trace.get_tracer('temporal')) - ], - data_converter=pydantic_data_converter, # TODO: Use ClientPlugin.configure_client for this - runtime=init_runtime_with_telemetry(), # TODO: Use ClientPlugin.connect_service_client for this + plugins=[PydanticAIPlugin(), LogfirePlugin()], ) async with Worker( client, task_queue=TASK_QUEUE, workflows=[MyAgentWorkflow], - activities=activities, - workflow_runner=SandboxedWorkflowRunner( # TODO: Use WorkerPlugin.configure_worker for this, see https://github.com/temporalio/sdk-python/blob/da6616a93e9ee5170842bb5a056e2383e18d07c6/tests/test_plugins.py#L71 - restrictions=SandboxRestrictions.default.with_passthrough_modules( - 'pydantic_ai', - 'logfire', - # Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize - 'attrs', - # Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize - 'numpy', - 'pandas', - ), - ), + plugins=[plugin], ): output = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType] MyAgentWorkflow.run, diff --git a/uv.lock b/uv.lock index b33d028f9..dcc5090da 100644 --- a/uv.lock +++ b/uv.lock @@ -2286,6 +2286,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/e2/5d3f6ada4297caebe1a2add3b126fe800c96f56dbe5d1988a2cbe0b267aa/mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d", size = 4695, upload-time = "2023-02-04T12:11:25.002Z" }, ] +[[package]] +name = "nexus-rpc" +version = "1.1.0" +source = { git = "https://github.com/nexus-rpc/sdk-python.git?rev=35f574c711193a6e2560d3e6665732a5bb7ae92c#35f574c711193a6e2560d3e6665732a5bb7ae92c" } +dependencies = [ + { name = "typing-extensions" }, +] + [[package]] name = "nodeenv" version = "1.9.1" @@ -3243,7 +3251,7 @@ requires-dist = [ { name = "rich", marker = "extra == 'cli'", specifier = ">=13" }, { name = "starlette", marker = "extra == 'ag-ui'", specifier = ">=0.45.3" }, { name = "tavily-python", marker = "extra == 'tavily'", specifier = ">=0.5.0" }, - { name = "temporalio", marker = "extra == 'temporal'", specifier = ">=1.13.0" }, + { name = "temporalio", marker = "extra == 'temporal'", git = "https://github.com/temporalio/sdk-python.git?rev=main" }, { name = "typing-inspection", specifier = ">=0.4.0" }, ] provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "tavily", "temporal", "vertexai"] @@ -4160,22 +4168,15 @@ wheels = [ [[package]] name = "temporalio" -version = "1.13.0" -source = { registry = "https://pypi.org/simple" } +version = "1.14.1" +source = { git = "https://github.com/temporalio/sdk-python.git?rev=main#e767013acca543345e0408a167556bbb987eb130" } dependencies = [ + { name = "nexus-rpc" }, { name = "protobuf" }, { name = "python-dateutil", marker = "python_full_version < '3.11'" }, { name = "types-protobuf" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/3e/a3/a76477b523937f47a21941188c16b3c6b1eef6baadc7c8efeea497d909de/temporalio-1.13.0.tar.gz", hash = "sha256:5a979eee5433da6ab5d8a2bcde25a1e7d454e91920acb0bf7ca93d415750828b", size = 1558745, upload-time = "2025-06-20T19:57:26.944Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f5/f4/a5a74284c671bd50ce7353ad1dad7dab1a795f891458454049e95bc5378f/temporalio-1.13.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:7ee14cab581352e77171d1e4ce01a899231abfe75c5f7233e3e260f361a344cc", size = 12086961, upload-time = "2025-06-20T19:57:15.25Z" }, - { url = "https://files.pythonhosted.org/packages/1f/b7/5dc6e34f4e9a3da8b75cb3fe0d32edca1d9201d598c38d022501d38650a9/temporalio-1.13.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:575a0c57dbb089298b4775f3aca86ebaf8d58d5ba155e7fc5509877c25e6bb44", size = 11745239, upload-time = "2025-06-20T19:57:17.934Z" }, - { url = "https://files.pythonhosted.org/packages/04/30/4b9b15af87c181fd9364b61971faa0faa07d199320d7ff1712b5d51b5bbb/temporalio-1.13.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf099a27f22c0dbc22f3d86dba76d59be5da812ff044ba3fa183e3e14bd5e9a", size = 12119197, upload-time = "2025-06-20T19:57:20.509Z" }, - { url = "https://files.pythonhosted.org/packages/46/9f/a5b627d773974c654b6cd22ed3937e7e2471023af244ea417f0e917e617b/temporalio-1.13.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7e20c711f41c66877b9d54ab33c79a14ccaac9ed498a174274f6129110f4d84", size = 12413459, upload-time = "2025-06-20T19:57:22.816Z" }, - { url = "https://files.pythonhosted.org/packages/a3/73/efb6957212eb8c8dfff26c7c2c6ddf745aa5990a3b722cff17c8feaa66fc/temporalio-1.13.0-cp39-abi3-win_amd64.whl", hash = "sha256:9286cb84c1e078b2bcc6e8c6bd0be878d8ed395be991ac0d7cff555e3a82ac0b", size = 12440644, upload-time = "2025-06-20T19:57:25.175Z" }, -] [[package]] name = "tenacity" From 5e858d310e0034b37707d3febc5cb343dc82e65c Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 25 Jul 2025 23:20:46 +0000 Subject: [PATCH 11/11] Polish --- .../pydantic_ai/ext/temporal/__init__.py | 70 ++++++++++ .../pydantic_ai/ext/temporal/_agent.py | 120 +++++++++++++++++ .../{ => ext}/temporal/_function_toolset.py | 49 +++++-- .../pydantic_ai/ext/temporal/_logfire.py | 34 +++++ .../{ => ext}/temporal/_mcp_server.py | 28 +++- .../pydantic_ai/{ => ext}/temporal/_model.py | 49 +++++-- .../pydantic_ai/ext/temporal/_run_context.py | 68 ++++++++++ .../{ => ext}/temporal/_settings.py | 22 ++-- .../pydantic_ai/ext/temporal/_toolset.py | 41 ++++++ .../pydantic_ai/temporal/__init__.py | 124 ------------------ .../pydantic_ai/temporal/_run_context.py | 41 ------ .../pydantic_ai/temporal/_toolset.py | 26 ---- temporal.py | 67 +++++----- 13 files changed, 472 insertions(+), 267 deletions(-) create mode 100644 pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py create mode 100644 pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py rename pydantic_ai_slim/pydantic_ai/{ => ext}/temporal/_function_toolset.py (51%) create mode 100644 pydantic_ai_slim/pydantic_ai/ext/temporal/_logfire.py rename pydantic_ai_slim/pydantic_ai/{ => ext}/temporal/_mcp_server.py (72%) rename pydantic_ai_slim/pydantic_ai/{ => ext}/temporal/_model.py (82%) create mode 100644 pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py rename pydantic_ai_slim/pydantic_ai/{ => ext}/temporal/_settings.py (71%) create mode 100644 pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py delete mode 100644 pydantic_ai_slim/pydantic_ai/temporal/__init__.py delete mode 100644 pydantic_ai_slim/pydantic_ai/temporal/_run_context.py delete mode 100644 pydantic_ai_slim/pydantic_ai/temporal/_toolset.py diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py new file mode 100644 index 000000000..f1ec642a5 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import replace +from typing import Any, Callable + +from temporalio.client import ClientConfig, Plugin as ClientPlugin +from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.worker import Plugin as WorkerPlugin, WorkerConfig +from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner + +from pydantic_ai.agent import Agent + +from ._agent import temporalize_agent, untemporalize_agent +from ._logfire import LogfirePlugin +from ._run_context import TemporalRunContext +from ._settings import TemporalSettings + +__all__ = [ + 'TemporalSettings', + 'TemporalRunContext', + 'PydanticAIPlugin', + 'LogfirePlugin', + 'AgentPlugin', + 'temporalize_agent', + 'untemporalize_agent', +] + + +class PydanticAIPlugin(ClientPlugin, WorkerPlugin): + """Temporal client and worker plugin for Pydantic AI.""" + + def configure_client(self, config: ClientConfig) -> ClientConfig: + config['data_converter'] = pydantic_data_converter + return super().configure_client(config) + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + runner = config.get('workflow_runner') # pyright: ignore[reportUnknownMemberType] + if isinstance(runner, SandboxedWorkflowRunner): + config['workflow_runner'] = replace( + runner, + restrictions=runner.restrictions.with_passthrough_modules( + 'pydantic_ai', + 'logfire', + # Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize + 'attrs', + # Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize + 'numpy', + 'pandas', + ), + ) + return super().configure_worker(config) + + +class AgentPlugin(WorkerPlugin): + """Temporal worker plugin for a specific Pydantic AI agent.""" + + def __init__(self, agent: Agent[Any, Any]): + self.agent = agent + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + agent_activities = getattr(self.agent, '__temporal_activities', None) + if agent_activities is None: + raise ValueError( + 'The agent has not been temporalized yet, call `temporalize_agent(agent)` (or `with temporalized_agent(agent): ...`) first.' + ) + + activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType] + config['activities'] = [*activities, *agent_activities] + return super().configure_worker(config) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py new file mode 100644 index 000000000..f3717b6da --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from collections.abc import Generator +from contextlib import contextmanager +from typing import Any, Callable + +from pydantic_ai.agent import Agent +from pydantic_ai.models import Model +from pydantic_ai.toolsets.abstract import AbstractToolset + +from ._model import temporalize_model, untemporalize_model +from ._settings import TemporalSettings +from ._toolset import temporalize_toolset, untemporalize_toolset + + +def temporalize_agent( + agent: Agent[Any, Any], + settings: TemporalSettings | None = None, + toolset_settings: dict[str, TemporalSettings] = {}, + tool_settings: dict[str, dict[str, TemporalSettings]] = {}, + temporalize_toolset_func: Callable[ + [AbstractToolset, TemporalSettings | None, dict[str, TemporalSettings]], list[Callable[..., Any]] + ] = temporalize_toolset, +) -> list[Callable[..., Any]]: + """Temporalize an agent. + + Args: + agent: The agent to temporalize. + settings: The temporal settings to use. + toolset_settings: The temporal settings to use for specific toolsets identified by ID. + tool_settings: The temporal settings to use for specific tools identified by toolset ID and tool name. + temporalize_toolset_func: The function to use to temporalize the toolsets. + """ + if existing_activities := getattr(agent, '__temporal_activities', None): + return existing_activities + + settings = settings or TemporalSettings() + + activities: list[Callable[..., Any]] = [] + if isinstance(agent.model, Model): + activities.extend(temporalize_model(agent.model, settings, agent._event_stream_handler)) # pyright: ignore[reportPrivateUsage] + + def temporalize_toolset(toolset: AbstractToolset) -> None: + id = toolset.id + if not id: + raise ValueError( + "A toolset needs to have an ID in order to be used with Temporal. The ID will be used to identify the toolset's activities within the workflow." + ) + activities.extend( + temporalize_toolset_func(toolset, settings.merge(toolset_settings.get(id)), tool_settings.get(id, {})) + ) + + agent.toolset.apply(temporalize_toolset) + + original_iter = agent.iter + original_override = agent.override + setattr(agent, '__original_iter', original_iter) + setattr(agent, '__original_override', original_override) + + def iter(*args: Any, **kwargs: Any) -> Any: + if kwargs.get('model') is not None: + raise ValueError( + 'Model cannot be set at agent run time when using Temporal, it must be set at agent creation time.' + ) + if kwargs.get('toolsets') is not None: + raise ValueError( + 'Toolsets cannot be set at agent run time when using Temporal, it must be set at agent creation time.' + ) + if kwargs.get('event_stream_handler') is not None: + raise ValueError( + 'Event stream handler cannot be set at agent run time when using Temporal, it must be set at agent creation time.' + ) + + return original_iter(*args, **kwargs) + + def override(*args: Any, **kwargs: Any) -> Any: + if kwargs.get('model') is not None: + raise ValueError('Model cannot be overridden when using Temporal, it must be set at agent creation time.') + if kwargs.get('toolsets') is not None: + raise ValueError( + 'Toolsets cannot be overridden when using Temporal, it must be set at agent creation time.' + ) + return original_override(*args, **kwargs) + + agent.iter = iter + agent.override = override + + setattr(agent, '__temporal_activities', activities) + return activities + + +def untemporalize_agent(agent: Agent[Any, Any]) -> None: + """Untemporalize an agent. + + Args: + agent: The agent to untemporalize. + """ + if not hasattr(agent, '__temporal_activities'): + return + + if isinstance(agent.model, Model): + untemporalize_model(agent.model) + + agent.toolset.apply(untemporalize_toolset) + + agent.iter = getattr(agent, '__original_iter') + agent.override = getattr(agent, '__original_override') + delattr(agent, '__original_iter') + delattr(agent, '__original_override') + + delattr(agent, '__temporal_activities') + + +@contextmanager +def temporalized_agent(agent: Agent[Any, Any], settings: TemporalSettings | None = None) -> Generator[None, None, None]: + temporalize_agent(agent, settings) + try: + yield + finally: + untemporalize_agent(agent) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/_function_toolset.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py similarity index 51% rename from pydantic_ai_slim/pydantic_ai/temporal/_function_toolset.py rename to pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py index 2c371a1ff..054420453 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/_function_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py @@ -6,10 +6,10 @@ from pydantic import ConfigDict, with_config from temporalio import activity, workflow -from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai._run_context import RunContext +from pydantic_ai.toolsets import FunctionToolset, ToolsetTool -from .._run_context import RunContext -from ..toolsets import ToolsetTool +from ._run_context import TemporalRunContext from ._settings import TemporalSettings @@ -24,40 +24,50 @@ class _CallToolParams: def temporalize_function_toolset( toolset: FunctionToolset, settings: TemporalSettings | None = None, + tool_settings: dict[str, TemporalSettings] = {}, ) -> list[Callable[..., Any]]: """Temporalize a function toolset. Args: toolset: The function toolset to temporalize. settings: The temporal settings to use. + tool_settings: The temporal settings to use for specific tools identified by tool name. """ if activities := getattr(toolset, '__temporal_activities', None): return activities id = toolset.id - if not id: - raise ValueError( - "A function toolset needs to have an ID in order to be used in a durable execution environment like Temporal. The ID will be used to identify the toolset's activities within the workflow." - ) + assert id is not None settings = settings or TemporalSettings() original_call_tool = toolset.call_tool + setattr(toolset, '__original_call_tool', original_call_tool) @activity.defn(name=f'function_toolset__{id}__call_tool') async def call_tool_activity(params: _CallToolParams) -> Any: name = params.name - ctx = settings.for_tool(id, name).deserialize_run_context(params.serialized_run_context) - tool = (await toolset.get_tools(ctx))[name] + settings_for_tool = settings.merge(tool_settings.get(name)) + ctx = TemporalRunContext.deserialize_run_context( + params.serialized_run_context, settings_for_tool.deserialize_run_context + ) + try: + tool = (await toolset.get_tools(ctx))[name] + except KeyError as e: + raise ValueError( + f'Tool {name!r} not found in toolset {toolset.id!r}. ' + 'Removing or renaming tools during an agent run is not supported with Temporal.' + ) from e + return await original_call_tool(name, params.tool_args, ctx, tool) async def call_tool(name: str, tool_args: dict[str, Any], ctx: RunContext, tool: ToolsetTool) -> Any: - tool_settings = settings.for_tool(id, name) - serialized_run_context = tool_settings.serialize_run_context(ctx) + settings_for_tool = settings.merge(tool_settings.get(name)) + serialized_run_context = TemporalRunContext.serialize_run_context(ctx, settings_for_tool.serialize_run_context) return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] activity=call_tool_activity, arg=_CallToolParams(name=name, tool_args=tool_args, serialized_run_context=serialized_run_context), - **tool_settings.execute_activity_options, + **settings_for_tool.execute_activity_options, ) toolset.call_tool = call_tool @@ -65,3 +75,18 @@ async def call_tool(name: str, tool_args: dict[str, Any], ctx: RunContext, tool: activities = [call_tool_activity] setattr(toolset, '__temporal_activities', activities) return activities + + +def untemporalize_function_toolset(toolset: FunctionToolset) -> None: + """Untemporalize a function toolset. + + Args: + toolset: The function toolset to untemporalize. + """ + if not hasattr(toolset, '__temporal_activities'): + return + + toolset.call_tool = getattr(toolset, '__original_call_tool') + delattr(toolset, '__original_call_tool') + + delattr(toolset, '__temporal_activities') diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_logfire.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_logfire.py new file mode 100644 index 000000000..bb307b990 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_logfire.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import Callable + +from opentelemetry.trace import get_tracer +from temporalio.client import ClientConfig, Plugin as ClientPlugin +from temporalio.contrib.opentelemetry import TracingInterceptor +from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig +from temporalio.service import ConnectConfig, ServiceClient + + +def _default_setup_logfire(): + import logfire + + logfire.configure(console=False) + logfire.instrument_pydantic_ai() + + +class LogfirePlugin(ClientPlugin): + """Temporal client plugin for Logfire.""" + + def __init__(self, setup_logfire: Callable[[], None] = _default_setup_logfire): + self.setup_logfire = setup_logfire + + def configure_client(self, config: ClientConfig) -> ClientConfig: + interceptors = config.get('interceptors', []) + config['interceptors'] = [*interceptors, TracingInterceptor(get_tracer('temporal'))] + return super().configure_client(config) + + async def connect_service_client(self, config: ConnectConfig) -> ServiceClient: + self.setup_logfire() + + config.runtime = Runtime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url='http://localhost:4318'))) + return await super().connect_service_client(config) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/_mcp_server.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py similarity index 72% rename from pydantic_ai_slim/pydantic_ai/temporal/_mcp_server.py rename to pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py index 6a93e0c34..edc021a90 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/_mcp_server.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py @@ -23,26 +23,27 @@ class _CallToolParams: def temporalize_mcp_server( server: MCPServer, settings: TemporalSettings | None = None, + tool_settings: dict[str, TemporalSettings] = {}, ) -> list[Callable[..., Any]]: """Temporalize an MCP server. Args: server: The MCP server to temporalize. settings: The temporal settings to use. + tool_settings: The temporal settings to use for each tool. """ if activities := getattr(server, '__temporal_activities', None): return activities id = server.id - if not id: - raise ValueError( - "An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal. The ID will be used to identify the server's activities within the workflow." - ) + assert id is not None settings = settings or TemporalSettings() original_list_tools = server.list_tools original_direct_call_tool = server.direct_call_tool + setattr(server, '__original_list_tools', original_list_tools) + setattr(server, '__original_direct_call_tool', original_direct_call_tool) @activity.defn(name=f'mcp_server__{id}__list_tools') async def list_tools_activity() -> list[mcp_types.Tool]: @@ -66,7 +67,7 @@ async def direct_call_tool( return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] activity=call_tool_activity, arg=_CallToolParams(name=name, tool_args=args, metadata=metadata), - **settings.for_tool(id, name).execute_activity_options, + **tool_settings.get(name, settings).execute_activity_options, ) server.list_tools = list_tools @@ -75,3 +76,20 @@ async def direct_call_tool( activities = [list_tools_activity, call_tool_activity] setattr(server, '__temporal_activities', activities) return activities + + +def untemporalize_mcp_server(server: MCPServer) -> None: + """Untemporalize an MCP server. + + Args: + server: The MCP server to untemporalize. + """ + if not hasattr(server, '__temporal_activities'): + return + + server.list_tools = getattr(server, '__original_list_tools') + server.direct_call_tool = getattr(server, '__original_direct_call_tool') + delattr(server, '__original_list_tools') + delattr(server, '__original_direct_call_tool') + + delattr(server, '__temporal_activities') diff --git a/pydantic_ai_slim/pydantic_ai/temporal/_model.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py similarity index 82% rename from pydantic_ai_slim/pydantic_ai/temporal/_model.py rename to pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py index cbadf7bd1..9f2240e58 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/_model.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py @@ -9,10 +9,10 @@ from pydantic import ConfigDict, with_config from temporalio import activity, workflow -from .._run_context import RunContext -from ..agent import EventStreamHandler -from ..exceptions import UserError -from ..messages import ( +from pydantic_ai._run_context import RunContext +from pydantic_ai.agent import EventStreamHandler +from pydantic_ai.exceptions import UserError +from pydantic_ai.messages import ( FinalResultEvent, ModelMessage, ModelResponse, @@ -21,9 +21,11 @@ TextPart, ToolCallPart, ) -from ..models import Model, ModelRequestParameters, StreamedResponse -from ..settings import ModelSettings -from ..usage import Usage +from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse +from pydantic_ai.settings import ModelSettings +from pydantic_ai.usage import Usage + +from ._run_context import TemporalRunContext from ._settings import TemporalSettings @@ -84,13 +86,18 @@ def temporalize_model( # noqa: C901 original_request = model.request original_request_stream = model.request_stream + setattr(model, '__original_request', original_request) + setattr(model, '__original_request_stream', original_request_stream) + @activity.defn(name='model_request') async def request_activity(params: _RequestParams) -> ModelResponse: return await original_request(params.messages, params.model_settings, params.model_request_parameters) @activity.defn(name='model_request_stream') async def request_stream_activity(params: _RequestParams) -> ModelResponse: - run_context = settings.deserialize_run_context(params.serialized_run_context) + run_context = TemporalRunContext.deserialize_run_context( + params.serialized_run_context, settings.deserialize_run_context + ) async with original_request_stream( params.messages, params.model_settings, params.model_request_parameters, run_context ) as streamed_response: @@ -102,6 +109,7 @@ async def request_stream_activity(params: _RequestParams) -> ModelResponse: ] } + # Keep in sync with `AgentStream.__aiter__` async def aiter(): def _get_final_result_event(e: ModelResponseStreamEvent) -> FinalResultEvent | None: """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" @@ -119,9 +127,9 @@ def _get_final_result_event(e: ModelResponseStreamEvent) -> FinalResultEvent | N elif tool_def.kind == 'deferred': return FinalResultEvent(tool_name=None, tool_call_id=None) - # TODO: usage_checking_stream = _get_usage_checking_stream_response( - # self._raw_stream_response, self._usage_limits, self.usage - # ) + # `AgentStream.__aiter__`, which this is based on, calls `_get_usage_checking_stream_response` here, + # but we don't have access to the `_usage_limits`. + async for event in streamed_response: yield event if (final_result_event := _get_final_result_event(event)) is not None: @@ -168,7 +176,7 @@ async def request_stream( if run_context is None: raise UserError('Streaming with Temporal requires `request_stream` to be called with a `run_context`') - serialized_run_context = settings.serialize_run_context(run_context) + serialized_run_context = TemporalRunContext.serialize_run_context(run_context, settings.serialize_run_context) response = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] activity=request_stream_activity, arg=_RequestParams( @@ -187,3 +195,20 @@ async def request_stream( activities = [request_activity, request_stream_activity] setattr(model, '__temporal_activities', activities) return activities + + +def untemporalize_model(model: Model) -> None: + """Untemporalize a model. + + Args: + model: The model to untemporalize. + """ + if not hasattr(model, '__temporal_activities'): + return + + model.request = getattr(model, '__original_request') + model.request_stream = getattr(model, '__original_request_stream') + + delattr(model, '__original_request') + delattr(model, '__original_request_stream') + delattr(model, '__temporal_activities') diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py new file mode 100644 index 000000000..c1e896508 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from typing import Any, Callable + +from pydantic_ai._run_context import RunContext + + +class TemporalRunContext(RunContext[Any]): + def __init__(self, **kwargs: Any): + self.__dict__ = kwargs + setattr( + self, + '__dataclass_fields__', + {name: field for name, field in RunContext.__dataclass_fields__.items() if name in kwargs}, + ) + + def __getattribute__(self, name: str) -> Any: + try: + return super().__getattribute__(name) + except AttributeError as e: + if name in RunContext.__dataclass_fields__: + raise AttributeError( + f'Temporalized {RunContext.__name__!r} object has no attribute {name!r}. ' + 'To make the attribute available, pass a `TemporalSettings` object to `temporalize_agent` with a custom `serialize_run_context` function that returns a dictionary that includes the attribute.' + ) + else: + raise e + + @classmethod + def serialize_run_context( + cls, + ctx: RunContext[Any], + extra_serializer: Callable[[RunContext[Any]], dict[str, Any]] | None = None, + ) -> dict[str, Any]: + return { + 'retries': ctx.retries, + 'tool_call_id': ctx.tool_call_id, + 'tool_name': ctx.tool_name, + 'retry': ctx.retry, + 'run_step': ctx.run_step, + **(extra_serializer(ctx) if extra_serializer else {}), + } + + @classmethod + def deserialize_run_context( + cls, ctx: dict[str, Any], extra_deserializer: Callable[[dict[str, Any]], dict[str, Any]] | None = None + ) -> RunContext[Any]: + return cls( + retries=ctx['retries'], + tool_call_id=ctx['tool_call_id'], + tool_name=ctx['tool_name'], + retry=ctx['retry'], + run_step=ctx['run_step'], + **(extra_deserializer(ctx) if extra_deserializer else {}), + ) + + +def serialize_run_context_deps(ctx: RunContext[Any]) -> dict[str, Any]: + if not isinstance(ctx.deps, dict): + raise ValueError( + 'The `deps` object must be a JSON-serializable dictionary in order to be used with Temporal. ' + 'To use a different type, pass a `TemporalSettings` object to `temporalize_agent` with custom `serialize_run_context` and `deserialize_run_context` functions.' + ) + return {'deps': ctx.deps} # pyright: ignore[reportUnknownMemberType] + + +def deserialize_run_context_deps(ctx: dict[str, Any]) -> dict[str, Any]: + return {'deps': ctx['deps']} diff --git a/pydantic_ai_slim/pydantic_ai/temporal/_settings.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_settings.py similarity index 71% rename from pydantic_ai_slim/pydantic_ai/temporal/_settings.py rename to pydantic_ai_slim/pydantic_ai/ext/temporal/_settings.py index 14c9d595e..4340f1863 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/_settings.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_settings.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, fields, replace from datetime import timedelta from typing import Any, Callable @@ -9,7 +9,7 @@ from pydantic_ai._run_context import RunContext -from ._run_context import TemporalRunContext +from ._run_context import deserialize_run_context_deps, serialize_run_context_deps @dataclass @@ -29,16 +29,8 @@ class TemporalSettings: summary: str | None = None priority: Priority = Priority.default - # Pydantic AI specific - tool_settings: dict[str, dict[str, TemporalSettings]] | None = None - - def for_tool(self, toolset_id: str, tool_id: str) -> TemporalSettings: - if self.tool_settings is None: - return self - return self.tool_settings.get(toolset_id, {}).get(tool_id, self) - - serialize_run_context: Callable[[RunContext], Any] = TemporalRunContext.serialize_run_context - deserialize_run_context: Callable[[dict[str, Any]], RunContext] = TemporalRunContext.deserialize_run_context + serialize_run_context: Callable[[RunContext], dict[str, Any]] = serialize_run_context_deps + deserialize_run_context: Callable[[dict[str, Any]], dict[str, Any]] = deserialize_run_context_deps @property def execute_activity_options(self) -> dict[str, Any]: @@ -55,3 +47,9 @@ def execute_activity_options(self) -> dict[str, Any]: 'summary': self.summary, 'priority': self.priority, } + + def merge(self, other: TemporalSettings | None) -> TemporalSettings: + """Merge non-default values from another TemporalSettings instance into this one, returning a new instance.""" + if not other: + return self + return replace(self, **{f.name: value for f in fields(other) if (value := getattr(other, f.name)) != f.default}) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py new file mode 100644 index 000000000..bf8b8281d --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import Any, Callable + +from pydantic_ai.mcp import MCPServer +from pydantic_ai.toolsets.abstract import AbstractToolset +from pydantic_ai.toolsets.function import FunctionToolset + +from ._function_toolset import temporalize_function_toolset, untemporalize_function_toolset +from ._mcp_server import temporalize_mcp_server, untemporalize_mcp_server +from ._settings import TemporalSettings + + +def temporalize_toolset( + toolset: AbstractToolset, settings: TemporalSettings | None, tool_settings: dict[str, TemporalSettings] = {} +) -> list[Callable[..., Any]]: + """Temporalize a toolset. + + Args: + toolset: The toolset to temporalize. + settings: The temporal settings to use. + tool_settings: The temporal settings to use for specific tools identified by tool name. + """ + if isinstance(toolset, FunctionToolset): + return temporalize_function_toolset(toolset, settings, tool_settings) + elif isinstance(toolset, MCPServer): + return temporalize_mcp_server(toolset, settings, tool_settings) + else: + return [] + + +def untemporalize_toolset(toolset: AbstractToolset) -> None: + """Untemporalize a toolset. + + Args: + toolset: The toolset to untemporalize. + """ + if isinstance(toolset, FunctionToolset): + untemporalize_function_toolset(toolset) + elif isinstance(toolset, MCPServer): + untemporalize_mcp_server(toolset) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/temporal/__init__.py deleted file mode 100644 index 32ea06775..000000000 --- a/pydantic_ai_slim/pydantic_ai/temporal/__init__.py +++ /dev/null @@ -1,124 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from dataclasses import replace -from typing import Any, Callable - -import logfire # TODO: Not always available -from opentelemetry import trace # TODO: Not always available -from temporalio.client import ClientConfig, Plugin as ClientPlugin -from temporalio.contrib.opentelemetry import TracingInterceptor -from temporalio.contrib.pydantic import pydantic_data_converter -from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig -from temporalio.service import ConnectConfig, ServiceClient -from temporalio.worker import Plugin as WorkerPlugin, WorkerConfig -from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner - -from pydantic_ai.agent import Agent -from pydantic_ai.toolsets.abstract import AbstractToolset - -from ..models import Model -from ._model import temporalize_model -from ._run_context import TemporalRunContext -from ._settings import TemporalSettings -from ._toolset import temporalize_toolset - -__all__ = [ - 'TemporalSettings', - 'TemporalRunContext', - 'PydanticAIPlugin', - 'LogfirePlugin', - 'AgentPlugin', -] - - -class PydanticAIPlugin(ClientPlugin, WorkerPlugin): - """Temporal client and worker plugin for Pydantic AI.""" - - def configure_client(self, config: ClientConfig) -> ClientConfig: - config['data_converter'] = pydantic_data_converter - return super().configure_client(config) - - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - runner = config.get('workflow_runner') # pyright: ignore[reportUnknownMemberType] - if isinstance(runner, SandboxedWorkflowRunner): - config['workflow_runner'] = replace( - runner, - restrictions=runner.restrictions.with_passthrough_modules( - 'pydantic_ai', - 'logfire', - # Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize - 'attrs', - # Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize - 'numpy', - 'pandas', - ), - ) - return super().configure_worker(config) - - -class LogfirePlugin(ClientPlugin): - """Temporal client plugin for Logfire.""" - - def configure_client(self, config: ClientConfig) -> ClientConfig: - config['interceptors'] = [TracingInterceptor(trace.get_tracer('temporal'))] - return super().configure_client(config) - - async def connect_service_client(self, config: ConnectConfig) -> ServiceClient: - # TODO: Do we need this here? - logfire.configure(console=False) - logfire.instrument_pydantic_ai() - logfire.instrument_httpx(capture_all=True) - - config.runtime = Runtime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url='http://localhost:4318'))) - return await super().connect_service_client(config) - - -class AgentPlugin(WorkerPlugin): - """Temporal worker plugin for a specific Pydantic AI agent.""" - - def __init__(self, agent: Agent[Any, Any], settings: TemporalSettings | None = None): - self.activities = temporalize_agent(agent, settings) - - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType] - config['activities'] = [*activities, *self.activities] - return super().configure_worker(config) - - -def temporalize_agent( - agent: Agent[Any, Any], - settings: TemporalSettings | None = None, - temporalize_toolset_func: Callable[ - [AbstractToolset, TemporalSettings | None], list[Callable[..., Any]] - ] = temporalize_toolset, -) -> list[Callable[..., Any]]: - """Temporalize an agent. - - Args: - agent: The agent to temporalize. - settings: The temporal settings to use. - temporalize_toolset_func: The function to use to temporalize the toolsets. - """ - if existing_activities := getattr(agent, '__temporal_activities', None): - return existing_activities - - settings = settings or TemporalSettings() - - # TODO: Doesn't consider model/toolsets passed at iter time, raise an error if that happens. - # Similarly, passing event_stream_handler at iter time should raise an error. - - activities: list[Callable[..., Any]] = [] - if isinstance(agent.model, Model): - activities.extend(temporalize_model(agent.model, settings, agent._event_stream_handler)) # pyright: ignore[reportPrivateUsage] - - def temporalize_toolset(toolset: AbstractToolset) -> None: - activities.extend(temporalize_toolset_func(toolset, settings)) - - agent.toolset.apply(temporalize_toolset) - - setattr(agent, '__temporal_activities', activities) - return activities - - -# TODO: untemporalize_agent diff --git a/pydantic_ai_slim/pydantic_ai/temporal/_run_context.py b/pydantic_ai_slim/pydantic_ai/temporal/_run_context.py deleted file mode 100644 index 8bc7029e6..000000000 --- a/pydantic_ai_slim/pydantic_ai/temporal/_run_context.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import annotations - -from typing import Any - -from pydantic_ai._run_context import AgentDepsT, RunContext - - -class TemporalRunContext(RunContext[AgentDepsT]): - def __init__(self, **kwargs: Any): - self.__dict__ = kwargs - setattr( - self, - '__dataclass_fields__', - {name: field for name, field in RunContext.__dataclass_fields__.items() if name in kwargs}, - ) - - def __getattribute__(self, name: str) -> Any: - try: - return super().__getattribute__(name) - except AttributeError as e: - if name in RunContext.__dataclass_fields__: - raise AttributeError( - f'Temporalized {RunContext.__name__!r} object has no attribute {name!r}. To make the attribute available, pass a `TemporalSettings` object to `temporalize_agent` that has a custom `serialize_run_context` function that returns a dictionary that includes the attribute.' - ) - else: - raise e - - @classmethod - def serialize_run_context(cls, ctx: RunContext[AgentDepsT]) -> dict[str, Any]: - return { - 'deps': ctx.deps, - 'retries': ctx.retries, - 'tool_call_id': ctx.tool_call_id, - 'tool_name': ctx.tool_name, - 'retry': ctx.retry, - 'run_step': ctx.run_step, - } - - @classmethod - def deserialize_run_context(cls, ctx: dict[str, Any]) -> RunContext[AgentDepsT]: - return cls(**ctx) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/_toolset.py b/pydantic_ai_slim/pydantic_ai/temporal/_toolset.py deleted file mode 100644 index 289d90071..000000000 --- a/pydantic_ai_slim/pydantic_ai/temporal/_toolset.py +++ /dev/null @@ -1,26 +0,0 @@ -from __future__ import annotations - -from typing import Any, Callable - -from pydantic_ai.mcp import MCPServer -from pydantic_ai.toolsets.abstract import AbstractToolset -from pydantic_ai.toolsets.function import FunctionToolset - -from ._function_toolset import temporalize_function_toolset -from ._mcp_server import temporalize_mcp_server -from ._settings import TemporalSettings - - -def temporalize_toolset(toolset: AbstractToolset, settings: TemporalSettings | None) -> list[Callable[..., Any]]: - """Temporalize a toolset. - - Args: - toolset: The toolset to temporalize. - settings: The temporal settings to use. - """ - if isinstance(toolset, FunctionToolset): - return temporalize_function_toolset(toolset, settings) - elif isinstance(toolset, MCPServer): - return temporalize_mcp_server(toolset, settings) - else: - return [] diff --git a/temporal.py b/temporal.py index 2b8a4a319..677e831ba 100644 --- a/temporal.py +++ b/temporal.py @@ -10,38 +10,22 @@ from typing_extensions import TypedDict from pydantic_ai import Agent, RunContext -from pydantic_ai.mcp import MCPServerStdio -from pydantic_ai.messages import AgentStreamEvent, HandleResponseEvent -from pydantic_ai.temporal import ( +from pydantic_ai.ext.temporal import ( AgentPlugin, LogfirePlugin, PydanticAIPlugin, TemporalSettings, + temporalize_agent, ) -from pydantic_ai.toolsets import FunctionToolset +from pydantic_ai.mcp import MCPServerStdio +from pydantic_ai.messages import AgentStreamEvent, HandleResponseEvent class Deps(TypedDict): country: str -def get_country(ctx: RunContext[Deps]) -> str: - return ctx.deps['country'] - - -toolset = FunctionToolset[Deps](tools=[get_country], id='country') -mcp_server = MCPServerStdio( - 'python', - ['-m', 'tests.mcp_server'], - timeout=20, - id='test', -) - - -async def event_stream_handler( - ctx: RunContext[Deps], - stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent], -): +async def event_stream_handler(ctx: RunContext[Deps], stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent]): logfire.info(f'{ctx.run_step=}') async for event in stream: logfire.info(f'{event=}') @@ -49,24 +33,33 @@ async def event_stream_handler( agent = Agent( 'openai:gpt-4o', - toolsets=[toolset, mcp_server], - event_stream_handler=event_stream_handler, deps_type=Deps, + toolsets=[MCPServerStdio('python', ['-m', 'tests.mcp_server'], timeout=20, id='test')], + event_stream_handler=event_stream_handler, ) -temporal_settings = TemporalSettings( - start_to_close_timeout=timedelta(seconds=60), - tool_settings={ # TODO: Allow default temporal settings to be set for all activities in a toolset + +@agent.tool +def get_country(ctx: RunContext[Deps]) -> str: + return ctx.deps['country'] + + +# This needs to be called in the same scope where the `agent` is bound to the workflow, +# as it modifies the `agent` object in place to swap out methods that use IO for ones that use Temporal activities. +temporalize_agent( + agent, + settings=TemporalSettings(start_to_close_timeout=timedelta(seconds=60)), + toolset_settings={ + 'country': TemporalSettings(start_to_close_timeout=timedelta(seconds=120)), + }, + tool_settings={ 'country': { - 'get_country': TemporalSettings(start_to_close_timeout=timedelta(seconds=110)), + 'get_country': TemporalSettings(start_to_close_timeout=timedelta(seconds=180)), }, }, ) -TASK_QUEUE = 'pydantic-ai-agent-task-queue' - - @workflow.defn class MyAgentWorkflow: @workflow.run @@ -75,22 +68,26 @@ async def run(self, prompt: str, deps: Deps) -> str: return result.output -# TODO: For some reason, when I put this (specifically the temporalize_agent call) inside `async def main()`, -# we get tons of errors. -plugin = AgentPlugin(agent, temporal_settings) +TASK_QUEUE = 'pydantic-ai-agent-task-queue' + + +def setup_logfire(): + logfire.configure(console=False) + logfire.instrument_pydantic_ai() + logfire.instrument_httpx(capture_all=True) async def main(): client = await Client.connect( 'localhost:7233', - plugins=[PydanticAIPlugin(), LogfirePlugin()], + plugins=[PydanticAIPlugin(), LogfirePlugin(setup_logfire)], ) async with Worker( client, task_queue=TASK_QUEUE, workflows=[MyAgentWorkflow], - plugins=[plugin], + plugins=[AgentPlugin(agent)], ): output = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType] MyAgentWorkflow.run,