Skip to content
This repository has been archived by the owner on Feb 8, 2025. It is now read-only.

Commit

Permalink
Merge pull request #6 from AmaseCocoa/websocket-support
Browse files Browse the repository at this point in the history
WebSocket Support
  • Loading branch information
AmaseCocoa authored Jul 8, 2024
2 parents 54a65f3 + 6c4ccb3 commit aba5cf9
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 103 deletions.
7 changes: 6 additions & 1 deletion src/kasumi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from importlib.metadata import version
from importlib.metadata import PackageNotFoundError

from starlette.requests import Request

from . import responses
from .applications import Kasumi
from .gear import Gear
from .websocket import WebSocket

__version__ = version("kasumi")
try:
__version__ = version("kasumi")
except PackageNotFoundError:
__version__ = "0.0.0"
214 changes: 113 additions & 101 deletions src/kasumi/applications.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,57 @@
import inspect
from http.client import responses
from typing import Dict, Callable, Coroutine, Any

from starlette.requests import Request
from starlette.responses import PlainTextResponse

from .exceptions import AlreadyRegistedError, GearException
from .gear import Gear

from .websocket import WebSocket
class Kasumi:

def __init__(self) -> None:
self.__requests = {}
self.__err = {}

self.__lifespan = {
"startup": [],
"shutdown": [],
"lifespan": []
}
self._requests: Dict[str, Dict[str, Callable[..., Coroutine[Any, Any, Any]]]] = {}
self._err: Dict[int, Callable[..., Coroutine[Any, Any, Any]]] = {}
self._gears: Dict[str, Gear] = {}

def include_gear(self, gear: Gear):
"""
Include routes and error handlers from a Gear instance into the Kasumi app.
Parameters:
- gear: The Gear instance to include.
"""
if gear.prefix in self._gears:
raise GearException(f"Gear with prefix '{gear.prefix}' is already included.")
self._gears[gear.prefix] = gear
for route, methods in gear._requests.items():
full_route = gear.prefix + route
if full_route not in self._requests:
self._requests[full_route] = methods
else:
for method, func in methods.items():
if method in self._requests[full_route]:
raise AlreadyRegistedError(
f"The function is already registered in the method “{method}” of the route “{full_route}”."
)
self._requests[full_route][method] = func
for error_code, func in gear._err.items():
if error_code in self._err:
raise AlreadyRegistedError(
f"The function is already registered in the ErrorCode “{error_code}”."
)
self._err[error_code] = func

def __normalize_path(self, path: str) -> str:
return path.rstrip('/')

async def __call__(self, scope, receive, send):
"""
This function handles incoming HTTP requests by routing them to the appropriate handler based on
Expand All @@ -31,31 +70,47 @@ async def __call__(self, scope, receive, send):
client. The `send` function takes a single argument, which is a dictionary representing the
message to be
"""

if scope['type'] == 'http':
handler = None
request = Request(scope, receive)
if self.__requests.get(scope['path']):
req: dict = self.__requests[scope['path']]
handler = req.get(request.method)
if handler is None:
handler = 405
elif self.__requests.get(request.base_url.hostname):
if self.__requests[request.base_url.hostname].get(scope['path']):
req: dict = self.__requests[request.base_url.hostname][scope['path']]
handler = req.get(request.method)
if handler is None:
handler = 405
else:
handler = None
handler = self.__find_route_handler(scope, request.method)
if handler:
if isinstance(handler, int):
await self.__handle_err(request, scope, receive, send, 405)
await self.__handle_err(request, scope, receive, send, handler)
else:
response = await handler(request)
await response(scope, receive, send)
else:
await self.__handle_err(request, scope, receive, send, 404)
elif scope['type'] == 'lifespan':
while True:
message = await receive()
if message['type'] == 'lifespan.startup':
for lifespan in self.__lifespan["startup"]:
await lifespan()
await send({'type': 'lifespan.startup.complete'})
elif message['type'] == 'lifespan.shutdown':
for lifespan in self.__lifespan["shutdown"]:
await lifespan()
await send({'type': 'lifespan.shutdown.complete'})
return
elif scope['type'] == 'websocket':
print(self._requests[scope['path']])
handler = None
handler = self.__find_route_handler(scope, "WS")
if handler:
websocket = WebSocket(scope, receive, send)
await handler(websocket)

def __find_route_handler(self, scope, method: str) -> Callable[..., Coroutine[Any, Any, Any]]:
path = scope['path']
method = method.upper()
if path in self._requests and method in self._requests[path]:
return self._requests[path][method]
for gear in self._gears.values():
if path.startswith(gear.prefix):
return gear._find_route_handler(path, method)
return None

async def __handle_err(self, request, scope, receive, send, status_code: int=404):
if self.__err.get(status_code):
Expand All @@ -69,59 +124,68 @@ async def __handle_err(self, request, scope, receive, send, status_code: int=404
response = PlainTextResponse(resp_msg, status_code=status_code)
await response(scope, receive, send)

def route(self, route: str, method: list="GET"):
def lifespan(self, event: str):
def decorator(func):
if isinstance(func, staticmethod):
func = func.__func__
if not inspect.iscoroutinefunction(func):
raise TypeError("Routes that listen for requests must be coroutines.")
for m in method:
met = m.upper()
ev = self.__requests.get(route)
if ev is None:
self.__requests[route] = {}
ev = self.__requests.get(route)
else:
if ev.get(met) and ev.get(met) != {}:
raise AlreadyRegistedError(f'The function is already registered in the method “{met}” of the route “{route}”.')
ev[met] = func
raise TypeError("lifespan that listen for requests must be coroutines.")
if event not in ["startup", "shutdown"]:
raise TypeError("Only startup or shutdown can be set for event.")
else:
self.__lifespan[event].append(func)
return func
return decorator

def get(self, route: str):
def ws(self, route: str):
def decorator(func):
route_normalized = self.__normalize_path(route)
if isinstance(func, staticmethod):
func = func.__func__
if not inspect.iscoroutinefunction(func):
raise TypeError("Routes that listen for requests must be coroutines.")
ev = self.__requests.get(route)
if ev is None:
self.__requests[route] = {}
ev = self.__requests.get(route)
else:
if ev.get("GET") and ev.get("GET") != {}:
raise AlreadyRegistedError(f'The function is already registered in the method “GET” of the route “{route}”.')
ev["GET"] = func
if route not in self._requests or route_normalized not in self._requests:
self._requests[route] = {}
self._requests[route_normalized] = {}
if "WS" in self._requests[route] or "WS" in self._requests[route_normalized]:
raise AlreadyRegistedError(
f"The function is already registered in the method “WebSocket” of the route “{route_normalized} ({route})”."
)
self._requests[route]["WS"] = func
if route != route_normalized:
self._requests[route_normalized]["WS"] = func
return func
return decorator

def post(self, route: str):
def route(self, route: str, method: list = ["GET", "POST"]):
def decorator(func):
route_normalized = self.__normalize_path(route)
if isinstance(func, staticmethod):
func = func.__func__
func._router_method = method
if not inspect.iscoroutinefunction(func):
raise TypeError("Routes that listen for requests must be coroutines.")
ev = self.__requests.get(route)
if ev is None:
self.__requests[route] = {}
ev = self.__requests.get(route)
else:
if ev.get("POST") and ev.get("POST") != {}:
raise AlreadyRegistedError(f'The function is already registered in the method “POST” of the route “{route}”.')
ev["POST"] = func
for m in method:
met = m.upper()
if route not in self._requests or route_normalized not in self._requests:
self._requests[route] = {}
self._requests[route_normalized] = {}
if met in self._requests[route] or met in self._requests[route_normalized]:
raise AlreadyRegistedError(
f"The function is already registered in the method “{met}” of the route “{route_normalized} ({route})”."
)
self._requests[route][met] = func
if route != route_normalized:
self._requests[route_normalized][met] = func
return func
return decorator

def get(self, route: str):
return self.route(route, method=["GET"])

def post(self, route: str):
return self.route(route, method=["POST"])

def err(self, error_code: int):
def decorator(func):
if isinstance(func, staticmethod):
Expand All @@ -134,56 +198,4 @@ def decorator(func):
raise AlreadyRegistedError(f'The function is already registered in the ErrorCode “{error_code}”.')
self.__err[error_code] = func
return func
return decorator

def combine_route(self, route: dict, name: str, routeType: str="normal", host: str=None):
if host is None:
if routeType == "normal":
self.__requests[name] = route
elif routeType == "err":
self.__err[name] = route
else:
if routeType == "normal":
self.__requests[host][name] = route
elif routeType == "err":
self.__err[host][name] = route

def include_gear(self, module: Gear, host: str=None):
if host is None:
route = module._requests
for k in route.keys():
if self.__requests.get(k):
for router in route[k].keys():
if self.__requests[k].get(router):
raise GearException(f"""The Route "{k}" registered in the gear has another function registered""")
else:
self.combine_route(
route[k], k
)
del k
err = module._err
for k in err.keys():
if self.__err.get(k):
for error in err[k].keys():
if self.__requests[k].get(error):
raise GearException(f"""The Route "{k}" registered in the gear has another function registered""")
else:
self.combine_route(
err[k], k
)
else:
if self.__requests.get(host):
raise GearException(f"""Another gear is registered to the requested host "{k}".""")
else:
route = module._requests
self.__requests[host] = {}
route_host = self.__requests[host]
for k in route.keys():
if route_host.get(k):
for router in route[k].keys():
if route_host[k].get(router):
raise GearException(f"""The Route "{k}" (on {host}) registered in the gear has another function registered""")
else:
self.combine_route(
route[k], k, host=host
)
return decorator
3 changes: 3 additions & 0 deletions src/kasumi/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@ class AlreadyRegistedError(Exception):
pass

class GearException(Exception):
pass

class ConnectionClosed(Exception):
pass
40 changes: 40 additions & 0 deletions src/kasumi/models/websocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Callable, Iterable, Tuple, Optional, Union

from pydantic import dataclasses
from pydantic.dataclasses import dataclass
from yarl import URL

from .asgi import ASGI
from .connection import Client, Server

"""
@dataclass()
class WebSocket:
asgi: ASGI
url: URL
json: Callable | None
path: str
headers: Iterable[Union[bytes, str, Tuple[bytes, bytes]]]
client: Client
server: Server
subprotocols: Iterable[str]
state: Optional[dict[str]]
raw_path: bytes | None = dataclasses.Field(
default=None)
http_version: str = dataclasses.Field(
default="2.0")
text: str | bytes | None = dataclasses.Field(
default=None)
query_string: bytes = dataclasses.Field(
default=b'')
root_path: str = dataclasses.Field(
default='')
"""

@dataclass
class WSMessage:
text: str | bytes | None = dataclasses.Field(
default=None)
json: dict | None = dataclasses.Field(
default=None)
Loading

0 comments on commit aba5cf9

Please sign in to comment.