-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
494 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from flama.telemetry.data_structures import * # noqa | ||
from flama.telemetry.middleware import * # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import dataclasses | ||
import datetime | ||
import logging | ||
import typing as t | ||
from http.cookies import SimpleCookie | ||
|
||
from flama import Flama, types | ||
from flama.authentication.types import AccessToken, RefreshToken | ||
from flama.exceptions import HTTPException | ||
from flama.http import Request as HTTPRequest | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
__all__ = ["Endpoint", "Authentication", "Request", "Response", "Error", "TelemetryData"] | ||
|
||
|
||
@dataclasses.dataclass | ||
class Endpoint: | ||
path: str | ||
name: t.Optional[str] | ||
tags: dict[str, t.Any] | ||
|
||
@classmethod | ||
async def from_scope(cls, *, scope: types.Scope, receive: types.Receive, send: types.Send) -> "Endpoint": | ||
app: Flama = scope["app"] | ||
|
||
route, _ = app.router.resolve_route(scope) | ||
|
||
return cls(path=str(route.path), name=route.name, tags=route.tags) | ||
|
||
def to_dict(self) -> dict[str, t.Any]: | ||
return {"path": self.path, "name": self.name, "tags": self.tags} | ||
|
||
|
||
@dataclasses.dataclass | ||
class Authentication: | ||
access: t.Optional[AccessToken] | ||
refresh: t.Optional[RefreshToken] | ||
|
||
@classmethod | ||
async def from_scope(cls, *, scope: types.Scope, receive: types.Receive, send: types.Send) -> "Authentication": | ||
app: Flama = scope["app"] | ||
context = {"scope": scope, "request": HTTPRequest(scope, receive=receive)} | ||
|
||
try: | ||
access = await app.injector.resolve(AccessToken).value(context) | ||
except Exception: | ||
access = None | ||
|
||
try: | ||
refresh = await app.injector.resolve(RefreshToken).value(context) | ||
except Exception: | ||
refresh = None | ||
|
||
return cls(access=access, refresh=refresh) | ||
|
||
def to_dict(self) -> dict[str, t.Any]: | ||
return {"access": self.access, "refresh": self.refresh} | ||
|
||
|
||
@dataclasses.dataclass | ||
class Request: | ||
headers: dict[str, t.Any] | ||
cookies: dict[str, t.Any] | ||
query_parameters: dict[str, t.Any] | ||
path_parameters: dict[str, t.Any] | ||
body: bytes = b"" | ||
timestamp: datetime.datetime = dataclasses.field( | ||
init=False, default_factory=lambda: datetime.datetime.now(datetime.timezone.utc) | ||
) | ||
|
||
@classmethod | ||
async def from_scope(cls, *, scope: types.Scope, receive: types.Receive, send: types.Send) -> "Request": | ||
app: Flama = scope["app"] | ||
context = {"scope": scope, "request": HTTPRequest(scope, receive=receive)} | ||
|
||
headers = dict(await app.injector.resolve(types.Headers).value(context)) | ||
cookies = dict(await app.injector.resolve(types.Cookies).value(context)) | ||
query = dict(await app.injector.resolve(types.QueryParams).value(context)) | ||
path = dict(await app.injector.resolve(types.PathParams).value(context)) | ||
|
||
return cls(headers=headers, cookies=cookies, query_parameters=query, path_parameters=path) | ||
|
||
def to_dict(self) -> dict[str, t.Any]: | ||
return { | ||
"timestamp": self.timestamp.isoformat(), | ||
"headers": self.headers, | ||
"cookies": self.headers, | ||
"query_parameters": self.query_parameters, | ||
"path_parameters": self.path_parameters, | ||
"body": self.body, | ||
} | ||
|
||
|
||
@dataclasses.dataclass | ||
class Response: | ||
headers: t.Optional[dict[str, t.Any]] | ||
cookies: t.Optional[dict[str, t.Any]] = dataclasses.field(init=False) | ||
body: bytes = b"" | ||
status_code: t.Optional[int] = None | ||
timestamp: datetime.datetime = dataclasses.field( | ||
init=False, default_factory=lambda: datetime.datetime.now(datetime.timezone.utc) | ||
) | ||
|
||
def __post_init__(self): | ||
if self.headers: | ||
cookie = SimpleCookie() | ||
cookie.load(self.headers.get("cookie", "")) | ||
self.cookies = { | ||
str(name): {**{str(k): str(v) for k, v in morsel.items()}, "value": morsel.value} | ||
for name, morsel in cookie.items() | ||
} | ||
|
||
def to_dict(self) -> dict[str, t.Any]: | ||
return { | ||
"timestamp": self.timestamp.isoformat(), | ||
"headers": self.headers, | ||
"cookies": self.headers, | ||
"body": self.body, | ||
"status_code": self.status_code, | ||
} | ||
|
||
|
||
@dataclasses.dataclass | ||
class Error: | ||
detail: str | ||
status_code: t.Optional[int] = None | ||
timestamp: datetime.datetime = dataclasses.field( | ||
init=False, default_factory=lambda: datetime.datetime.now(datetime.timezone.utc) | ||
) | ||
|
||
@classmethod | ||
async def from_exception(cls, *, exception: Exception) -> "Error": | ||
if isinstance(exception, HTTPException): | ||
return cls(status_code=exception.status_code, detail=str(exception.detail)) | ||
|
||
return cls(detail=str(exception)) | ||
|
||
def to_dict(self) -> dict[str, t.Any]: | ||
return {"timestamp": self.timestamp.isoformat(), "detail": self.detail, "status_code": self.status_code} | ||
|
||
|
||
@dataclasses.dataclass | ||
class TelemetryData: | ||
type: t.Literal["http", "websocket"] | ||
endpoint: Endpoint | ||
authentication: Authentication | ||
request: Request | ||
response: t.Optional[Response] = None | ||
error: t.Optional[Error] = None | ||
extra: dict[t.Any, t.Any] = dataclasses.field(default_factory=dict) | ||
|
||
@classmethod | ||
async def from_scope(cls, *, scope: types.Scope, receive: types.Receive, send: types.Send) -> "TelemetryData": | ||
return cls( | ||
type=scope["type"], | ||
endpoint=await Endpoint.from_scope(scope=scope, receive=receive, send=send), | ||
authentication=await Authentication.from_scope(scope=scope, receive=receive, send=send), | ||
request=await Request.from_scope(scope=scope, receive=receive, send=send), | ||
) | ||
|
||
def to_dict(self) -> dict[str, t.Any]: | ||
return { | ||
"type": self.type, | ||
"endpoint": self.endpoint.to_dict(), | ||
"authentication": self.authentication.to_dict(), | ||
"request": self.request.to_dict(), | ||
"response": self.response.to_dict() if self.response else None, | ||
"error": self.error.to_dict() if self.error else None, | ||
"extra": self.extra, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import abc | ||
import logging | ||
import typing as t | ||
|
||
from flama import Flama, concurrency, types | ||
from flama.telemetry.data_structures import Error, Response, TelemetryData | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
__all__ = ["TelemetryMiddleware"] | ||
|
||
|
||
PROJECT = "vortico-core" | ||
SERVICE = "elektrococo" | ||
TOPIC_ID = "telemetry-bus" | ||
|
||
HookFunction = t.Callable[[TelemetryData], t.Union[None, t.Awaitable[None]]] | ||
|
||
|
||
class Wrapper(abc.ABC): | ||
def __init__(self, app: Flama, data: TelemetryData) -> None: | ||
self.app = app | ||
self.data = data | ||
|
||
@classmethod | ||
def build(cls, type: t.Literal["http", "websocket"], app: Flama, data: TelemetryData) -> "Wrapper": | ||
if type == "websocket": | ||
return WebSocketWrapper(app, data) | ||
|
||
return HTTPWrapper(app, data) | ||
|
||
async def __call__(self, scope: types.Scope, receive: types.Receive, send: types.Send) -> None: | ||
self._scope = scope | ||
self._receive = receive | ||
self._send = send | ||
self._response_body = b"" | ||
self._response_headers = None | ||
self._response_status_code = None | ||
|
||
try: | ||
await self.app(self._scope, self.receive, self.send) | ||
self.data.response = Response(headers=self._response_headers, status_code=self._response_status_code) | ||
except Exception as e: | ||
self.data.error = await Error.from_exception(exception=e) | ||
raise | ||
|
||
@abc.abstractmethod | ||
async def receive(self) -> types.Message: | ||
... | ||
|
||
@abc.abstractmethod | ||
async def send(self, message: types.Message) -> None: | ||
... | ||
|
||
|
||
class HTTPWrapper(Wrapper): | ||
async def receive(self) -> types.Message: | ||
message = await self._receive() | ||
|
||
if message["type"] == "http.request": | ||
self.data.request.body += message.get("body", b"") | ||
|
||
return message | ||
|
||
async def send(self, message: types.Message) -> None: | ||
if message["type"] == "http.response.start": | ||
self._response_headers = message.get("headers", []) | ||
self._response_status_code = message.get("status") | ||
elif message["type"] == "http.response.body": | ||
self._response_body += message.get("body", b"") | ||
|
||
await self._send(message) | ||
|
||
|
||
class WebSocketWrapper(Wrapper): | ||
async def receive(self) -> types.Message: | ||
message = await self._receive() | ||
|
||
if message["type"] == "websocket.receive": | ||
self._response_body += message.get("body", b"") | ||
elif message["type"] == "websocket.disconnect": | ||
self._response_status_code = message.get("code", None) | ||
self._response_body = message.get("reason", "").encode() | ||
|
||
return message | ||
|
||
async def send(self, message: types.Message) -> None: | ||
if message["type"] == "websocket.send": | ||
self.data.request.body += message.get("bytes", message.get("text", "").encode()) | ||
elif message["type"] == "websocket.close": | ||
self._response_status_code = message.get("code") | ||
self._response_body = message.get("reason", "").encode() | ||
|
||
await self._send(message) | ||
|
||
|
||
class TelemetryDataCollector: | ||
data: TelemetryData | ||
|
||
def __init__(self, app: Flama, scope: types.Scope, receive: types.Receive, send: types.Send) -> None: | ||
self.app = app | ||
self._scope = scope | ||
self._receive = receive | ||
self._send = send | ||
|
||
@classmethod | ||
async def build( | ||
cls, app: Flama, scope: types.Scope, receive: types.Receive, send: types.Send | ||
) -> "TelemetryDataCollector": | ||
self = cls(app, scope, receive, send) | ||
self.data = await TelemetryData.from_scope(scope=scope, receive=receive, send=send) | ||
return self | ||
|
||
async def __call__(self) -> None: | ||
await Wrapper.build(self._scope["type"], self.app, self.data)( | ||
scope=self._scope, receive=self._receive, send=self._send | ||
) | ||
|
||
|
||
class TelemetryMiddleware: | ||
def __init__( | ||
self, | ||
app: types.App, | ||
log_level: int = logging.INFO, | ||
*, | ||
before: t.Optional[HookFunction] = None, | ||
after: t.Optional[HookFunction] = None, | ||
) -> None: | ||
self.app: Flama = t.cast(Flama, app) | ||
self._log_level = log_level | ||
self._before = before | ||
self._after = after | ||
|
||
async def before(self, data: TelemetryData): | ||
if self._before: | ||
await concurrency.run(self._before, data) | ||
|
||
async def after(self, data: TelemetryData): | ||
if self._after: | ||
await concurrency.run(self._after, data) | ||
|
||
async def __call__(self, scope: types.Scope, receive: types.Receive, send: types.Send) -> None: | ||
if scope["type"] not in ("http", "websocket"): | ||
await self.app(scope, receive, send) | ||
return | ||
|
||
collector = await TelemetryDataCollector.build(self.app, scope, receive, send) | ||
|
||
await self.before(collector.data) | ||
|
||
try: | ||
await collector() | ||
finally: | ||
await self.after(collector.data) | ||
logger.log(self._log_level, "Telemetry: %s", str(collector.data)) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import datetime | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
|
||
from flama.telemetry.data_structures import Error, TelemetryData | ||
|
||
|
||
@pytest.fixture(scope="function", autouse=True) | ||
def add_routes(app): | ||
@app.route("/") | ||
def root(): | ||
return {"puppy": "Canna"} | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def asgi_scope(app, asgi_scope): | ||
asgi_scope["app"] = app | ||
return asgi_scope | ||
|
||
|
||
class TestCaseAuthentication: | ||
def test_from_scope(self, asgi_scope, asgi_receive, asgi_send): | ||
... | ||
|
||
|
||
class TestCaseEndpoint: | ||
def test_from_scope(self, asgi_scope, asgi_receive, asgi_send): | ||
... | ||
|
||
|
||
class TestCaseRequest: | ||
def test_from_scope(self, asgi_scope, asgi_receive, asgi_send): | ||
... | ||
|
||
|
||
class TestCaseError: | ||
async def test_from_exception(self): | ||
now = datetime.datetime.now() | ||
with patch("datetime.datetime", MagicMock(now=MagicMock(return_value=now))): | ||
try: | ||
raise ValueError("Foo") | ||
except ValueError as e: | ||
error = await Error.from_exception(exception=e) | ||
|
||
assert error.to_dict() == {"detail": "Foo", "status_code": None, "timestamp": now.isoformat()} | ||
|
||
|
||
class TestCaseTelemetryData: | ||
async def test_from_scope(self, asgi_scope, asgi_receive, asgi_send): | ||
data = await TelemetryData.from_scope(scope=asgi_scope, receive=asgi_receive, send=asgi_send) | ||
|
||
assert data.to_dict() == {} |
Oops, something went wrong.