From 783ca52592e82ac81c144026896060e8072bc9cb Mon Sep 17 00:00:00 2001 From: Mostafa Khalil Date: Sat, 6 Jan 2024 23:17:13 +0400 Subject: [PATCH] feat: Add PingingPool check on session age --- google/cloud/spanner_v1/pool.py | 10 +++++++- tests/unit/test_pool.py | 42 +++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 56837bfc0b..f25414f2ad 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -385,6 +385,8 @@ class PingingPool(AbstractSessionPool): :param database_role: (Optional) user-assigned database_role for the session. """ + SESSION_MAX_AGE = 28 * 24 * 60 * 60 + def __init__( self, size=10, @@ -448,8 +450,9 @@ def get(self, timeout=None): timeout = self.default_timeout ping_after, session = self._sessions.get(block=True, timeout=timeout) + session_age = (_NOW() - session._created_at).total_seconds() - if _NOW() > ping_after: + if _NOW() > ping_after or session_age >= self.SESSION_MAX_AGE: # Using session.exists() guarantees the returned session exists. # session.ping() uses a cached result in the backend which could # result in a recently deleted session being returned. @@ -481,6 +484,11 @@ def clear(self): else: session.delete() + def _new_session(self): + session = super()._new_session() + session._created_at = _NOW() + return session + def ping(self): """Refresh maybe-expired sessions in the pool. diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index 23ed3e7251..f4085121d7 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -531,6 +531,48 @@ def test_get_hit_w_ping_expired(self): self.assertTrue(SESSIONS[0]._exists_checked) self.assertFalse(pool._sessions.full()) + def test_get_hit_w_created(self): + import datetime + + pool = self._make_one(size=4) + database = _Database("name") + SESSIONS = [_Session(database)] * 4 + database._sessions.extend(SESSIONS) + pool.bind(database) + + session_max_age = 28 * 24 * 60 * 60 + SESSIONS[0]._created_at = datetime.datetime.utcnow() - datetime.timedelta( + seconds=session_max_age + 10 + ) + + session = pool.get() + + self.assertIs(session, SESSIONS[0]) + self.assertTrue(session._exists_checked) + self.assertFalse(pool._sessions.full()) + + def test_get_hit_w_created_expired(self): + import datetime + + pool = self._make_one(size=4) + database = _Database("name") + SESSIONS = [_Session(database)] * 5 + database._sessions.extend(SESSIONS) + pool.bind(database) + + session_max_age = 28 * 24 * 60 * 60 + SESSIONS[0]._created_at = datetime.datetime.utcnow() - datetime.timedelta( + seconds=session_max_age + ) + SESSIONS[0]._exists = False + + session = pool.get() + + self.assertIs(session, SESSIONS[4]) + session.create.assert_called() + self.assertTrue(SESSIONS[0]._exists_checked) + self.assertFalse(pool._sessions.full()) + def test_get_empty_default_timeout(self): import queue