Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
author_email='[email protected]',
description='OpenID Connect Provider (OP) library in Python.',
install_requires=[
'oic >= 0.15.0',
'pymongo'
'oic >= 0.15.0'
]
extras_require=[
'pymongo',
'redis'
]
)
146 changes: 132 additions & 14 deletions src/pyop/storage.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,98 @@
# -*- coding: utf-8 -*-

from abc import ABC, abstractmethod
import copy
import pymongo
from time import time
import json
from datetime import datetime
import sys

try:
import pymongo
except ImportError:
pass

class MongoWrapper(object):
def __init__(self, db_uri, db_name, collection):
try:
from redis.client import Redis
except ImportError:
pass

class StorageBase(ABC):
_ttl = None

@abstractmethod
def __setitem__(self, key, value):
pass

@abstractmethod
def __getitem__(self, key):
pass

@abstractmethod
def __delitem__(self, key):
pass

@abstractmethod
def __contains__(self, key):
pass

@abstractmethod
def items(self):
pass

def pop(self, key, default=None):
try:
data = self[key]
except KeyError:
return default
del self[key]
return data

@classmethod
def from_uri(cls, db_uri, collection, db_name=None, ttl=None):
if db_uri.startswith("mongodb"):
return MongoWrapper(db_uri, db_name, collection, ttl)
if db_uri.startswith("redis") or db_uri.startswith("unix"):
return RedisWrapper(db_uri, collection, ttl)

return ValueError(f"Invalid DB URI: {db_uri}")

@property
def ttl(self):
return self._ttl

def ensure_dependency(self, dependencies):
for module in dependencies:
if not module in sys.modules:
raise ImportError(
f"'{module}' module is required but it is not available"
)


class MongoWrapper(StorageBase):
def __init__(self, db_uri, db_name, collection, ttl=None):
self.ensure_dependency(["pymongo"])
self._db_uri = db_uri
self._coll_name = collection
self._db = MongoDB(db_uri, db_name=db_name)
self._coll = self._db.get_collection(collection)
self._coll.create_index('lookup_key', unique=True)

if ttl is None or (isinstance(ttl, int) and ttl >= 0):
self._ttl = ttl
else:
raise ValueError("TTL must be a non-negative integer or None")
if ttl is not None:
self._coll.create_index(
'last_modified',
expireAfterSeconds=ttl,
name="expiry"
)

def __setitem__(self, key, value):
doc = {
'lookup_key': key,
'data': value,
'modified_ts': time()
'last_modified': datetime.utcnow()
}
self._coll.replace_one({'lookup_key': key}, doc, upsert=True)

Expand All @@ -38,13 +113,53 @@ def items(self):
for doc in self._coll.find():
yield (doc['lookup_key'], doc['data'])

def pop(self, key, default=None):
try:
data = self[key]
except KeyError:
return default
del self[key]
return data

class RedisWrapper(StorageBase):
"""
Simple wrapper for a dict-like storage in Redis.
Supports JSON-serializable data types.
"""

def __init__(self, db_uri, collection, ttl=None):
self.ensure_dependency(["redis.client"])
self._db = Redis.from_url(db_uri, decode_responses=True)
self._collection = collection
if ttl is None or (isinstance(ttl, int) and ttl >= 0):
self._ttl = ttl
else:
raise ValueError("TTL must be a non-negative integer or None")

def _make_key(self, key):
if not isinstance(key, str):
raise TypeError(f"Keys must be strings, {type(key).__name__} given")

return ":".join([self._collection, key])

def __setitem__(self, key, value):
# Replacing the value of a key resets the ttl counter
encoded = json.dumps({ "value": value })
self._db.set(self._make_key(key), encoded, ex=self.ttl)

def __getitem__(self, key):
encoded = self._db.get(self._make_key(key))
if encoded is None:
raise KeyError(key)
return json.loads(encoded).get("value")

def __delitem__(self, key):
# Deleting a non-existent key is allowed
self._db.delete(self._make_key(key))

def __contains__(self, key):
return (self._db.get(self._make_key(key)) is not None)

def items(self):
for key in self._db.keys(self._collection + "*"):
visible_key = key[len(self._collection) + 1 :]
try:
yield (visible_key, self[visible_key])
except KeyError:
pass


class MongoDB(object):
Expand All @@ -56,14 +171,17 @@ def __init__(self, db_uri, db_name=None,
if db_uri is None:
raise ValueError('db_uri not supplied')

self._db_uri = db_uri
self._database_name = db_name
self._sanitized_uri = None

self._parsed_uri = pymongo.uri_parser.parse_uri(db_uri)

if self._parsed_uri.get('database') is None:
if db_name is None:
raise ValueError(
"Database name must be provided either in the URI or as an argument"
)
self._parsed_uri['database'] = db_name
self._database_name = self._parsed_uri['database']

if 'replicaSet' in kwargs and kwargs['replicaSet'] is None:
del kwargs['replicaSet']
Expand Down
84 changes: 0 additions & 84 deletions tests/pyop/conftest.py

This file was deleted.

Loading