29
29
from restate .exceptions import TerminalError
30
30
from restate .handler import Handler , handler_from_callable , invoke_handler
31
31
from restate .serde import BytesSerde , DefaultSerde , JsonSerde , Serde
32
- from restate .server_types import Receive , Send
32
+ from restate .server_types import ReceiveChannel , Send
33
33
from restate .vm import Failure , Invocation , NotReady , SuspendedException , VMWrapper , RunRetryConfig # pylint: disable=line-too-long
34
34
from restate .vm import DoProgressAnyCompleted , DoProgressCancelSignalReceived , DoProgressReadFromInput , DoProgressExecuteRun , DoWaitPendingRun
35
35
@@ -220,25 +220,6 @@ def peek(self) -> Awaitable[Any | None]:
220
220
# disable too many public method
221
221
# pylint: disable=R0904
222
222
223
- class SyncPoint :
224
- """
225
- This class implements a synchronization point.
226
- """
227
-
228
- def __init__ (self ) -> None :
229
- self .cond : asyncio .Event | None = None
230
-
231
- def awaiter (self ):
232
- """Wait for the sync point."""
233
- if self .cond is None :
234
- self .cond = asyncio .Event ()
235
- return self .cond .wait ()
236
-
237
- async def arrive (self ):
238
- """arrive at the sync point."""
239
- if self .cond is not None :
240
- self .cond .set ()
241
-
242
223
class Tasks :
243
224
"""
244
225
This class implements a list of tasks.
@@ -284,7 +265,8 @@ def __init__(self,
284
265
invocation : Invocation ,
285
266
attempt_headers : Dict [str , str ],
286
267
send : Send ,
287
- receive : Receive ) -> None :
268
+ receive : ReceiveChannel
269
+ ) -> None :
288
270
super ().__init__ ()
289
271
self .vm = vm
290
272
self .handler = handler
@@ -293,7 +275,6 @@ def __init__(self,
293
275
self .send = send
294
276
self .receive = receive
295
277
self .run_coros_to_execute : dict [int , Callable [[], Awaitable [None ]]] = {}
296
- self .sync_point = SyncPoint ()
297
278
self .request_finished_event = asyncio .Event ()
298
279
self .tasks = Tasks ()
299
280
@@ -365,18 +346,6 @@ def on_attempt_finished(self):
365
346
# ignore the cancelled error
366
347
pass
367
348
368
-
369
- async def receive_and_notify_input (self ):
370
- """Receive input from the state machine."""
371
- chunk = await self .receive ()
372
- if chunk .get ('type' ) == 'http.disconnect' :
373
- raise DisconnectedException ()
374
- if chunk .get ('body' , None ) is not None :
375
- assert isinstance (chunk ['body' ], bytes )
376
- self .vm .notify_input (chunk ['body' ])
377
- if not chunk .get ('more_body' , False ):
378
- self .vm .notify_input_closed ()
379
-
380
349
async def take_and_send_output (self ):
381
350
"""Take output from state machine and send it"""
382
351
output = self .vm .take_output ()
@@ -417,21 +386,22 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
417
386
async def wrapper (f ):
418
387
await f ()
419
388
await self .take_and_send_output ()
420
- await self .sync_point . arrive ( )
389
+ await self .receive . tx ({ 'type' : 'restate.run_completed' , 'body' : bytes (), 'more_body' : True } )
421
390
422
391
task = asyncio .create_task (wrapper (fn ))
423
392
self .tasks .add (task )
424
393
continue
425
394
if isinstance (do_progress_response , (DoWaitPendingRun , DoProgressReadFromInput )):
426
- sync_task = asyncio .create_task (self .sync_point .awaiter ())
427
- self .tasks .add (sync_task )
428
-
429
- read_task = asyncio .create_task (self .receive_and_notify_input ())
430
- self .tasks .add (read_task )
431
-
432
- done , _ = await asyncio .wait ([sync_task , read_task ], return_when = asyncio .FIRST_COMPLETED )
433
- if read_task in done :
434
- _ = read_task .result () # propagate exception
395
+ chunk = await self .receive ()
396
+ if chunk .get ('type' ) == 'restate.run_completed' :
397
+ continue
398
+ if chunk .get ('type' ) == 'http.disconnect' :
399
+ raise DisconnectedException ()
400
+ if chunk .get ('body' , None ) is not None :
401
+ assert isinstance (chunk ['body' ], bytes )
402
+ self .vm .notify_input (chunk ['body' ])
403
+ if not chunk .get ('more_body' , False ):
404
+ self .vm .notify_input_closed ()
435
405
436
406
def _create_fetch_result_coroutine (self , handle : int , serde : Serde [T ] | None = None ):
437
407
"""Create a coroutine that fetches a result from a notification handle."""
0 commit comments