Skip to content

Commit

Permalink
Allow route decorator stacking.
Browse files Browse the repository at this point in the history
  • Loading branch information
EvieePy committed May 6, 2024
1 parent 517e151 commit 39bf145
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions starlette_plus/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(self, **kwargs: Unpack[RouteOptions]) -> None:
self._prefix: bool = kwargs["prefix"]
self._limits: list[RateLimitData] = kwargs.get("limits", [])
self._is_websocket: bool = kwargs.get("websocket", False)
self._view: View | None = None
self._view: View | Application | None = None
self._include_in_schema: bool = kwargs["include_in_schema"]

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> Any:
Expand All @@ -117,8 +117,8 @@ def route(
prefix: bool = True,
websocket: bool = False,
include_in_schema: bool = True,
) -> Callable[..., _Route]:
def decorator(coro: Callable[..., RouteCoro]) -> _Route:
) -> Callable[..., Callable[..., RouteCoro]]:
def decorator(coro: Callable[..., RouteCoro]) -> Callable[..., RouteCoro]:
if not asyncio.iscoroutinefunction(coro):
raise RuntimeError("Route callback must be a coroutine function.")

Expand All @@ -127,7 +127,7 @@ def decorator(coro: Callable[..., RouteCoro]) -> _Route:
raise ValueError(f"Route callback function must not be named any: {', '.join(disallowed)}")

limits: list[RateLimitData] = getattr(coro, "__limits__", [])
return _Route(
route = _Route(
path=path,
coro=coro,
methods=methods,
Expand All @@ -137,6 +137,13 @@ def decorator(coro: Callable[..., RouteCoro]) -> _Route:
include_in_schema=include_in_schema,
)

try:
coro.__routes__.append(route) # type: ignore
except AttributeError:
setattr(coro, "__routes__", [route])

return coro

return decorator


Expand Down Expand Up @@ -215,8 +222,11 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Self:
self.__routes__ = []

name: str = cls.__name__
members: list[Any] = [
r for (_, m) in inspect.getmembers(self, predicate=lambda m: hasattr(m, "__routes__")) for r in m.__routes__
]

for _, member in inspect.getmembers(self, predicate=lambda m: isinstance(m, _Route)):
for member in members:
member._view = self
path: str = member._path

Expand Down Expand Up @@ -294,8 +304,11 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Self:
prefix = cls.__prefix__ or name

self.__routes__ = []
members: list[Any] = [
r for (_, m) in inspect.getmembers(self, predicate=lambda m: hasattr(m, "__routes__")) for r in m.__routes__
]

for _, member in inspect.getmembers(self, predicate=lambda m: isinstance(m, _Route)):
for member in members:
member._view = self
path: str = member._path

Expand Down

0 comments on commit 39bf145

Please sign in to comment.