Skip to content

Commit 8475d2c

Browse files
committed
Implement
1 parent ed89330 commit 8475d2c

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed

temporalio/activity.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
import temporalio.common
3838
import temporalio.converter
39+
from temporalio.client import Client
3940

4041
from .types import CallableType
4142

@@ -147,6 +148,7 @@ class _Context:
147148
temporalio.converter.PayloadConverter,
148149
]
149150
runtime_metric_meter: Optional[temporalio.common.MetricMeter]
151+
client: Client
150152
_logger_details: Optional[Mapping[str, Any]] = None
151153
_payload_converter: Optional[temporalio.converter.PayloadConverter] = None
152154
_metric_meter: Optional[temporalio.common.MetricMeter] = None
@@ -238,13 +240,25 @@ def wait_sync(self, timeout: Optional[float] = None) -> None:
238240
self.thread_event.wait(timeout)
239241

240242

243+
def client() -> Client:
244+
"""Return a Temporal Client for use in the current activity.
245+
246+
Returns:
247+
:py:class:`temporalio.client.Client` for use in the current activity.
248+
249+
Raises:
250+
RuntimeError: When not in an activity.
251+
"""
252+
return _Context.current().client
253+
254+
241255
def in_activity() -> bool:
242256
"""Whether the current code is inside an activity.
243257
244258
Returns:
245259
True if in an activity, False otherwise.
246260
"""
247-
return not _current_context.get(None) is None
261+
return _current_context.get(None) is not None
248262

249263

250264
def info() -> Info:

temporalio/worker/_activity.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(
6969
data_converter: temporalio.converter.DataConverter,
7070
interceptors: Sequence[Interceptor],
7171
metric_meter: temporalio.common.MetricMeter,
72+
client: temporalio.client.Client,
7273
) -> None:
7374
self._bridge_worker = bridge_worker
7475
self._task_queue = task_queue
@@ -84,6 +85,7 @@ def __init__(
8485
None
8586
)
8687
self._seen_sync_activity = False
88+
self._client = client
8789

8890
# Validate and build activity dict
8991
self._activities: Dict[str, temporalio.activity._Definition] = {}
@@ -428,13 +430,16 @@ async def _run_activity(
428430
heartbeat=None,
429431
cancelled_event=running_activity.cancelled_event,
430432
worker_shutdown_event=self._worker_shutdown_event,
431-
shield_thread_cancel_exception=None
432-
if not running_activity.cancel_thread_raiser
433-
else running_activity.cancel_thread_raiser.shielded,
433+
shield_thread_cancel_exception=(
434+
None
435+
if not running_activity.cancel_thread_raiser
436+
else running_activity.cancel_thread_raiser.shielded
437+
),
434438
payload_converter_class_or_instance=self._data_converter.payload_converter,
435-
runtime_metric_meter=None
436-
if sync_non_threaded
437-
else self._metric_meter,
439+
runtime_metric_meter=(
440+
None if sync_non_threaded else self._metric_meter
441+
),
442+
client=self._client,
438443
)
439444
)
440445
temporalio.activity.logger.debug("Starting activity")
@@ -692,6 +697,7 @@ async def heartbeat_with_context(*details: Any) -> None:
692697
worker_shutdown_event.thread_event,
693698
payload_converter_class_or_instance,
694699
ctx.runtime_metric_meter,
700+
ctx.client,
695701
input.fn,
696702
*input.args,
697703
]
@@ -739,6 +745,7 @@ def _execute_sync_activity(
739745
temporalio.converter.PayloadConverter,
740746
],
741747
runtime_metric_meter: Optional[temporalio.common.MetricMeter],
748+
client: temporalio.client.Client,
742749
fn: Callable[..., Any],
743750
*args: Any,
744751
) -> Any:
@@ -770,6 +777,7 @@ def _execute_sync_activity(
770777
else cancel_thread_raiser.shielded,
771778
payload_converter_class_or_instance=payload_converter_class_or_instance,
772779
runtime_metric_meter=runtime_metric_meter,
780+
client=client,
773781
)
774782
)
775783
return fn(*args)

temporalio/worker/_worker.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from ._activity import SharedStateManager, _ActivityWorker
3131
from ._interceptor import Interceptor
32-
from ._tuning import WorkerTuner, _to_bridge_slot_supplier
32+
from ._tuning import WorkerTuner
3333
from ._workflow import _WorkflowWorker
3434
from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner
3535
from .workflow_sandbox import SandboxedWorkflowRunner
@@ -303,6 +303,7 @@ def __init__(
303303
data_converter=client_config["data_converter"],
304304
interceptors=interceptors,
305305
metric_meter=self._runtime.metric_meter,
306+
client=client,
306307
)
307308
self._workflow_worker: Optional[_WorkflowWorker] = None
308309
if workflows:

0 commit comments

Comments
 (0)