diff --git a/sanic_session/__init__.py b/sanic_session/__init__.py index 1341bea..a99e916 100644 --- a/sanic_session/__init__.py +++ b/sanic_session/__init__.py @@ -1,7 +1,13 @@ +from typing import Optional, Union + +from sanic import Sanic + from .aioredis import AIORedisSessionInterface +from .base import BaseSessionInterface from .memcache import MemcacheSessionInterface from .memory import InMemorySessionInterface from .mongodb import MongoDBSessionInterface +from .policy import RenewalPolicy from .redis import RedisSessionInterface __all__ = ( @@ -15,13 +21,27 @@ class Session: - def __init__(self, app=None, interface=None): - self.interface = None + def __init__( + self, + app: Optional[Sanic] = None, + interface: Optional[BaseSessionInterface] = None, + renew_cookie: Union[str, RenewalPolicy] = RenewalPolicy.NEVER, + ): + self.interface = interface + self.renew_cookie = ( + renew_cookie + if isinstance(renew_cookie, RenewalPolicy) + else RenewalPolicy[renew_cookie.upper()] + ) if app: self.init_app(app, interface) - def init_app(self, app, interface): + def init_app( + self, app: Sanic, interface: Optional[BaseSessionInterface] = None + ): self.interface = interface or InMemorySessionInterface() + self.interface.renew_cookie = self.renew_cookie + if not hasattr(app.ctx, "extensions"): app.ctx.extensions = {} diff --git a/sanic_session/base.py b/sanic_session/base.py index 274960e..43d8c2a 100644 --- a/sanic_session/base.py +++ b/sanic_session/base.py @@ -5,6 +5,7 @@ import ujson +from sanic_session.policy import RenewalPolicy from sanic_session.utils import CallbackDict @@ -47,6 +48,7 @@ def __init__( self.samesite = samesite self.session_name = session_name self.secure = secure + self.renew_cookie: RenewalPolicy = RenewalPolicy.NEVER def _delete_cookie(self, request, response): req = get_request_container(request) @@ -65,6 +67,12 @@ def _calculate_expires(expiry): def _set_cookie_props(self, request, response): req = get_request_container(request) + if ( + self.renew_cookie is not RenewalPolicy.ALWAYS + and request.cookies.get(self.cookie_name) + == req[self.session_name].sid + ): + return # session_id same with client, do nothing response.cookies[self.cookie_name] = req[self.session_name].sid response.cookies[self.cookie_name]["httponly"] = self.httponly