Skip to content

Commit 8940d51

Browse files
committed
Support client deriving types from operation impl or interface
1 parent 963de97 commit 8940d51

File tree

4 files changed

+63
-25
lines changed

4 files changed

+63
-25
lines changed

temporalio/nexus/handler.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
TypeVar,
2121
)
2222

23-
import nexusrpc.handler
2423
from hyperlinked import hyperlinked, print
25-
from nexusrpc.handler import _ServiceImpl
2624
from typing_extensions import Concatenate, overload
2725

26+
import nexusrpc.handler
2827
import temporalio.api.common.v1
2928
import temporalio.api.enums.v1
3029
import temporalio.common
30+
from nexusrpc.handler import _ServiceImpl
3131
from temporalio.client import (
3232
Client,
3333
WorkflowHandle,
@@ -298,7 +298,7 @@ def workflow_operation(
298298
def factory(service: S) -> WorkflowOperation[I, O]:
299299
return WorkflowOperation(service, start_method)
300300

301-
factory.__nexus_operation__ = nexusrpc.handler._NexusOperationDefinition(
301+
factory.__nexus_operation__ = nexusrpc.handler.NexusOperationDefinition(
302302
name=start_method.__name__
303303
)
304304
return factory

temporalio/worker/_interceptor.py

+39-17
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import concurrent.futures
6-
from dataclasses import dataclass
6+
from dataclasses import dataclass, field
77
from datetime import timedelta
88
from typing import (
99
Any,
@@ -22,6 +22,8 @@
2222
)
2323

2424
import nexusrpc
25+
import nexusrpc.handler
26+
import nexusrpc.interface
2527

2628
import temporalio.activity
2729
import temporalio.api.common.v1
@@ -289,6 +291,7 @@ class StartChildWorkflowInput:
289291
ret_type: Optional[Type]
290292

291293

294+
# TODO(dan): Put these in a better location. Type variance?
292295
I = TypeVar("I")
293296
O = TypeVar("O")
294297

@@ -299,34 +302,53 @@ class StartNexusOperationInput(Generic[I, O]):
299302

300303
endpoint: str
301304
service: str
302-
operation: Union[nexusrpc.Operation[I, O], str]
305+
operation: Union[
306+
nexusrpc.interface.Operation[I, O],
307+
Callable[[Any], nexusrpc.handler.Operation[I, O]],
308+
str,
309+
]
303310
input: I
304311
schedule_to_close_timeout: Optional[timedelta]
305312
headers: Optional[Mapping[str, str]]
306313

314+
# Cached properties, initialized in __post_init__
315+
_operation_name: str = field(init=False, repr=False)
316+
_input_type: Optional[Type[I]] = field(init=False, repr=False)
317+
_output_type: Optional[Type[O]] = field(init=False, repr=False)
318+
319+
def __post_init__(self) -> None:
320+
if isinstance(self.operation, str):
321+
self._operation_name = self.operation
322+
self._input_type = None
323+
self._output_type = None
324+
elif isinstance(self.operation, nexusrpc.interface.Operation):
325+
self._operation_name = self.operation.name
326+
self._input_type = self.operation.input_type
327+
self._output_type = self.operation.output_type
328+
elif isinstance(self.operation, Callable):
329+
defn = getattr(self.operation, "__nexus_operation__", None)
330+
if isinstance(defn, nexusrpc.handler.NexusOperationDefinition):
331+
self._operation_name = defn.name
332+
self._input_type = defn.input_type
333+
self._output_type = defn.output_type
334+
else:
335+
raise ValueError(
336+
f"Operation callable is not a Nexus operation: {self.operation}"
337+
)
338+
else:
339+
raise ValueError(f"Operation is not a Nexus operation: {self.operation}")
340+
307341
@property
308342
def operation_name(self) -> str:
309-
return (
310-
self.operation.name
311-
if isinstance(self.operation, nexusrpc.Operation)
312-
else self.operation
313-
)
343+
return self._operation_name
314344

315345
@property
316346
def input_type(self) -> Optional[Type[I]]:
317-
return (
318-
self.operation.input_type
319-
if isinstance(self.operation, nexusrpc.Operation)
320-
else None
321-
)
347+
return self._input_type
322348

323349
@property
324350
def output_type(self) -> Optional[Type[O]]:
325-
return (
326-
self.operation.output_type
327-
if isinstance(self.operation, nexusrpc.Operation)
328-
else None
329-
)
351+
return self._output_type
330352

331353

332354
@dataclass

temporalio/worker/_workflow_instance.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1559,7 +1559,11 @@ async def workflow_start_nexus_operation(
15591559
self,
15601560
endpoint: str,
15611561
service: str,
1562-
operation: Union[nexusrpc.Operation[I, O], str],
1562+
operation: Union[
1563+
nexusrpc.interface.Operation[I, O],
1564+
Callable[[Any], nexusrpc.handler.Operation[I, O]],
1565+
str,
1566+
],
15631567
input: Any,
15641568
schedule_to_close_timeout: Optional[timedelta] = None,
15651569
headers: Optional[Mapping[str, str]] = None,

temporalio/workflow.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,11 @@ async def workflow_start_nexus_operation(
840840
self,
841841
endpoint: str,
842842
service: str,
843-
operation: Union[nexusrpc.Operation[I, O], str],
843+
operation: Union[
844+
nexusrpc.interface.Operation[I, O],
845+
Callable[[Any], nexusrpc.handler.Operation[I, O]],
846+
str,
847+
],
844848
input: Any,
845849
schedule_to_close_timeout: Optional[timedelta] = None,
846850
headers: Optional[Mapping[str, str]] = None,
@@ -4385,7 +4389,11 @@ def operation_token(self) -> Optional[str]:
43854389
async def start_nexus_operation(
43864390
endpoint: str,
43874391
service: str,
4388-
operation: Union[nexusrpc.Operation[I, O], str],
4392+
operation: Union[
4393+
nexusrpc.interface.Operation[I, O],
4394+
Callable[[Any], nexusrpc.handler.Operation[I, O]],
4395+
str,
4396+
],
43894397
input: Any,
43904398
*,
43914399
schedule_to_close_timeout: Optional[timedelta] = None,
@@ -5158,7 +5166,11 @@ def __init__(
51585166
# TODO(dan): overloads: no-input, operation name, ret type
51595167
async def start_operation(
51605168
self,
5161-
operation: Union[nexusrpc.interface.Operation[I, O], str],
5169+
operation: Union[
5170+
nexusrpc.interface.Operation[I, O],
5171+
Callable[[Any], nexusrpc.handler.Operation[I, O]],
5172+
str,
5173+
],
51625174
input: I,
51635175
*,
51645176
schedule_to_close_timeout: Optional[timedelta] = None,
@@ -5180,7 +5192,7 @@ async def execute_operation(
51805192
self,
51815193
operation: Union[
51825194
nexusrpc.interface.Operation[I, O],
5183-
Callable[..., nexusrpc.handler.Operation[I, O]],
5195+
Callable[[Any], nexusrpc.handler.Operation[I, O]],
51845196
str,
51855197
],
51865198
input: I,

0 commit comments

Comments
 (0)