Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ctx result consumption #319

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tractor/_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ async def open_root_actor(
logger.cancel("Shutting down root actor")
await actor.cancel()
finally:
_state._current_actor = None
logger.runtime("Root actor terminated")


Expand Down
4 changes: 2 additions & 2 deletions tractor/_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ async def _invoke(
except BaseExceptionGroup:
# if a context error was set then likely
# thei multierror was raised due to that
if ctx._error is not None:
raise ctx._error from None
if ctx._remote_ctx_error is not None:
raise ctx._remote_ctx_error from None

raise

Expand Down
103 changes: 72 additions & 31 deletions tractor/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
Optional,
Callable,
AsyncGenerator,
AsyncIterator
AsyncIterator,
TYPE_CHECKING,
)

import warnings
Expand All @@ -41,6 +42,10 @@
from .trionics import broadcast_receiver, BroadcastReceiver


if TYPE_CHECKING:
from ._portal import Portal


log = get_logger(__name__)


Expand Down Expand Up @@ -70,7 +75,7 @@ class MsgStream(trio.abc.Channel):
'''
def __init__(
self,
ctx: 'Context', # typing: ignore # noqa
ctx: Context, # typing: ignore # noqa
rx_chan: trio.MemoryReceiveChannel,
_broadcaster: Optional[BroadcastReceiver] = None,

Expand All @@ -83,6 +88,9 @@ def __init__(
self._eoc: bool = False
self._closed: bool = False

def ctx(self) -> Context:
return self._ctx

# delegate directly to underlying mem channel
def receive_nowait(self):
msg = self._rx_chan.receive_nowait()
Expand Down Expand Up @@ -278,7 +286,6 @@ async def aclose(self):
@asynccontextmanager
async def subscribe(
self,

) -> AsyncIterator[BroadcastReceiver]:
'''
Allocate and return a ``BroadcastReceiver`` which delegates
Expand Down Expand Up @@ -335,8 +342,8 @@ async def send(
Send a message over this stream to the far end.

'''
if self._ctx._error:
raise self._ctx._error # from None
if self._ctx._remote_ctx_error:
raise self._ctx._remote_ctx_error # from None

if self._closed:
raise trio.ClosedResourceError('This stream was already closed')
Expand Down Expand Up @@ -375,9 +382,10 @@ class Context:
_remote_func_type: Optional[str] = None

# only set on the caller side
_portal: Optional['Portal'] = None # type: ignore # noqa
_portal: Optional[Portal] = None # type: ignore # noqa
_stream: Optional[MsgStream] = None
_result: Optional[Any] = False
_error: Optional[BaseException] = None
_remote_ctx_error: Optional[BaseException] = None

# status flags
_cancel_called: bool = False
Expand All @@ -390,7 +398,7 @@ class Context:
# only set on the callee side
_scope_nursery: Optional[trio.Nursery] = None

_backpressure: bool = False
_backpressure: bool = True

async def send_yield(self, data: Any) -> None:

Expand Down Expand Up @@ -435,29 +443,34 @@ async def _maybe_raise_from_remote_msg(
# (currently) that other portal APIs (``Portal.run()``,
# ``.run_in_actor()``) do their own error checking at the point
# of the call and result processing.
log.error(
f'Remote context error for {self.chan.uid}:{self.cid}:\n'
f'{msg["error"]["tb_str"]}'
)
error = unpack_error(msg, self.chan)
if (
isinstance(error, ContextCancelled) and
self._cancel_called
isinstance(error, ContextCancelled)
):
# this is an expected cancel request response message
# and we don't need to raise it in scope since it will
# potentially override a real error
return
log.cancel(
f'Remote context error for {self.chan.uid}:{self.cid}:\n'
f'{msg["error"]["tb_str"]}'
)
if self._cancel_called:
# this is an expected cancel request response message
# and we don't need to raise it in scope since it will
# potentially override a real error
return
else:
log.error(
f'Remote context error for {self.chan.uid}:{self.cid}:\n'
f'{msg["error"]["tb_str"]}'
)

self._error = error
self._remote_ctx_error = error

# TODO: tempted to **not** do this by-reraising in a
# nursery and instead cancel a surrounding scope, detect
# the cancellation, then lookup the error that was set?
if self._scope_nursery:

async def raiser():
raise self._error from None
raise self._remote_ctx_error from None

# from trio.testing import wait_all_tasks_blocked
# await wait_all_tasks_blocked()
Expand All @@ -483,6 +496,7 @@ async def cancel(
log.cancel(f'Cancelling {side} side of context to {self.chan.uid}')

self._cancel_called = True
ipc_broken: bool = False

if side == 'caller':
if not self._portal:
Expand All @@ -500,7 +514,14 @@ async def cancel(
# NOTE: we're telling the far end actor to cancel a task
# corresponding to *this actor*. The far end local channel
# instance is passed to `Actor._cancel_task()` implicitly.
await self._portal.run_from_ns('self', '_cancel_task', cid=cid)
try:
await self._portal.run_from_ns(
'self',
'_cancel_task',
cid=cid,
)
except trio.BrokenResourceError:
ipc_broken = True

if cs.cancelled_caught:
# XXX: there's no way to know if the remote task was indeed
Expand All @@ -516,7 +537,10 @@ async def cancel(
"Timed out on cancelling remote task "
f"{cid} for {self._portal.channel.uid}")

# callee side remote task
elif ipc_broken:
log.cancel(
"Transport layer was broken before cancel request "
f"{cid} for {self._portal.channel.uid}")
else:
self._cancel_msg = msg

Expand Down Expand Up @@ -604,6 +628,7 @@ async def open_stream(
ctx=self,
rx_chan=ctx._recv_chan,
) as stream:
self._stream = stream

if self._portal:
self._portal._streams.add(stream)
Expand Down Expand Up @@ -645,25 +670,22 @@ async def result(self) -> Any:

if not self._recv_chan._closed: # type: ignore

# wait for a final context result consuming
# and discarding any bi dir stream msgs still
# in transit from the far end.
while True:
def consume(
msg: dict,

msg = await self._recv_chan.receive()
) -> Optional[dict]:
try:
self._result = msg['return']
break
return msg['return']
except KeyError as msgerr:

if 'yield' in msg:
# far end task is still streaming to us so discard
log.warning(f'Discarding stream delivered {msg}')
continue
return

elif 'stop' in msg:
log.debug('Remote stream terminated')
continue
return

# internal error should never get here
assert msg.get('cid'), (
Expand All @@ -673,6 +695,25 @@ async def result(self) -> Any:
msg, self._portal.channel
) from msgerr

# wait for a final context result consuming
# and discarding any bi dir stream msgs still
# in transit from the far end.
if self._stream:
async with self._stream.subscribe() as bstream:
async for msg in bstream:
result = consume(msg)
if result:
self._result = result
break

if not self._result:
while True:
msg = await self._recv_chan.receive()
result = consume(msg)
if result:
self._result = result
break

return self._result

async def started(
Expand Down