Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions sanic_session/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from typing import Optional, Union

from sanic import Sanic

from .aioredis import AIORedisSessionInterface
from .base import BaseSessionInterface
from .memcache import MemcacheSessionInterface
from .memory import InMemorySessionInterface
from .mongodb import MongoDBSessionInterface
from .policy import RenewalPolicy
from .redis import RedisSessionInterface

__all__ = (
Expand All @@ -15,13 +21,27 @@


class Session:
def __init__(self, app=None, interface=None):
self.interface = None
def __init__(
self,
app: Optional[Sanic] = None,
interface: Optional[BaseSessionInterface] = None,
renew_cookie: Union[str, RenewalPolicy] = RenewalPolicy.NEVER,
):
self.interface = interface
self.renew_cookie = (
renew_cookie
if isinstance(renew_cookie, RenewalPolicy)
else RenewalPolicy[renew_cookie.upper()]
)
if app:
self.init_app(app, interface)

def init_app(self, app, interface):
def init_app(
self, app: Sanic, interface: Optional[BaseSessionInterface] = None
):
self.interface = interface or InMemorySessionInterface()
self.interface.renew_cookie = self.renew_cookie

if not hasattr(app.ctx, "extensions"):
app.ctx.extensions = {}

Expand Down
8 changes: 8 additions & 0 deletions sanic_session/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import ujson

from sanic_session.policy import RenewalPolicy
from sanic_session.utils import CallbackDict


Expand Down Expand Up @@ -47,6 +48,7 @@ def __init__(
self.samesite = samesite
self.session_name = session_name
self.secure = secure
self.renew_cookie: RenewalPolicy = RenewalPolicy.NEVER

def _delete_cookie(self, request, response):
req = get_request_container(request)
Expand All @@ -65,6 +67,12 @@ def _calculate_expires(expiry):

def _set_cookie_props(self, request, response):
req = get_request_container(request)
if (
self.renew_cookie is not RenewalPolicy.ALWAYS
and request.cookies.get(self.cookie_name)
== req[self.session_name].sid
):
return # session_id same with client, do nothing
response.cookies[self.cookie_name] = req[self.session_name].sid
response.cookies[self.cookie_name]["httponly"] = self.httponly

Expand Down