Skip to content

Commit 2aa17da

Browse files
authored
enhancement: Add Lifespan Events Handling In ASGI app (#113)
1 parent 2a6f139 commit 2aa17da

File tree

10 files changed

+614
-30
lines changed

10 files changed

+614
-30
lines changed

.github/workflows/unit-test.yaml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
name: Unit Tests
2+
3+
on:
4+
pull_request:
5+
workflow_dispatch:
6+
push:
7+
branches:
8+
- main
9+
10+
permissions:
11+
contents: read
12+
checks: write
13+
pull-requests: write
14+
15+
jobs:
16+
restate-sdk-unit-tests:
17+
name: "Run Unit Tests (Python ${{ matrix.python }})"
18+
runs-on: ubuntu-latest
19+
strategy:
20+
matrix:
21+
python: ["3.11", "3.12"]
22+
steps:
23+
- uses: actions/checkout@v4
24+
- uses: extractions/setup-just@v2
25+
- uses: actions/setup-python@v5
26+
with:
27+
python-version: ${{ matrix.python }}
28+
- name: Build Rust module
29+
uses: PyO3/maturin-action@v1
30+
with:
31+
args: --out dist --interpreter python${{ matrix.python }}
32+
sccache: "true"
33+
container: off
34+
- run: pip install -r requirements.txt
35+
- run: pip install dist/*
36+
- name: Run Unit Tests
37+
run: python -m pytest tests/ -v

examples/example.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,27 @@
1212
# pylint: disable=C0116
1313
# pylint: disable=W0613
1414

15+
import asyncio
1516
import restate
1617

18+
from concurrent_greeter import concurrent_greeter
1719
from greeter import greeter
20+
from pydantic_greeter import pydantic_greeter
1821
from virtual_object import counter
1922
from workflow import payment
20-
from pydantic_greeter import pydantic_greeter
21-
from concurrent_greeter import concurrent_greeter
23+
24+
2225

2326
app = restate.app(services=[greeter,
2427
counter,
2528
payment,
2629
pydantic_greeter,
27-
concurrent_greeter])
30+
concurrent_greeter,
31+
],)
2832

2933
if __name__ == "__main__":
3034
import hypercorn
3135
import hypercorn.asyncio
32-
import asyncio
3336
conf = hypercorn.Config()
3437
conf.bind = ["0.0.0.0:9080"]
3538
asyncio.run(hypercorn.asyncio.serve(app, conf))

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ Source = "https://github.com/restatedev/sdk-python"
2121
"Bug Tracker" = "https://github.com/restatedev/sdk-python/issues"
2222

2323
[project.optional-dependencies]
24-
test = ["pytest", "hypercorn"]
24+
test = ["pytest", "hypercorn", "pytest-asyncio==1.1.0"]
2525
lint = ["mypy", "pylint"]
2626
harness = ["testcontainers", "hypercorn", "httpx"]
2727
serde = ["dacite", "pydantic"]

python/restate/aws_lambda.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
Receive,
2222
HTTPResponseStartEvent,
2323
HTTPResponseBodyEvent,
24-
HTTPRequestEvent)
24+
HTTPRequestEvent, Send)
2525

2626
class RestateLambdaRequest(TypedDict):
2727
"""
@@ -162,8 +162,9 @@ def lambda_handler(event: RestateLambdaRequest, _context: Any) -> RestateLambdaR
162162
scope = create_scope(event)
163163
recv = request_to_receive(event)
164164
send = ResponseCollector()
165+
send_typed = cast(Send, send)
165166

166-
asgi_instance = asgi_app(scope, recv, send)
167+
asgi_instance = asgi_app(scope, recv, send_typed)
167168
asgi_task = loop.create_task(asgi_instance) # type: ignore[var-annotated, arg-type]
168169
loop.run_until_complete(asgi_task)
169170

python/restate/context.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,8 @@ def request(self) -> Request:
220220
"""
221221

222222

223-
@typing_extensions.deprecated("`run` is deprecated, use `run_typed` instead for better type safety")
224223
@overload
224+
@typing_extensions.deprecated("`run` is deprecated, use `run_typed` instead for better type safety")
225225
@abc.abstractmethod
226226
def run(self,
227227
name: str,
@@ -250,8 +250,8 @@ def run(self,
250250
251251
"""
252252

253-
@typing_extensions.deprecated("`run` is deprecated, use `run_typed` instead for better type safety")
254253
@overload
254+
@typing_extensions.deprecated("`run` is deprecated, use `run_typed` instead for better type safety")
255255
@abc.abstractmethod
256256
def run(self,
257257
name: str,
@@ -280,6 +280,7 @@ def run(self,
280280
281281
"""
282282

283+
@overload
283284
@typing_extensions.deprecated("`run` is deprecated, use `run_typed` instead for better type safety")
284285
@abc.abstractmethod
285286
def run(self,

python/restate/endpoint.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import typing
1717

18+
from restate.server_types import LifeSpan
1819
from restate.service import Service
1920
from restate.object import VirtualObject
2021
from restate.workflow import Workflow
@@ -80,7 +81,7 @@ def identity_key(self, identity_key: str):
8081
"""Add an identity key to this endpoint."""
8182
self.identity_keys.append(identity_key)
8283

83-
def app(self):
84+
def app(self, lifespan: typing.Optional[LifeSpan] = None):
8485
"""
8586
Returns the ASGI application for this endpoint.
8687
@@ -94,12 +95,13 @@ def app(self):
9495
# pylint: disable=C0415
9596
# pylint: disable=R0401
9697
from restate.server import asgi_app
97-
return asgi_app(self)
98+
return asgi_app(self, lifespan)
9899

99100
def app(
100101
services: typing.Iterable[typing.Union[Service, VirtualObject, Workflow]],
101102
protocol: typing.Optional[typing.Literal["bidi", "request_response"]] = None,
102-
identity_keys: typing.Optional[typing.List[str]] = None):
103+
identity_keys: typing.Optional[typing.List[str]] = None,
104+
lifespan: typing.Optional[LifeSpan] = None):
103105
"""A restate ASGI application that hosts the given services."""
104106
endpoint = Endpoint()
105107
if protocol == "bidi":
@@ -111,4 +113,4 @@ def app(
111113
if identity_keys:
112114
for key in identity_keys:
113115
endpoint.identity_key(key)
114-
return endpoint.app()
116+
return endpoint.app(lifespan)

python/restate/server.py

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111
"""This module contains the ASGI server for the restate framework."""
1212

1313
import asyncio
14-
from typing import Dict, TypedDict, Literal
14+
import inspect
15+
from typing import Any, Dict, TypedDict, Literal
1516
import traceback
17+
import typing
1618
from restate.discovery import compute_discovery_json
1719
from restate.endpoint import Endpoint
1820
from restate.server_context import ServerInvocationContext, DisconnectedException
19-
from restate.server_types import Receive, ReceiveChannel, Scope, Send, binary_to_header, header_to_binary # pylint: disable=line-too-long
21+
from restate.server_types import Receive, ReceiveChannel, Scope, Send, binary_to_header, header_to_binary, LifeSpan # pylint: disable=line-too-long
2022
from restate.vm import VMWrapper
2123
from restate._internal import PyIdentityVerifier, IdentityVerificationException # pylint: disable=import-error,no-name-in-module
2224
from restate._internal import SDK_VERSION # pylint: disable=import-error,no-name-in-module
@@ -186,10 +188,6 @@ async def process_invocation_to_completion(vm: VMWrapper,
186188
finally:
187189
context.on_attempt_finished()
188190

189-
class LifeSpanNotImplemented(ValueError):
190-
"""Signal to the asgi server that we didn't implement lifespans"""
191-
192-
193191
class ParsedPath(TypedDict):
194192
"""Parsed path from the request."""
195193
type: Literal["invocation", "health", "discover", "unknown"]
@@ -216,8 +214,55 @@ def parse_path(request: str) -> ParsedPath:
216214
# anything other than invoke is 404
217215
return { "type": "unknown" , "service": None, "handler": None }
218216

217+
def is_async_context_manager(obj: Any):
218+
"""check if passed object is an async context manager"""
219+
return (hasattr(obj, '__aenter__') and
220+
hasattr(obj, '__aexit__') and
221+
inspect.iscoroutinefunction(obj.__aenter__) and
222+
inspect.iscoroutinefunction(obj.__aexit__))
219223

220-
def asgi_app(endpoint: Endpoint):
224+
225+
async def lifespan_processor(
226+
scope: Scope,
227+
receive: Receive,
228+
send: Send,
229+
lifespan: LifeSpan
230+
) -> None:
231+
"""Process lifespan context manager."""
232+
started = False
233+
assert scope["type"] in ["lifespan", "lifespan.startup", "lifespan.shutdown"]
234+
assert is_async_context_manager(lifespan()), "lifespan must be an async context manager"
235+
await receive()
236+
try:
237+
async with lifespan() as maybe_state:
238+
if maybe_state is not None:
239+
if "state" not in scope:
240+
raise RuntimeError("The server does not support state in lifespan")
241+
scope["state"] = maybe_state
242+
await send({
243+
"type": "lifespan.startup.complete", # type: ignore
244+
})
245+
started = True
246+
await receive()
247+
except Exception:
248+
exc_text = traceback.format_exc()
249+
if started:
250+
await send({
251+
"type": "lifespan.shutdown.failed",
252+
"message": exc_text
253+
})
254+
else:
255+
await send({
256+
"type": "lifespan.startup.failed",
257+
"message": exc_text
258+
})
259+
raise
260+
await send({
261+
"type": "lifespan.shutdown.complete" # type: ignore
262+
})
263+
264+
# pylint: disable=too-many-return-statements
265+
def asgi_app(endpoint: Endpoint, lifespan: typing.Optional[LifeSpan] = None):
221266
"""Create an ASGI-3 app for the given endpoint."""
222267

223268
# Prepare request signer
@@ -226,14 +271,17 @@ def asgi_app(endpoint: Endpoint):
226271
async def app(scope: Scope, receive: Receive, send: Send):
227272
try:
228273
if scope['type'] == 'lifespan':
229-
raise LifeSpanNotImplemented()
274+
if lifespan is not None:
275+
await lifespan_processor(scope, receive, send, lifespan)
276+
return
277+
return
278+
230279
if scope['type'] != 'http':
231280
raise NotImplementedError(f"Unknown scope type {scope['type']}")
232281

233282
request_path = scope['path']
234283
assert isinstance(request_path, str)
235284
request: ParsedPath = parse_path(request_path)
236-
237285
# Health check
238286
if request['type'] == 'health':
239287
await send_health_check(send)
@@ -249,7 +297,6 @@ async def app(scope: Scope, receive: Receive, send: Send):
249297
# Identify verification failed, send back unauthorized and close
250298
await send_status(send, receive, 401)
251299
return
252-
253300
# might be a discovery request
254301
if request['type'] == 'discover':
255302
await send_discovery(scope, send, endpoint)
@@ -283,8 +330,6 @@ async def app(scope: Scope, receive: Receive, send: Send):
283330
send)
284331
finally:
285332
await receive_channel.close()
286-
except LifeSpanNotImplemented as e:
287-
raise e
288333
except Exception as e:
289334
traceback.print_exc()
290335
raise e

python/restate/server_types.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
import asyncio
1818
from typing import (Awaitable, Callable, Dict, Iterable, List,
19-
Tuple, Union, TypedDict, Literal, Optional, NotRequired, Any)
19+
Tuple, Union, TypedDict, Literal, Optional,
20+
NotRequired, Any, AsyncContextManager)
2021

2122
class ASGIVersions(TypedDict):
2223
"""ASGI Versions"""
@@ -25,7 +26,7 @@ class ASGIVersions(TypedDict):
2526

2627
class Scope(TypedDict):
2728
"""ASGI Scope"""
28-
type: Literal["http"]
29+
type: Literal["http", "lifespan", "lifespan.startup", "lifespan.shutdown"]
2930
asgi: ASGIVersions
3031
http_version: str
3132
method: str
@@ -64,18 +65,29 @@ class HTTPResponseBodyEvent(TypedDict):
6465
body: bytes
6566
more_body: bool
6667

68+
class LifeSpanEvent(TypedDict):
69+
"""ASGI LifeSpan event"""
70+
type: Literal["lifespan.startup.complete",
71+
"lifespan.shutdown.complete",
72+
"lifespan.startup.failed",
73+
"lifespan.shutdown.failed"]
74+
message: Optional[str]
6775

68-
ASGIReceiveEvent = HTTPRequestEvent
76+
77+
ASGIReceiveEvent = Union[HTTPRequestEvent, LifeSpanEvent]
6978

7079

7180
ASGISendEvent = Union[
7281
HTTPResponseStartEvent,
73-
HTTPResponseBodyEvent
82+
HTTPResponseBodyEvent,
83+
LifeSpanEvent
7484
]
7585

7686
Receive = Callable[[], Awaitable[ASGIReceiveEvent]]
7787
Send = Callable[[ASGISendEvent], Awaitable[None]]
7888

89+
LifeSpan = Callable[[], AsyncContextManager[Any]]
90+
7991
ASGIApp = Callable[
8092
[
8193
Scope,
@@ -84,7 +96,6 @@ class HTTPResponseBodyEvent(TypedDict):
8496
],
8597
Awaitable[None],
8698
]
87-
8899
def header_to_binary(headers: Iterable[Tuple[str, str]]) -> List[Tuple[bytes, bytes]]:
89100
"""Convert a list of headers to a list of binary headers."""
90101
return [ (k.encode('utf-8'), v.encode('utf-8')) for k,v in headers ]

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ pydantic
77
httpx
88
testcontainers
99
typing-extensions>=4.14.0
10+
pytest-asyncio==1.1.0

0 commit comments

Comments
 (0)