15
15
import warnings
16
16
from abc import ABC , abstractmethod
17
17
from contextlib import contextmanager
18
- from dataclasses import dataclass
18
+ from dataclasses import dataclass , field
19
19
from datetime import datetime , timedelta , timezone
20
20
from typing import (
21
21
Any ,
@@ -216,7 +216,13 @@ def _cancel(
216
216
warnings .warn (f"Cannot find activity to cancel for token { task_token !r} " )
217
217
return
218
218
logger .debug ("Cancelling activity %s, reason: %s" , task_token , cancel .reason )
219
- activity .cancel (cancelled_by_request = True )
219
+ activity .cancellation_details .details = (
220
+ temporalio .activity .ActivityCancellationDetails ._from_proto (cancel .details )
221
+ )
222
+ activity .cancel (
223
+ cancelled_by_request = cancel .details .is_cancelled
224
+ or cancel .details .is_worker_shutdown
225
+ )
220
226
221
227
def _heartbeat (self , task_token : bytes , * details : Any ) -> None :
222
228
# We intentionally make heartbeating non-async, but since the data
@@ -303,6 +309,24 @@ async def _run_activity(
303
309
await self ._data_converter .encode_failure (
304
310
err , completion .result .failed .failure
305
311
)
312
+ elif (
313
+ isinstance (
314
+ err ,
315
+ (asyncio .CancelledError , temporalio .exceptions .CancelledError ),
316
+ )
317
+ and running_activity .cancellation_details .details
318
+ and running_activity .cancellation_details .details .paused
319
+ ):
320
+ temporalio .activity .logger .warning (
321
+ f"Completing as failure due to unhandled cancel error produced by activity pause" ,
322
+ )
323
+ await self ._data_converter .encode_failure (
324
+ temporalio .exceptions .ApplicationError (
325
+ type = "ActivityPause" ,
326
+ message = "Unhandled activity cancel error produced by activity pause" ,
327
+ ),
328
+ completion .result .failed .failure ,
329
+ )
306
330
elif (
307
331
isinstance (
308
332
err ,
@@ -336,7 +360,6 @@ async def _run_activity(
336
360
await self ._data_converter .encode_failure (
337
361
err , completion .result .failed .failure
338
362
)
339
-
340
363
# For broken executors, we have to fail the entire worker
341
364
if isinstance (err , concurrent .futures .BrokenExecutor ):
342
365
self ._fail_worker_exception_queue .put_nowait (err )
@@ -524,6 +547,7 @@ async def _execute_activity(
524
547
else running_activity .cancel_thread_raiser .shielded ,
525
548
payload_converter_class_or_instance = self ._data_converter .payload_converter ,
526
549
runtime_metric_meter = None if sync_non_threaded else self ._metric_meter ,
550
+ cancellation_details = running_activity .cancellation_details ,
527
551
)
528
552
)
529
553
temporalio .activity .logger .debug ("Starting activity" )
@@ -570,6 +594,9 @@ class _RunningActivity:
570
594
done : bool = False
571
595
cancelled_by_request : bool = False
572
596
cancelled_due_to_heartbeat_error : Optional [Exception ] = None
597
+ cancellation_details : temporalio .activity ._ActivityCancellationDetailsHolder = (
598
+ field (default_factory = temporalio .activity ._ActivityCancellationDetailsHolder )
599
+ )
573
600
574
601
def cancel (
575
602
self ,
@@ -659,6 +686,7 @@ async def execute_activity(self, input: ExecuteActivityInput) -> Any:
659
686
# can set the initializer on the executor).
660
687
ctx = temporalio .activity ._Context .current ()
661
688
info = ctx .info ()
689
+ cancellation_details = ctx .cancellation_details
662
690
663
691
# Heartbeat calls internally use a data converter which is async so
664
692
# they need to be called on the event loop
@@ -717,6 +745,7 @@ async def heartbeat_with_context(*details: Any) -> None:
717
745
worker_shutdown_event .thread_event ,
718
746
payload_converter_class_or_instance ,
719
747
ctx .runtime_metric_meter ,
748
+ cancellation_details ,
720
749
input .fn ,
721
750
* input .args ,
722
751
]
@@ -732,7 +761,6 @@ async def heartbeat_with_context(*details: Any) -> None:
732
761
finally :
733
762
if shared_manager :
734
763
await shared_manager .unregister_heartbeater (info .task_token )
735
-
736
764
# Otherwise for async activity, just run
737
765
return await input .fn (* input .args )
738
766
@@ -764,6 +792,7 @@ def _execute_sync_activity(
764
792
temporalio .converter .PayloadConverter ,
765
793
],
766
794
runtime_metric_meter : Optional [temporalio .common .MetricMeter ],
795
+ cancellation_details : temporalio .activity ._ActivityCancellationDetailsHolder ,
767
796
fn : Callable [..., Any ],
768
797
* args : Any ,
769
798
) -> Any :
@@ -795,6 +824,7 @@ def _execute_sync_activity(
795
824
else cancel_thread_raiser .shielded ,
796
825
payload_converter_class_or_instance = payload_converter_class_or_instance ,
797
826
runtime_metric_meter = runtime_metric_meter ,
827
+ cancellation_details = cancellation_details ,
798
828
)
799
829
)
800
830
return fn (* args )
0 commit comments