diff --git a/sanic_session/__init__.py b/sanic_session/__init__.py index 1341bea..263d745 100644 --- a/sanic_session/__init__.py +++ b/sanic_session/__init__.py @@ -1,4 +1,5 @@ from .aioredis import AIORedisSessionInterface +from .extension import Session from .memcache import MemcacheSessionInterface from .memory import InMemorySessionInterface from .mongodb import MongoDBSessionInterface @@ -12,35 +13,3 @@ "AIORedisSessionInterface", "Session", ) - - -class Session: - def __init__(self, app=None, interface=None): - self.interface = None - if app: - self.init_app(app, interface) - - def init_app(self, app, interface): - self.interface = interface or InMemorySessionInterface() - if not hasattr(app.ctx, "extensions"): - app.ctx.extensions = {} - - app.ctx.extensions[ - self.interface.session_name - ] = self # session_name defaults to 'session' - - # @app.middleware('request') - async def add_session_to_request(request): - """Before each request initialize a session - using the client's request.""" - await self.interface.open(request) - - # @app.middleware('response') - async def save_session(request, response): - """After each request save the session, pass - the response to set client cookies. - """ - await self.interface.save(request, response) - - app.request_middleware.appendleft(add_session_to_request) - app.response_middleware.append(save_session) diff --git a/sanic_session/extension.py b/sanic_session/extension.py new file mode 100644 index 0000000..832ac21 --- /dev/null +++ b/sanic_session/extension.py @@ -0,0 +1,49 @@ +from .memory import InMemorySessionInterface + +try: + from sanic_ext import Extension + + SANIC_EXTENSIONS = True +except ImportError: + Extension = type("Extension", (), {}) # type: ignore + SANIC_EXTENSIONS = False + + +class Session(Extension): + name = "session" + + def __init__(self, app=None, interface=None): + self.interface = interface + if app: + self.init_app(app, interface) + + def init_app(self, app, interface): + self.interface = interface or InMemorySessionInterface() + if not hasattr(app.ctx, "extensions"): + app.ctx.extensions = {} + + app.ctx.extensions[ + self.interface.session_name + ] = self # session_name defaults to 'session' + + async def add_session_to_request(request): + """Before each request initialize a session + using the client's request.""" + await self.interface.open(request) + + async def save_session(request, response): + """After each request save the session, pass + the response to set client cookies. + """ + await self.interface.save(request, response) + + app.request_middleware.appendleft(add_session_to_request) + app.response_middleware.append(save_session) + + def startup(self, _) -> None: + if not SANIC_EXTENSIONS: + raise RuntimeError("Sanic Extensions is not installed") + self.init_app(self.app, self.interface) + + def label(self): + return self.interface.__class__.__name__