Skip to content

Commit 0d1521d

Browse files
authored
Make sure to wait for a pending side effect to flush (#89)
1 parent aaedf2f commit 0d1521d

File tree

3 files changed

+68
-53
lines changed

3 files changed

+68
-53
lines changed

python/restate/server.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from restate.discovery import compute_discovery_json
1717
from restate.endpoint import Endpoint
1818
from restate.server_context import ServerInvocationContext, DisconnectedException
19-
from restate.server_types import Receive, Scope, Send, binary_to_header, header_to_binary
19+
from restate.server_types import Receive, ReceiveChannel, Scope, Send, binary_to_header, header_to_binary # pylint: disable=line-too-long
2020
from restate.vm import VMWrapper
2121
from restate._internal import PyIdentityVerifier, IdentityVerificationException # pylint: disable=import-error,no-name-in-module
2222
from restate._internal import SDK_VERSION # pylint: disable=import-error,no-name-in-module
@@ -85,7 +85,7 @@ async def send_health_check(send: Send):
8585
async def process_invocation_to_completion(vm: VMWrapper,
8686
handler,
8787
attempt_headers: Dict[str, str],
88-
receive: Receive,
88+
receive: ReceiveChannel,
8989
send: Send):
9090
"""Invoke the user code."""
9191
status, res_headers = vm.get_response_head()
@@ -171,6 +171,7 @@ def parse_path(request: str) -> ParsedPath:
171171
# anything other than invoke is 404
172172
return { "type": "unknown" , "service": None, "handler": None }
173173

174+
174175
def asgi_app(endpoint: Endpoint):
175176
"""Create an ASGI-3 app for the given endpoint."""
176177

@@ -201,7 +202,7 @@ async def app(scope: Scope, receive: Receive, send: Send):
201202
identity_verifier.verify(request_headers, request_path)
202203
except IdentityVerificationException:
203204
# Identify verification failed, send back unauthorized and close
204-
await send_status(send, receive,401)
205+
await send_status(send, receive, 401)
205206
return
206207

207208
# might be a discovery request
@@ -228,11 +229,15 @@ async def app(scope: Scope, receive: Receive, send: Send):
228229
# At this point we have a valid handler.
229230
# Let us setup restate's execution context for this invocation and handler.
230231
#
231-
await process_invocation_to_completion(VMWrapper(request_headers),
232-
handler,
233-
dict(request_headers),
234-
receive,
235-
send)
232+
receive_channel = ReceiveChannel(receive)
233+
try:
234+
await process_invocation_to_completion(VMWrapper(request_headers),
235+
handler,
236+
dict(request_headers),
237+
receive_channel,
238+
send)
239+
finally:
240+
await receive_channel.close()
236241
except LifeSpanNotImplemented as e:
237242
raise e
238243
except Exception as e:

python/restate/server_context.py

Lines changed: 14 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from restate.exceptions import TerminalError
3030
from restate.handler import Handler, handler_from_callable, invoke_handler
3131
from restate.serde import BytesSerde, DefaultSerde, JsonSerde, Serde
32-
from restate.server_types import Receive, Send
32+
from restate.server_types import ReceiveChannel, Send
3333
from restate.vm import Failure, Invocation, NotReady, SuspendedException, VMWrapper, RunRetryConfig # pylint: disable=line-too-long
3434
from restate.vm import DoProgressAnyCompleted, DoProgressCancelSignalReceived, DoProgressReadFromInput, DoProgressExecuteRun, DoWaitPendingRun
3535

@@ -220,25 +220,6 @@ def peek(self) -> Awaitable[Any | None]:
220220
# disable too many public method
221221
# pylint: disable=R0904
222222

223-
class SyncPoint:
224-
"""
225-
This class implements a synchronization point.
226-
"""
227-
228-
def __init__(self) -> None:
229-
self.cond: asyncio.Event | None = None
230-
231-
def awaiter(self):
232-
"""Wait for the sync point."""
233-
if self.cond is None:
234-
self.cond = asyncio.Event()
235-
return self.cond.wait()
236-
237-
async def arrive(self):
238-
"""arrive at the sync point."""
239-
if self.cond is not None:
240-
self.cond.set()
241-
242223
class Tasks:
243224
"""
244225
This class implements a list of tasks.
@@ -284,7 +265,8 @@ def __init__(self,
284265
invocation: Invocation,
285266
attempt_headers: Dict[str, str],
286267
send: Send,
287-
receive: Receive) -> None:
268+
receive: ReceiveChannel
269+
) -> None:
288270
super().__init__()
289271
self.vm = vm
290272
self.handler = handler
@@ -293,7 +275,6 @@ def __init__(self,
293275
self.send = send
294276
self.receive = receive
295277
self.run_coros_to_execute: dict[int, Callable[[], Awaitable[None]]] = {}
296-
self.sync_point = SyncPoint()
297278
self.request_finished_event = asyncio.Event()
298279
self.tasks = Tasks()
299280

@@ -365,18 +346,6 @@ def on_attempt_finished(self):
365346
# ignore the cancelled error
366347
pass
367348

368-
369-
async def receive_and_notify_input(self):
370-
"""Receive input from the state machine."""
371-
chunk = await self.receive()
372-
if chunk.get('type') == 'http.disconnect':
373-
raise DisconnectedException()
374-
if chunk.get('body', None) is not None:
375-
assert isinstance(chunk['body'], bytes)
376-
self.vm.notify_input(chunk['body'])
377-
if not chunk.get('more_body', False):
378-
self.vm.notify_input_closed()
379-
380349
async def take_and_send_output(self):
381350
"""Take output from state machine and send it"""
382351
output = self.vm.take_output()
@@ -417,21 +386,22 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
417386
async def wrapper(f):
418387
await f()
419388
await self.take_and_send_output()
420-
await self.sync_point.arrive()
389+
await self.receive.tx({ 'type' : 'restate.run_completed', 'body' : bytes(), 'more_body' : True})
421390

422391
task = asyncio.create_task(wrapper(fn))
423392
self.tasks.add(task)
424393
continue
425394
if isinstance(do_progress_response, (DoWaitPendingRun, DoProgressReadFromInput)):
426-
sync_task = asyncio.create_task(self.sync_point.awaiter())
427-
self.tasks.add(sync_task)
428-
429-
read_task = asyncio.create_task(self.receive_and_notify_input())
430-
self.tasks.add(read_task)
431-
432-
done, _ = await asyncio.wait([sync_task, read_task], return_when=asyncio.FIRST_COMPLETED)
433-
if read_task in done:
434-
_ = read_task.result() # propagate exception
395+
chunk = await self.receive()
396+
if chunk.get('type') == 'restate.run_completed':
397+
continue
398+
if chunk.get('type') == 'http.disconnect':
399+
raise DisconnectedException()
400+
if chunk.get('body', None) is not None:
401+
assert isinstance(chunk['body'], bytes)
402+
self.vm.notify_input(chunk['body'])
403+
if not chunk.get('more_body', False):
404+
self.vm.notify_input_closed()
435405

436406
def _create_fetch_result_coroutine(self, handle: int, serde: Serde[T] | None = None):
437407
"""Create a coroutine that fetches a result from a notification handle."""

python/restate/server_types.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
:see: https://github.com/django/asgiref/blob/main/asgiref/typing.py
1515
"""
1616

17+
import asyncio
1718
from typing import (Awaitable, Callable, Dict, Iterable, List,
1819
Tuple, Union, TypedDict, Literal, Optional, NotRequired, Any)
1920

@@ -41,7 +42,7 @@ class Scope(TypedDict):
4142

4243
class HTTPRequestEvent(TypedDict):
4344
"""ASGI Request event"""
44-
type: Literal["http.request"]
45+
type: Literal["http.request", "restate.run_completed"]
4546
body: bytes
4647
more_body: bool
4748

@@ -86,3 +87,42 @@ def header_to_binary(headers: Iterable[Tuple[str, str]]) -> List[Tuple[bytes, by
8687
def binary_to_header(headers: Iterable[Tuple[bytes, bytes]]) -> List[Tuple[str, str]]:
8788
"""Convert a list of binary headers to a list of headers."""
8889
return [ (k.decode('utf-8'), v.decode('utf-8')) for k,v in headers ]
90+
91+
class ReceiveChannel:
92+
"""ASGI receive channel."""
93+
94+
def __init__(self, receive: Receive):
95+
self.queue = asyncio.Queue[ASGIReceiveEvent]()
96+
97+
async def loop():
98+
"""Receive loop."""
99+
while True:
100+
event = await receive()
101+
await self.queue.put(event)
102+
if event.get('type') == 'http.disconnect':
103+
break
104+
105+
self.task = asyncio.create_task(loop())
106+
107+
async def rx(self) -> ASGIReceiveEvent:
108+
"""Get the next message."""
109+
what = await self.queue.get()
110+
self.queue.task_done()
111+
return what
112+
113+
async def __call__(self):
114+
"""Get the next message."""
115+
return await self.rx()
116+
117+
async def tx(self, what: ASGIReceiveEvent):
118+
"""Add a message."""
119+
await self.queue.put(what)
120+
121+
async def close(self):
122+
"""Close the channel."""
123+
if self.task and not self.task.done():
124+
self.task.cancel()
125+
try:
126+
await self.task
127+
except asyncio.CancelledError:
128+
pass

0 commit comments

Comments
 (0)