diff --git a/setup.py b/setup.py index 66e210f..10ba9cb 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,10 @@ author_email='satosa-dev@lists.sunet.se', description='OpenID Connect Provider (OP) library in Python.', install_requires=[ - 'oic >= 0.15.0', - 'pymongo' + 'oic >= 0.15.0' + ] + extras_require=[ + 'pymongo', + 'redis' ] ) diff --git a/src/pyop/storage.py b/src/pyop/storage.py index 613903a..1040d79 100644 --- a/src/pyop/storage.py +++ b/src/pyop/storage.py @@ -1,23 +1,98 @@ # -*- coding: utf-8 -*- +from abc import ABC, abstractmethod import copy -import pymongo -from time import time +import json +from datetime import datetime +import sys +try: + import pymongo +except ImportError: + pass -class MongoWrapper(object): - def __init__(self, db_uri, db_name, collection): +try: + from redis.client import Redis +except ImportError: + pass + +class StorageBase(ABC): + _ttl = None + + @abstractmethod + def __setitem__(self, key, value): + pass + + @abstractmethod + def __getitem__(self, key): + pass + + @abstractmethod + def __delitem__(self, key): + pass + + @abstractmethod + def __contains__(self, key): + pass + + @abstractmethod + def items(self): + pass + + def pop(self, key, default=None): + try: + data = self[key] + except KeyError: + return default + del self[key] + return data + + @classmethod + def from_uri(cls, db_uri, collection, db_name=None, ttl=None): + if db_uri.startswith("mongodb"): + return MongoWrapper(db_uri, db_name, collection, ttl) + if db_uri.startswith("redis") or db_uri.startswith("unix"): + return RedisWrapper(db_uri, collection, ttl) + + return ValueError(f"Invalid DB URI: {db_uri}") + + @property + def ttl(self): + return self._ttl + + def ensure_dependency(self, dependencies): + for module in dependencies: + if not module in sys.modules: + raise ImportError( + f"'{module}' module is required but it is not available" + ) + + +class MongoWrapper(StorageBase): + def __init__(self, db_uri, db_name, collection, ttl=None): + self.ensure_dependency(["pymongo"]) self._db_uri = db_uri self._coll_name = collection self._db = MongoDB(db_uri, db_name=db_name) self._coll = self._db.get_collection(collection) self._coll.create_index('lookup_key', unique=True) + if ttl is None or (isinstance(ttl, int) and ttl >= 0): + self._ttl = ttl + else: + raise ValueError("TTL must be a non-negative integer or None") + if ttl is not None: + self._coll.create_index( + 'last_modified', + expireAfterSeconds=ttl, + name="expiry" + ) + def __setitem__(self, key, value): doc = { 'lookup_key': key, 'data': value, - 'modified_ts': time() + 'last_modified': datetime.utcnow() } self._coll.replace_one({'lookup_key': key}, doc, upsert=True) @@ -38,13 +113,53 @@ def items(self): for doc in self._coll.find(): yield (doc['lookup_key'], doc['data']) - def pop(self, key, default=None): - try: - data = self[key] - except KeyError: - return default - del self[key] - return data + +class RedisWrapper(StorageBase): + """ + Simple wrapper for a dict-like storage in Redis. + Supports JSON-serializable data types. + """ + + def __init__(self, db_uri, collection, ttl=None): + self.ensure_dependency(["redis.client"]) + self._db = Redis.from_url(db_uri, decode_responses=True) + self._collection = collection + if ttl is None or (isinstance(ttl, int) and ttl >= 0): + self._ttl = ttl + else: + raise ValueError("TTL must be a non-negative integer or None") + + def _make_key(self, key): + if not isinstance(key, str): + raise TypeError(f"Keys must be strings, {type(key).__name__} given") + + return ":".join([self._collection, key]) + + def __setitem__(self, key, value): + # Replacing the value of a key resets the ttl counter + encoded = json.dumps({ "value": value }) + self._db.set(self._make_key(key), encoded, ex=self.ttl) + + def __getitem__(self, key): + encoded = self._db.get(self._make_key(key)) + if encoded is None: + raise KeyError(key) + return json.loads(encoded).get("value") + + def __delitem__(self, key): + # Deleting a non-existent key is allowed + self._db.delete(self._make_key(key)) + + def __contains__(self, key): + return (self._db.get(self._make_key(key)) is not None) + + def items(self): + for key in self._db.keys(self._collection + "*"): + visible_key = key[len(self._collection) + 1 :] + try: + yield (visible_key, self[visible_key]) + except KeyError: + pass class MongoDB(object): @@ -56,14 +171,17 @@ def __init__(self, db_uri, db_name=None, if db_uri is None: raise ValueError('db_uri not supplied') - self._db_uri = db_uri - self._database_name = db_name self._sanitized_uri = None self._parsed_uri = pymongo.uri_parser.parse_uri(db_uri) if self._parsed_uri.get('database') is None: + if db_name is None: + raise ValueError( + "Database name must be provided either in the URI or as an argument" + ) self._parsed_uri['database'] = db_name + self._database_name = self._parsed_uri['database'] if 'replicaSet' in kwargs and kwargs['replicaSet'] is None: del kwargs['replicaSet'] diff --git a/tests/pyop/conftest.py b/tests/pyop/conftest.py deleted file mode 100644 index 15f46b8..0000000 --- a/tests/pyop/conftest.py +++ /dev/null @@ -1,84 +0,0 @@ -import atexit -import random -import shutil -import subprocess -import tempfile -import time - -import pymongo -import pytest - - -class MongoTemporaryInstance(object): - """Singleton to manage a temporary MongoDB instance - - Use this for testing purpose only. The instance is automatically destroyed - at the end of the program. - - """ - _instance = None - - @classmethod - def get_instance(cls): - if cls._instance is None: - cls._instance = cls() - atexit.register(cls._instance.shutdown) - return cls._instance - - def __init__(self): - self._tmpdir = tempfile.mkdtemp() - self._port = 27017 - self._process = subprocess.Popen(['mongod', '--bind_ip', 'localhost', - '--port', str(self._port), - '--dbpath', self._tmpdir, - '--nojournal', '--nohttpinterface', - '--noauth', '--smallfiles', - '--syncdelay', '0', - '--nssize', '1', ], - stdout=open('/tmp/mongo-temp.log', 'wb'), - stderr=subprocess.STDOUT) - - # XXX: wait for the instance to be ready - # Mongo is ready in a glance, we just wait to be able to open a - # Connection. - for i in range(10): - time.sleep(0.2) - try: - self._conn = pymongo.MongoClient('localhost', self._port) - except pymongo.errors.ConnectionFailure: - continue - else: - break - else: - self.shutdown() - assert False, 'Cannot connect to the mongodb test instance' - - @property - def conn(self): - return self._conn - - @property - def port(self): - return self._port - - def shutdown(self): - if self._process: - self._process.terminate() - self._process.wait() - self._process = None - shutil.rmtree(self._tmpdir, ignore_errors=True) - - def get_uri(self): - """ - Convenience function to get a mongodb URI to the temporary database. - - :return: URI - """ - return 'mongodb://localhost:{port!s}'.format(port=self.port) - - -@pytest.yield_fixture -def mongodb_instance(): - tmp_db = MongoTemporaryInstance() - yield tmp_db - tmp_db.shutdown() diff --git a/tests/pyop/test_storage.py b/tests/pyop/test_storage.py index 30687e1..f36b2f6 100644 --- a/tests/pyop/test_storage.py +++ b/tests/pyop/test_storage.py @@ -2,15 +2,38 @@ import pytest -from pyop.storage import MongoWrapper +from abc import ABC, abstractmethod +from contextlib import contextmanager +from redis.client import Redis +import datetime +import fakeredis +import mongomock +import pymongo +import time +import sys + +from pyop.storage import StorageBase __author__ = 'lundberg' -class TestMongoStorage(object): - @pytest.fixture() - def db(self, mongodb_instance): - return MongoWrapper(mongodb_instance.get_uri(), 'pyop', 'test') +uri_list = ["mongodb://localhost:1234/pyop", "redis://localhost/0"] + +@pytest.fixture(autouse=True) +def mock_redis(monkeypatch): + def mockreturn(*args, **kwargs): + return fakeredis.FakeStrictRedis(decode_responses=True) + monkeypatch.setattr(Redis, "from_url", mockreturn) + +@pytest.fixture(autouse=True) +def mock_mongo(): + pymongo.MongoClient = mongomock.MongoClient + + +class TestStorage(object): + @pytest.fixture(params=uri_list) + def db(self, request): + return StorageBase.from_uri(request.param, db_name="pyop", collection='test') def test_write(self, db): db['foo'] = 'bar' @@ -41,3 +64,132 @@ def test_items(self, db): for key, item in db.items(): assert key assert item + + @pytest.mark.parametrize( + "args,kwargs", + [ + (["redis://localhost"], {"collection": "test"}), + (["redis://localhost", "test"], {}), + (["unix://localhost/0"], {"collection": "test", "ttl": 3}), + (["mongodb://localhost/pyop"], {"collection": "test", "ttl": 3}), + (["mongodb://localhost"], {"db_name": "pyop", "collection": "test"}), + (["mongodb://localhost", "test", "pyop"], {}), + (["mongodb://localhost/pyop", "test"], {}), + (["mongodb://localhost/pyop"], {"db_name": "other", "collection": "test"}), + (["redis://localhost/0"], {"db_name": "pyop", "collection": "test"}), + ], + ) + def test_from_uri(self, args, kwargs): + store = StorageBase.from_uri(*args, **kwargs) + store["test"] = "value" + assert store["test"] == "value" + + @pytest.mark.parametrize( + "error,args,kwargs", + [ + ( + TypeError, + ["redis://localhost", "ouch"], + {"db_name": 3, "collection": "test", "ttl": None}, + ), + ( + TypeError, + ["mongodb://localhost", "ouch"], + {"db_name": 3, "collection": "test", "ttl": None}, + ), + ( + TypeError, + ["mongodb://localhost", "ouch"], + {"db_name": "pyop", "collection": "test", "ttl": None}, + ), + ( + TypeError, + ["mongodb://localhost", "pyop"], + {"collection": "test", "ttl": None}, + ), + ( + TypeError, + ["mongodb://localhost"], + {"db_name": "pyop", "collection": "test", "ttl": None, "extra": True}, + ), + (TypeError, ["redis://localhost/0"], {}), + (TypeError, ["redis://localhost/0"], {"db_name": "pyop"}), + (ValueError, ["mongodb://localhost"], {"collection": "test", "ttl": None}), + ], + ) + def test_from_uri_invalid_parameters(self, error, args, kwargs): + with pytest.raises(error): + StorageBase.from_uri(*args, **kwargs) + + +class StorageTTLTest(ABC): + def prepare_db(self, uri, ttl): + self.db = StorageBase.from_uri( + uri, + collection="test", + ttl=ttl, + ) + self.db["foo"] = {"bar": "baz"} + + @abstractmethod + def set_time(self, offset, monkey): + pass + + @contextmanager + def adjust_time(self, offset): + mp = pytest.MonkeyPatch() + try: + yield self.set_time(offset, mp) + finally: + mp.undo() + + def execute_ttl_test(self, uri, ttl): + self.prepare_db(uri, ttl) + assert self.db["foo"] + with self.adjust_time(offset=int(ttl / 2)): + assert self.db["foo"] + with self.adjust_time(offset=int(ttl * 2)): + with pytest.raises(KeyError): + self.db["foo"] + + @pytest.mark.parametrize("uri", uri_list) + @pytest.mark.parametrize("ttl", ["invalid", -1, 2.3, {}]) + def test_invalid_ttl(self, uri, ttl): + with pytest.raises(ValueError): + self.prepare_db(uri, ttl) + + +class TestRedisTTL(StorageTTLTest): + def set_time(self, offset, monkeypatch): + now = time.time() + def new_time(): + return now + offset + + monkeypatch.setattr(time, "time", new_time) + + def test_ttl(self): + self.execute_ttl_test("redis://localhost/0", 3600) + + def test_missing_module(self): + sys.modules.pop("redis.client") + with pytest.raises(ImportError): + self.prepare_db("redis://localhost/0", None) + from redis.client import Redis + + +class TestMongoTTL(StorageTTLTest): + def set_time(self, offset, monkeypatch): + now = datetime.datetime.utcnow() + def new_time(): + return now + datetime.timedelta(seconds=offset) + + monkeypatch.setattr(mongomock, "utcnow", new_time) + + def test_ttl(self): + self.execute_ttl_test("mongodb://localhost/pyop", 3600) + + def test_missing_module(self): + sys.modules.pop("pymongo") + with pytest.raises(ImportError): + self.prepare_db("mongodb://localhost/0", None) + import pymongo diff --git a/tests/test_requirements.txt b/tests/test_requirements.txt index 29a14fe..c92dad5 100644 --- a/tests/test_requirements.txt +++ b/tests/test_requirements.txt @@ -1,3 +1,6 @@ -pytest +pytest >= 6.2 +pip >= 19.0 responses pycryptodomex +fakeredis +mongomock