Skip to content

Commit 862e4f9

Browse files
committed
Clean up test
1 parent 1474b11 commit 862e4f9

File tree

2 files changed

+127
-80
lines changed

2 files changed

+127
-80
lines changed

temporalio/worker/_workflow_instance.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -1821,6 +1821,20 @@ async def run_child() -> Any:
18211821
async def _outbound_start_nexus_operation(
18221822
self, input: StartNexusOperationInput
18231823
) -> _NexusOperationHandle[Any]:
1824+
# A Nexus operation handle has two futures: self._start_fut is resolved wheas a
1825+
# result of the Nexus operation starting (activation job:
1826+
# resolve_nexus_operation_start), and self._result_fut is resolved as a result of
1827+
# the Nexus operation completing (activation job: resolve_nexus_operation). The
1828+
# handle itself corresponds to an asyncio.Task which waits on self.result_fut,
1829+
# handling CancelledError by emitting a RequestCancelNexusOperation command. We do
1830+
# not return the handle until we receive resolve_nexus_operation_start, like
1831+
# ChildWorkflowHandle and unlike ActivityHandle. Note that a Nexus operation may
1832+
# complete synchronously (in which case both jobs will be sent in the same
1833+
# activation, and start will be resolved without an operation token), or
1834+
# asynchronously (in which case start they may be sent in separate activations,
1835+
# and start will be resolved with an operation token). See comments in
1836+
# tests/worker/test_nexus.py for worked examples of the evolution of the resulting
1837+
# handle state machine in the sync and async Nexus response cases.
18241838
handle: _NexusOperationHandle
18251839

18261840
async def operation_handle_fn() -> Any:
@@ -2977,9 +2991,17 @@ async def result(self) -> O:
29772991
def __await__(self) -> Generator[Any, Any, O]:
29782992
return self._task.__await__()
29792993

2994+
def __repr__(self) -> str:
2995+
return (
2996+
f"{self._start_fut} "
2997+
f"{self._result_fut} "
2998+
f"Task[{self._task._state}] fut_waiter = {self._task._fut_waiter}) ({self._task._must_cancel})"
2999+
)
3000+
29803001
def cancel(self) -> bool:
29813002
# TODO(dan): what do we do when the start result has been delivered and we know
2982-
# this cannot be canceled (e.g. because it was a sync result, or because it failed.)
3003+
# this cannot be canceled (e.g. because it was a sync result, or because it
3004+
# failed.)
29833005
return self._task.cancel()
29843006

29853007
def _resolve_start_success(self, operation_token: Optional[str]) -> None:

tests/worker/test_nexus.py

+104-79
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
import asyncio
21
import uuid
32
from dataclasses import dataclass
43
from datetime import timedelta
5-
from typing import Union, cast
4+
from typing import Union
65

76
import nexusrpc
87
import nexusrpc.handler
@@ -34,6 +33,7 @@ class SyncResponse:
3433
@dataclass
3534
class AsyncResponse:
3635
operation_workflow_id: str
36+
block_forever_waiting_for_cancellation: bool
3737

3838

3939
# The ordering in this union is critical since the data converter matches eagerly,
@@ -47,7 +47,6 @@ class AsyncResponse:
4747
@dataclass
4848
class MyInput:
4949
response_type: ResponseType
50-
block_forever_waiting_for_cancellation: bool
5150

5251

5352
@dataclass
@@ -86,7 +85,7 @@ async def start(
8685
elif isinstance(input.response_type, AsyncResponse):
8786
return await temporalio.nexus.handler.start_workflow(
8887
MyHandlerWorkflow.run,
89-
input.block_forever_waiting_for_cancellation,
88+
input.response_type.block_forever_waiting_for_cancellation,
9089
id=input.response_type.operation_workflow_id,
9190
options=options,
9291
)
@@ -126,7 +125,7 @@ class MyCallerWorkflow:
126125
def __init__(
127126
self,
128127
response_type: ResponseType,
129-
should_cancel: bool,
128+
request_cancel: bool,
130129
task_queue: str,
131130
) -> None:
132131
self.nexus_service = workflow.NexusClient(
@@ -141,39 +140,33 @@ def __init__(
141140
async def run(
142141
self,
143142
response_type: ResponseType,
144-
should_cancel: bool,
143+
request_cancel: bool,
145144
task_queue: str,
146145
) -> str:
147146
op_handle = await self.nexus_service.start_operation(
148147
MyService.my_operation,
149-
MyInput(
150-
response_type,
151-
block_forever_waiting_for_cancellation=should_cancel,
152-
),
148+
MyInput(response_type),
153149
)
150+
print(f"🌈 {'after await start':<24}: {op_handle}")
154151
self._nexus_operation_started = True
155-
task = cast(asyncio.Task, getattr(op_handle, "_task"))
156152
if isinstance(response_type, SyncResponse):
157153
assert op_handle.operation_token is None
158-
# TODO(dan): I expected task to be done at this point
159-
# assert task.done()
160-
# assert not task.exception()
161-
if should_cancel:
162-
# TODO(dan): why does this assert pass (same Q as above re task.done())
163-
assert op_handle.cancel()
164-
elif isinstance(response_type, AsyncResponse):
154+
else:
165155
assert op_handle.operation_token
166-
assert not task.done()
167-
# Allow the test to control when we proceed so that it can make initial
168-
# assertions.
156+
# Allow the test to make assertions before signalling us to proceed.
169157
await workflow.wait_condition(lambda: self._proceed)
158+
print(f"🌈 {'after await proceed':<24}: {op_handle}")
170159

171-
if should_cancel:
172-
# We cannot assert that cancel() returns True because it's possible that a
173-
# resolve_nexus_operation job has already come in.
174-
op_handle.cancel()
160+
if request_cancel:
161+
# Even for SyncResponse, the op_handle future is not done at this point; that
162+
# transition doesn't happen until the handle is awaited.
163+
print(f"🌈 {'before op_handle.cancel':<24}: {op_handle}")
164+
cancel_ret = op_handle.cancel()
165+
print(f"🌈 {'cancel returned':<24}: {cancel_ret}")
175166

167+
print(f"🌈 {'before await op_handle':<24}: {op_handle}")
176168
result = await op_handle
169+
print(f"🌈 {'after await op_handle':<24}: {op_handle}")
177170
return result.val
178171

179172
@workflow.update
@@ -190,32 +183,64 @@ def proceed(self) -> None:
190183
#
191184

192185

193-
# TODO(dan): cross-namespace tests
194-
# TODO(dan): nexus endpoint pytest fixture?
195-
@pytest.mark.parametrize("should_attempt_cancel", [False, True])
196-
async def test_sync_response(client: Client, should_attempt_cancel: bool):
197-
task_queue = str(uuid.uuid4())
198-
async with Worker(
199-
client,
200-
nexus_services=[MyServiceImpl()],
201-
workflows=[MyCallerWorkflow, MyHandlerWorkflow],
202-
task_queue=task_queue,
203-
workflow_runner=UnsandboxedWorkflowRunner(),
204-
):
205-
await create_nexus_endpoint(task_queue, client)
206-
wf_handle = await client.start_workflow(
207-
MyCallerWorkflow.run,
208-
args=[SyncResponse(), should_attempt_cancel, task_queue],
209-
id=str(uuid.uuid4()),
210-
task_queue=task_queue,
211-
)
212-
# The response is synchronous, so the workflow's attempt to cancel the
213-
# NexusOperationHandle do not result in cancellation.
214-
result = await wf_handle.result()
215-
assert result == "sync response"
186+
# When request_cancel is True, the NexusOperationHandle in the workflow evolves
187+
# through the following states:
188+
# start_fut result_fut handle_task w/ fut_waiter (task._must_cancel)
189+
#
190+
# Case 1: Sync Nexus operation response w/ cancellation of NexusOperationHandle
191+
# -----------------------------------------------------------------------------
192+
# >>>>>>>>>>>> WFT 1
193+
# after await start : Future_7856[FINISHED] Future_7984[FINISHED] Task[PENDING] fut_waiter = Future_8240[FINISHED]) (False)
194+
# before op_handle.cancel : Future_7856[FINISHED] Future_7984[FINISHED] Task[PENDING] fut_waiter = Future_8240[FINISHED]) (False)
195+
# Future_8240[FINISHED].cancel() -> False # no state transition; fut_waiter is already finished
196+
# cancel returned : True
197+
# before await op_handle : Future_7856[FINISHED] Future_7984[FINISHED] Task[PENDING] fut_waiter = Future_8240[FINISHED]) (True)
198+
# --> Despite cancel having been requested, this await on the nexus op handle does not
199+
# raise CancelledError, because the task's underlying fut_waiter is already finished.
200+
# after await op_handle : Future_7856[FINISHED] Future_7984[FINISHED] Task[FINISHED] fut_waiter = None) (False)
201+
# <<<<<<<<<<<< END WFT 1
202+
#
203+
204+
# Case 2: Async Nexus operation response w/ cancellation of NexusOperationHandle
205+
# ------------------------------------------------------------------------------
206+
# >>>>>>>>>>>> WFT 1
207+
# after await start : Future_7568[FINISHED] Future_7696[PENDING] Task[PENDING] fut_waiter = Future_7952[PENDING]) (False)
208+
# >>>>>>>>>>>> WFT 2
209+
# >>>>>>>>>>>> WFT 3
210+
# after await proceed : Future_7568[FINISHED] Future_7696[PENDING] Task[PENDING] fut_waiter = Future_7952[PENDING]) (False)
211+
# before op_handle.cancel : Future_7568[FINISHED] Future_7696[PENDING] Task[PENDING] fut_waiter = Future_7952[PENDING]) (False)
212+
# Future_7952[PENDING].cancel() -> True # transition to cancelled state; fut_waiter was not finished
213+
# cancel returned : True
214+
# before await op_handle : Future_7568[FINISHED] Future_7696[PENDING] Task[PENDING] fut_waiter = Future_7952[CANCELLED]) (False)
215+
# --> This await on the nexus op handle raises CancelledError, because the task's underlying fut_waiter is cancelled.
216+
217+
# Thus in the sync case, although the caller workflow attempted to cancel the
218+
# NexusOperationHandle, this did not result in a CancelledError when the handle was
219+
# awaited, because both resolve_nexus_operation_start and resolve_nexus_operation jobs
220+
# were sent in the same activation and hence the task's fut_waiter was already finished.
221+
#
222+
# But in the async case, at the time that we cancel the NexusOperationHandle, only the
223+
# resolve_nexus_operation_start job had been sent; the result_fut was unresolved. Thus
224+
# when the handle was awaited, CancelledError was raised.
225+
226+
# To create output like that above, set the following __repr__s:
227+
# asyncio.Future:
228+
# def __repr__(self):
229+
# return f"{self.__class__.__name__}_{str(id(self))[-4:]}[{self._state}]"
230+
# _NexusOperationHandle:
231+
# def __repr__(self) -> str:
232+
# return (
233+
# f"{self._start_fut} "
234+
# f"{self._result_fut} "
235+
# f"Task[{self._task._state}] fut_waiter = {self._task._fut_waiter}) ({self._task._must_cancel})"
236+
# )
216237

217238

218-
async def test_async_response(client: Client):
239+
# TODO(dan): cross-namespace tests
240+
# TODO(dan): nexus endpoint pytest fixture?
241+
# TODO: get rid of UnsandboxedWorkflowRunner (due to xray)
242+
@pytest.mark.parametrize("request_cancel", [False, True])
243+
async def test_sync_response(client: Client, request_cancel: bool):
219244
task_queue = str(uuid.uuid4())
220245
async with Worker(
221246
client,
@@ -224,37 +249,22 @@ async def test_async_response(client: Client):
224249
task_queue=task_queue,
225250
workflow_runner=UnsandboxedWorkflowRunner(),
226251
):
227-
operation_workflow_id = str(uuid.uuid4())
228-
operation_workflow_handle = client.get_workflow_handle(operation_workflow_id)
229252
await create_nexus_endpoint(task_queue, client)
230-
231-
# Start the caller workflow
232253
wf_handle = await client.start_workflow(
233254
MyCallerWorkflow.run,
234-
args=[AsyncResponse(operation_workflow_id), False, task_queue],
255+
args=[SyncResponse(), request_cancel, task_queue],
235256
id=str(uuid.uuid4()),
236257
task_queue=task_queue,
237258
)
238259

239-
# Wait for the Nexus operation to start and check that the operation-backing workflow now exists.
240-
await wf_handle.execute_update(MyCallerWorkflow.wait_nexus_operation_started)
241-
wf_details = await operation_workflow_handle.describe()
242-
assert wf_details.status in [
243-
WorkflowExecutionStatus.RUNNING,
244-
WorkflowExecutionStatus.COMPLETED,
245-
]
246-
247-
# Wait for the Nexus operation to complete and check that the operation-backing
248-
# workflow has completed.
249-
await wf_handle.signal(MyCallerWorkflow.proceed)
250-
251-
wf_details = await operation_workflow_handle.describe()
252-
assert wf_details.status == WorkflowExecutionStatus.COMPLETED
260+
# The operation result is returned even when request_cancel=True, because the
261+
# response was synchronous and it could not be cancelled. See explanation above.
253262
result = await wf_handle.result()
254-
assert result == "workflow result"
263+
assert result == "sync response"
255264

256265

257-
async def test_cancellation_of_async_response(client: Client):
266+
@pytest.mark.parametrize("request_cancel", [False, True])
267+
async def test_async_response(client: Client, request_cancel: bool):
258268
task_queue = str(uuid.uuid4())
259269
async with Worker(
260270
client,
@@ -268,9 +278,16 @@ async def test_cancellation_of_async_response(client: Client):
268278
await create_nexus_endpoint(task_queue, client)
269279

270280
# Start the caller workflow
281+
block_forever_waiting_for_cancellation = request_cancel
271282
wf_handle = await client.start_workflow(
272283
MyCallerWorkflow.run,
273-
args=[AsyncResponse(operation_workflow_id), True, task_queue],
284+
args=[
285+
AsyncResponse(
286+
operation_workflow_id, block_forever_waiting_for_cancellation
287+
),
288+
request_cancel,
289+
task_queue,
290+
],
274291
id=str(uuid.uuid4()),
275292
task_queue=task_queue,
276293
)
@@ -284,15 +301,23 @@ async def test_cancellation_of_async_response(client: Client):
284301
]
285302

286303
await wf_handle.signal(MyCallerWorkflow.proceed)
287-
# The caller workflow will now cancel the op_handle, and await it.
288304

289-
# TODO(dan): assert what type of exception is raised here
290-
with pytest.raises(BaseException) as ei:
291-
await wf_handle.result()
292-
e = ei.value
293-
print(f"🌈 workflow failed: {e.__class__.__name__}({e})")
294-
wf_details = await operation_workflow_handle.describe()
295-
assert wf_details.status == WorkflowExecutionStatus.CANCELED
305+
# The operation response was asynchronous and so request_cancel is honored. See
306+
# explanation above.
307+
if request_cancel:
308+
# The caller workflow now cancels the op_handle, and awaits it, resulting in a
309+
# CancellationError in the caller workflow.
310+
with pytest.raises(BaseException) as ei:
311+
await wf_handle.result()
312+
e = ei.value
313+
print(f"🌈 workflow failed: {e.__class__.__name__}({e})")
314+
wf_details = await operation_workflow_handle.describe()
315+
assert wf_details.status == WorkflowExecutionStatus.CANCELED
316+
else:
317+
wf_details = await operation_workflow_handle.describe()
318+
assert wf_details.status == WorkflowExecutionStatus.COMPLETED
319+
result = await wf_handle.result()
320+
assert result == "workflow result"
296321

297322

298323
def make_nexus_endpoint_name(task_queue: str) -> str:

0 commit comments

Comments
 (0)