diff --git a/src/kasumi/__init__.py b/src/kasumi/__init__.py index a3a7281..f4d3944 100644 --- a/src/kasumi/__init__.py +++ b/src/kasumi/__init__.py @@ -1,9 +1,14 @@ from importlib.metadata import version +from importlib.metadata import PackageNotFoundError from starlette.requests import Request from . import responses from .applications import Kasumi from .gear import Gear +from .websocket import WebSocket -__version__ = version("kasumi") +try: + __version__ = version("kasumi") +except PackageNotFoundError: + __version__ = "0.0.0" \ No newline at end of file diff --git a/src/kasumi/applications.py b/src/kasumi/applications.py index c66210c..6df4882 100644 --- a/src/kasumi/applications.py +++ b/src/kasumi/applications.py @@ -1,18 +1,57 @@ import inspect from http.client import responses +from typing import Dict, Callable, Coroutine, Any from starlette.requests import Request from starlette.responses import PlainTextResponse from .exceptions import AlreadyRegistedError, GearException from .gear import Gear - +from .websocket import WebSocket class Kasumi: def __init__(self) -> None: - self.__requests = {} self.__err = {} - + self.__lifespan = { + "startup": [], + "shutdown": [], + "lifespan": [] + } + self._requests: Dict[str, Dict[str, Callable[..., Coroutine[Any, Any, Any]]]] = {} + self._err: Dict[int, Callable[..., Coroutine[Any, Any, Any]]] = {} + self._gears: Dict[str, Gear] = {} + + def include_gear(self, gear: Gear): + """ + Include routes and error handlers from a Gear instance into the Kasumi app. + + Parameters: + - gear: The Gear instance to include. + """ + if gear.prefix in self._gears: + raise GearException(f"Gear with prefix '{gear.prefix}' is already included.") + self._gears[gear.prefix] = gear + for route, methods in gear._requests.items(): + full_route = gear.prefix + route + if full_route not in self._requests: + self._requests[full_route] = methods + else: + for method, func in methods.items(): + if method in self._requests[full_route]: + raise AlreadyRegistedError( + f"The function is already registered in the method “{method}” of the route “{full_route}”." + ) + self._requests[full_route][method] = func + for error_code, func in gear._err.items(): + if error_code in self._err: + raise AlreadyRegistedError( + f"The function is already registered in the ErrorCode “{error_code}”." + ) + self._err[error_code] = func + + def __normalize_path(self, path: str) -> str: + return path.rstrip('/') + async def __call__(self, scope, receive, send): """ This function handles incoming HTTP requests by routing them to the appropriate handler based on @@ -31,31 +70,47 @@ async def __call__(self, scope, receive, send): client. The `send` function takes a single argument, which is a dictionary representing the message to be """ - if scope['type'] == 'http': handler = None request = Request(scope, receive) - if self.__requests.get(scope['path']): - req: dict = self.__requests[scope['path']] - handler = req.get(request.method) - if handler is None: - handler = 405 - elif self.__requests.get(request.base_url.hostname): - if self.__requests[request.base_url.hostname].get(scope['path']): - req: dict = self.__requests[request.base_url.hostname][scope['path']] - handler = req.get(request.method) - if handler is None: - handler = 405 - else: - handler = None + handler = self.__find_route_handler(scope, request.method) if handler: if isinstance(handler, int): - await self.__handle_err(request, scope, receive, send, 405) + await self.__handle_err(request, scope, receive, send, handler) else: response = await handler(request) await response(scope, receive, send) else: await self.__handle_err(request, scope, receive, send, 404) + elif scope['type'] == 'lifespan': + while True: + message = await receive() + if message['type'] == 'lifespan.startup': + for lifespan in self.__lifespan["startup"]: + await lifespan() + await send({'type': 'lifespan.startup.complete'}) + elif message['type'] == 'lifespan.shutdown': + for lifespan in self.__lifespan["shutdown"]: + await lifespan() + await send({'type': 'lifespan.shutdown.complete'}) + return + elif scope['type'] == 'websocket': + print(self._requests[scope['path']]) + handler = None + handler = self.__find_route_handler(scope, "WS") + if handler: + websocket = WebSocket(scope, receive, send) + await handler(websocket) + + def __find_route_handler(self, scope, method: str) -> Callable[..., Coroutine[Any, Any, Any]]: + path = scope['path'] + method = method.upper() + if path in self._requests and method in self._requests[path]: + return self._requests[path][method] + for gear in self._gears.values(): + if path.startswith(gear.prefix): + return gear._find_route_handler(path, method) + return None async def __handle_err(self, request, scope, receive, send, status_code: int=404): if self.__err.get(status_code): @@ -69,59 +124,68 @@ async def __handle_err(self, request, scope, receive, send, status_code: int=404 response = PlainTextResponse(resp_msg, status_code=status_code) await response(scope, receive, send) - def route(self, route: str, method: list="GET"): + def lifespan(self, event: str): def decorator(func): if isinstance(func, staticmethod): func = func.__func__ if not inspect.iscoroutinefunction(func): - raise TypeError("Routes that listen for requests must be coroutines.") - for m in method: - met = m.upper() - ev = self.__requests.get(route) - if ev is None: - self.__requests[route] = {} - ev = self.__requests.get(route) - else: - if ev.get(met) and ev.get(met) != {}: - raise AlreadyRegistedError(f'The function is already registered in the method “{met}” of the route “{route}”.') - ev[met] = func + raise TypeError("lifespan that listen for requests must be coroutines.") + if event not in ["startup", "shutdown"]: + raise TypeError("Only startup or shutdown can be set for event.") + else: + self.__lifespan[event].append(func) return func return decorator - def get(self, route: str): + def ws(self, route: str): def decorator(func): + route_normalized = self.__normalize_path(route) if isinstance(func, staticmethod): func = func.__func__ if not inspect.iscoroutinefunction(func): raise TypeError("Routes that listen for requests must be coroutines.") - ev = self.__requests.get(route) - if ev is None: - self.__requests[route] = {} - ev = self.__requests.get(route) - else: - if ev.get("GET") and ev.get("GET") != {}: - raise AlreadyRegistedError(f'The function is already registered in the method “GET” of the route “{route}”.') - ev["GET"] = func + if route not in self._requests or route_normalized not in self._requests: + self._requests[route] = {} + self._requests[route_normalized] = {} + if "WS" in self._requests[route] or "WS" in self._requests[route_normalized]: + raise AlreadyRegistedError( + f"The function is already registered in the method “WebSocket” of the route “{route_normalized} ({route})”." + ) + self._requests[route]["WS"] = func + if route != route_normalized: + self._requests[route_normalized]["WS"] = func return func return decorator - def post(self, route: str): + def route(self, route: str, method: list = ["GET", "POST"]): def decorator(func): + route_normalized = self.__normalize_path(route) if isinstance(func, staticmethod): func = func.__func__ + func._router_method = method if not inspect.iscoroutinefunction(func): raise TypeError("Routes that listen for requests must be coroutines.") - ev = self.__requests.get(route) - if ev is None: - self.__requests[route] = {} - ev = self.__requests.get(route) - else: - if ev.get("POST") and ev.get("POST") != {}: - raise AlreadyRegistedError(f'The function is already registered in the method “POST” of the route “{route}”.') - ev["POST"] = func + for m in method: + met = m.upper() + if route not in self._requests or route_normalized not in self._requests: + self._requests[route] = {} + self._requests[route_normalized] = {} + if met in self._requests[route] or met in self._requests[route_normalized]: + raise AlreadyRegistedError( + f"The function is already registered in the method “{met}” of the route “{route_normalized} ({route})”." + ) + self._requests[route][met] = func + if route != route_normalized: + self._requests[route_normalized][met] = func return func return decorator + def get(self, route: str): + return self.route(route, method=["GET"]) + + def post(self, route: str): + return self.route(route, method=["POST"]) + def err(self, error_code: int): def decorator(func): if isinstance(func, staticmethod): @@ -134,56 +198,4 @@ def decorator(func): raise AlreadyRegistedError(f'The function is already registered in the ErrorCode “{error_code}”.') self.__err[error_code] = func return func - return decorator - - def combine_route(self, route: dict, name: str, routeType: str="normal", host: str=None): - if host is None: - if routeType == "normal": - self.__requests[name] = route - elif routeType == "err": - self.__err[name] = route - else: - if routeType == "normal": - self.__requests[host][name] = route - elif routeType == "err": - self.__err[host][name] = route - - def include_gear(self, module: Gear, host: str=None): - if host is None: - route = module._requests - for k in route.keys(): - if self.__requests.get(k): - for router in route[k].keys(): - if self.__requests[k].get(router): - raise GearException(f"""The Route "{k}" registered in the gear has another function registered""") - else: - self.combine_route( - route[k], k - ) - del k - err = module._err - for k in err.keys(): - if self.__err.get(k): - for error in err[k].keys(): - if self.__requests[k].get(error): - raise GearException(f"""The Route "{k}" registered in the gear has another function registered""") - else: - self.combine_route( - err[k], k - ) - else: - if self.__requests.get(host): - raise GearException(f"""Another gear is registered to the requested host "{k}".""") - else: - route = module._requests - self.__requests[host] = {} - route_host = self.__requests[host] - for k in route.keys(): - if route_host.get(k): - for router in route[k].keys(): - if route_host[k].get(router): - raise GearException(f"""The Route "{k}" (on {host}) registered in the gear has another function registered""") - else: - self.combine_route( - route[k], k, host=host - ) \ No newline at end of file + return decorator \ No newline at end of file diff --git a/src/kasumi/exceptions.py b/src/kasumi/exceptions.py index c779a80..70b0fb7 100644 --- a/src/kasumi/exceptions.py +++ b/src/kasumi/exceptions.py @@ -2,4 +2,7 @@ class AlreadyRegistedError(Exception): pass class GearException(Exception): + pass + +class ConnectionClosed(Exception): pass \ No newline at end of file diff --git a/src/kasumi/models/websocket.py b/src/kasumi/models/websocket.py new file mode 100644 index 0000000..b89279b --- /dev/null +++ b/src/kasumi/models/websocket.py @@ -0,0 +1,40 @@ +from typing import Callable, Iterable, Tuple, Optional, Union + +from pydantic import dataclasses +from pydantic.dataclasses import dataclass +from yarl import URL + +from .asgi import ASGI +from .connection import Client, Server + +""" +@dataclass() +class WebSocket: + asgi: ASGI + url: URL + json: Callable | None + path: str + headers: Iterable[Union[bytes, str, Tuple[bytes, bytes]]] + client: Client + server: Server + subprotocols: Iterable[str] + state: Optional[dict[str]] + + raw_path: bytes | None = dataclasses.Field( + default=None) + http_version: str = dataclasses.Field( + default="2.0") + text: str | bytes | None = dataclasses.Field( + default=None) + query_string: bytes = dataclasses.Field( + default=b'') + root_path: str = dataclasses.Field( + default='') +""" + +@dataclass +class WSMessage: + text: str | bytes | None = dataclasses.Field( + default=None) + json: dict | None = dataclasses.Field( + default=None) diff --git a/src/kasumi/websocket.py b/src/kasumi/websocket.py new file mode 100644 index 0000000..17cc8cb --- /dev/null +++ b/src/kasumi/websocket.py @@ -0,0 +1,58 @@ +import json as pyjson +from typing import Callable, Dict, Any + +from .exceptions import ConnectionClosed +from .models.websocket import WSMessage +from .models.connection import Client, Server +from .models.asgi import ASGI + +async def json(text: str) -> dict: + return pyjson.loads(text) + +class WebSocket: + def __init__(self, scope: Dict[str, Any], receive: Callable, send: Callable): + self.scope = scope + self.receive = receive + self.send = send + + async def accept(self): + await self.send({ + "type": "websocket.accept" + }) + + async def close(self, code: int = 1000): + await self.send({ + "type": "websocket.close", + "code": code + }) + + async def send_str(self, data: str): + await self.send({ + "type": "websocket.send", + "text": data + }) + + async def send_json(self, data: dict): + await self.send({ + "type": "websocket.send", + "text": pyjson.dumps(data, ensure_ascii=False) + }) + + async def recv(self) -> WSMessage | None: + message = await self.receive() + if message["type"] == "websocket.disconnect": + raise ConnectionClosed + elif message["type"] == "websocket.connect": + return None + elif message["type"] == "websocket.receive": + try: + text = pyjson.loads(message["text"]) + except pyjson.JSONDecodeError: + text = None + if message["type"] == "websocket.receive": + return WSMessage( + text=message["text"], + json=text + ) + else: + raise RuntimeError("Unexpected message type: " + message["type"]) \ No newline at end of file diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 0d6224e..12df0fd 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -54,4 +54,4 @@ async def test_websocket(uvicorn_server): response = await websocket.recv() assert response == "Message text was: Hello, WebSocket!" except ConnectionClosedError: - pass + pass \ No newline at end of file