-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathhandler.py
189 lines (159 loc) · 6.11 KB
/
handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
#
# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#
# pylint: disable=R0917
"""
This module contains the definition of the Handler class,
which is used to define the handlers for the services.
"""
from dataclasses import dataclass
from inspect import Signature
from typing import Any, Callable, Awaitable, Generic, Literal, Optional, TypeVar
from restate.exceptions import TerminalError
from restate.serde import JsonSerde, Serde, PydanticJsonSerde
I = TypeVar('I')
O = TypeVar('O')
T = TypeVar('T')
# we will use this symbol to store the handler in the function
RESTATE_UNIQUE_HANDLER_SYMBOL = str(object())
def try_import_pydantic_base_model():
"""
Try to import PydanticBaseModel from Pydantic.
"""
try:
from pydantic import BaseModel # type: ignore # pylint: disable=import-outside-toplevel
return BaseModel
except ImportError:
class Dummy: # pylint: disable=too-few-public-methods
"""a dummy class to use when Pydantic is not available"""
return Dummy
PydanticBaseModel = try_import_pydantic_base_model()
@dataclass
class ServiceTag:
"""
This class is used to identify the service.
"""
kind: Literal["object", "service", "workflow"]
name: str
@dataclass
class TypeHint(Generic[T]):
"""
Represents a type hint.
"""
annotation: Optional[T] = None
is_pydantic: bool = False
@dataclass
class HandlerIO(Generic[I, O]):
"""
Represents the input/output configuration for a handler.
Attributes:
accept (str): The accept header value for the handler.
content_type (str): The content type header value for the handler.
"""
accept: str
content_type: str
input_serde: Serde[I]
output_serde: Serde[O]
input_type: Optional[TypeHint[I]] = None
output_type: Optional[TypeHint[O]] = None
def is_pydantic(annotation) -> bool:
"""
Check if an object is a Pydantic model.
"""
try:
return issubclass(annotation, PydanticBaseModel)
except TypeError:
# annotation is not a class or a type
return False
def extract_io_type_hints(handler_io: HandlerIO[I, O], signature: Signature):
"""
Augment handler_io with additional information about the input and output types.
This function has a special check for Pydantic models when these are provided.
This method will inspect the signature of an handler and will look for
the input and the return types of a function, and will:
* capture any Pydantic models (to be used later at discovery)
* replace the default json serializer (is unchanged by a user) with a Pydantic serde
"""
annotation = list(signature.parameters.values())[-1].annotation
handler_io.input_type = TypeHint(annotation=annotation, is_pydantic=False)
if is_pydantic(annotation):
handler_io.input_type.is_pydantic = True
if isinstance(handler_io.input_serde, JsonSerde): # type: ignore
handler_io.input_serde = PydanticJsonSerde(annotation)
annotation = signature.return_annotation
handler_io.output_type = TypeHint(annotation=annotation, is_pydantic=False)
if is_pydantic(annotation):
handler_io.output_type.is_pydantic=True
if isinstance(handler_io.output_serde, JsonSerde): # type: ignore
handler_io.output_serde = PydanticJsonSerde(annotation)
@dataclass
class Handler(Generic[I, O]):
"""
Represents a handler for a service.
"""
service_tag: ServiceTag
handler_io: HandlerIO[I, O]
kind: Optional[Literal["exclusive", "shared", "workflow"]]
name: str
fn: Callable[[Any, I], Awaitable[O]] | Callable[[Any], Awaitable[O]]
arity: int
# disable too many arguments warning
# pylint: disable=R0913
def make_handler(service_tag: ServiceTag,
handler_io: HandlerIO[I, O],
name: str | None,
kind: Optional[Literal["exclusive", "shared", "workflow"]],
wrapped: Any,
signature: Signature) -> Handler[I, O]:
"""
Factory function to create a handler.
"""
# try to deduce the handler name
handler_name = name
if not handler_name:
handler_name = wrapped.__name__
if not handler_name:
raise ValueError("Handler name must be provided")
if len(signature.parameters) == 0:
raise ValueError("Handler must have at least one parameter")
arity = len(signature.parameters)
extract_io_type_hints(handler_io, signature)
handler = Handler[I, O](service_tag,
handler_io,
kind,
handler_name,
wrapped,
arity)
vars(wrapped)[RESTATE_UNIQUE_HANDLER_SYMBOL] = handler
return handler
def handler_from_callable(wrapper: Callable[[Any, I], Awaitable[O]]) -> Handler[I, O]:
"""
Get the handler from the callable.
"""
try:
return vars(wrapper)[RESTATE_UNIQUE_HANDLER_SYMBOL]
except KeyError:
raise ValueError("Handler not found") # pylint: disable=raise-missing-from
async def invoke_handler(handler: Handler[I, O], ctx: Any, in_buffer: bytes) -> bytes:
"""
Invoke the handler with the given context and input.
"""
if handler.arity > 2:
raise ValueError(f"Expected num of args for handler {handler.name}: 1-2. Received: {handler.arity}")
elif handler.arity == 2:
try:
in_arg = handler.handler_io.input_serde.deserialize(in_buffer) # type: ignore
except Exception as e:
raise TerminalError(message=f"Unable to parse an input argument. {e}") from e
out_arg = await handler.fn(ctx, in_arg) # type: ignore [call-arg, arg-type]
else:
out_arg = await handler.fn(ctx) # type: ignore [call-arg]
out_buffer = handler.handler_io.output_serde.serialize(out_arg) # type: ignore
return bytes(out_buffer)