Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#2685] implement a thread-safe switch db context manager #2686

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mongoengine/base/metaclasses.py
Original file line number Diff line number Diff line change
@@ -166,7 +166,7 @@ def __new__(mcs, name, bases, attrs):
) = mcs._import_classes()

if issubclass(new_class, Document):
new_class._collection = None
new_class._collections = {}

# Add class to the _document_registry
_document_registry[new_class._class_name] = new_class
40 changes: 40 additions & 0 deletions mongoengine/connection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import warnings
from threading import local

from pymongo import MongoClient, ReadPreference, uri_parser
from pymongo.database import _check_name

from mongoengine.errors import DatabaseAliasError
from mongoengine.pymongo_support import PYMONGO_VERSION

__all__ = [
@@ -15,6 +17,9 @@
"get_connection",
"get_db",
"register_connection",
"set_local_db_alias",
"del_local_db_alias",
"get_local_db_alias"
]


@@ -26,6 +31,7 @@
_connection_settings = {}
_connections = {}
_dbs = {}
_local = local()

READ_PREFERENCE = ReadPreference.PRIMARY

@@ -372,7 +378,41 @@ def _clean_settings(settings_dict):
return _connections[db_alias]


def __local_db_alias():
if getattr(_local, "db_alias", None) is None:
_local.db_alias = {}
return _local.db_alias


def set_local_db_alias(local_alias, alias=DEFAULT_CONNECTION_NAME):
if not alias or not local_alias:
raise DatabaseAliasError(f"db alias and local_alias cannot be empty")

if alias not in __local_db_alias():
__local_db_alias()[alias] = []

__local_db_alias()[alias].append(local_alias)


def del_local_db_alias(alias):
if not alias:
raise DatabaseAliasError(f"db alias cannot be empty")

if alias not in __local_db_alias() or not __local_db_alias()[alias]:
raise DatabaseAliasError(f"local db alias not set: {alias}")

__local_db_alias()[alias].pop()


def get_local_db_alias(alias):
if alias in __local_db_alias() and __local_db_alias()[alias]:
alias = __local_db_alias()[alias][-1]
return alias


def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
alias = get_local_db_alias(alias)

if reconnect:
disconnect(alias)

45 changes: 40 additions & 5 deletions mongoengine/context_managers.py
Original file line number Diff line number Diff line change
@@ -4,10 +4,11 @@
from pymongo.write_concern import WriteConcern

from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db, set_local_db_alias, del_local_db_alias
from mongoengine.pymongo_support import count_documents

__all__ = (
"switch_db_local",
"switch_db",
"switch_collection",
"no_dereference",
@@ -18,6 +19,36 @@
)


class switch_db_local:
"""switch_db_local alias context manager.

Switches a db alias in a thread-safe way.

Example ::
register_connection('testdb-1', 'mongoenginetest1')
register_connection('testdb-2', 'mongoenginetest2')

class Group(Document):
name = StringField()

# The following two calls to save() could be run concurrently
with switch_db_local('testdb-1'):
Group(name='test').save()
with switch_db_local('testdb-2'):
Group(name='test').save()
"""

def __init__(self, local_alias, alias=DEFAULT_CONNECTION_NAME):
self.local_alias = local_alias
self.alias = alias

def __enter__(self):
set_local_db_alias(self.local_alias, self.alias)

def __exit__(self, t, value, traceback):
del_local_db_alias(self.alias)


class switch_db:
"""switch_db alias context manager.

@@ -50,18 +81,22 @@ def __init__(self, cls, db_alias):
def __enter__(self):
"""Change the db_alias and clear the cached collection."""
self.cls._meta["db_alias"] = self.db_alias
self.cls._collection = None
self.cls._set_collection(None)
return self.cls

def __exit__(self, t, value, traceback):
"""Reset the db_alias and collection."""
self.cls._meta["db_alias"] = self.ori_db_alias
self.cls._collection = self.collection
self.cls._set_collection(self.collection)


class switch_collection:
"""switch_collection alias context manager.

Warning ::

### This is NOT completely thread-safe ###

Example ::

class Group(Document):
@@ -92,12 +127,12 @@ def _get_collection_name(cls):
return self.collection_name

self.cls._get_collection_name = _get_collection_name
self.cls._collection = None
self.cls._set_collection(None)
return self.cls

def __exit__(self, t, value, traceback):
"""Reset the collection."""
self.cls._collection = self.ori_collection
self.cls._set_collection(self.ori_collection)
self.cls._get_collection_name = self.ori_get_collection_name


29 changes: 19 additions & 10 deletions mongoengine/document.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
get_document,
)
from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db, get_local_db_alias
from mongoengine.context_managers import (
set_write_concern,
switch_collection,
@@ -196,15 +196,23 @@ def __hash__(self):

return hash(self.pk)

@classmethod
def _get_local_db_alias(cls):
return get_local_db_alias(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME))

@classmethod
def _get_db(cls):
"""Some Model using other db_alias"""
return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME))
return get_db(cls._get_local_db_alias())

@classmethod
def _disconnect(cls):
"""Detach the Document class from the (cached) database collection"""
cls._collection = None
"""Detach the Document class from all (cached) database collections"""
cls._collections = {}

@classmethod
def _set_collection(cls, collection):
cls._collections[cls._get_local_db_alias()] = collection

@classmethod
def _get_collection(cls):
@@ -216,14 +224,15 @@ def _get_collection(cls):
2. Creates indexes defined in this document's :attr:`meta` dictionary.
This happens only if `auto_create_index` is True.
"""
if not hasattr(cls, "_collection") or cls._collection is None:
local_db_alias = cls._get_local_db_alias()
if local_db_alias not in cls._collections:
# Get the collection, either capped or regular.
if cls._meta.get("max_size") or cls._meta.get("max_documents"):
cls._collection = cls._get_capped_collection()
cls._collections[local_db_alias] = cls._get_capped_collection()
else:
db = cls._get_db()
collection_name = cls._get_collection_name()
cls._collection = db[collection_name]
cls._collections[local_db_alias] = db[collection_name]

# Ensure indexes on the collection unless auto_create_index was
# set to False.
@@ -232,7 +241,7 @@ def _get_collection(cls):
if cls._meta.get("auto_create_index", True) and db.client.is_primary:
cls.ensure_indexes()

return cls._collection
return cls._collections[local_db_alias]

@classmethod
def _get_capped_collection(cls):
@@ -260,7 +269,7 @@ def _get_capped_collection(cls):
if options.get("max") != max_documents or options.get("size") != max_size:
raise InvalidCollectionError(
'Cannot create collection "{}" as a capped '
"collection as it already exists".format(cls._collection)
"collection as it already exists".format(collection_name)
)

return collection
@@ -837,7 +846,7 @@ def drop_collection(cls):
raise OperationError(
"Document %s has no collection defined (is it abstract ?)" % cls
)
cls._collection = None
cls._set_collection(None)
db = cls._get_db()
db.drop_collection(coll_name)

5 changes: 5 additions & 0 deletions mongoengine/errors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import defaultdict

__all__ = (
"DatabaseAliasError",
"NotRegistered",
"InvalidDocumentError",
"LookUpError",
@@ -21,6 +22,10 @@ class MongoEngineException(Exception):
pass


class DatabaseAliasError(MongoEngineException):
pass


class NotRegistered(MongoEngineException):
pass

35 changes: 19 additions & 16 deletions mongoengine/queryset/base.py
Original file line number Diff line number Diff line change
@@ -54,9 +54,9 @@ class BaseQuerySet:
__dereference = False
_auto_dereference = True

def __init__(self, document, collection):
def __init__(self, document, db_alias=None):
self._document = document
self._collection_obj = collection
self._db_alias = db_alias
self._mongo_query = None
self._query_obj = Q()
self._cls_query = {}
@@ -74,6 +74,8 @@ def __init__(self, document, collection):
self._as_pymongo = False
self._search_text = None

self.__init_using_collection()

# If inheritance is allowed, only return instances and instances of
# subclasses of the class being used
if document._meta.get("allow_inheritance") is True:
@@ -100,6 +102,12 @@ def __init__(self, document, collection):
# it anytime we change _limit. Inspired by how it is done in pymongo.Cursor
self._empty = False

def __init_using_collection(self):
self._using_collection = None
if self._db_alias is not None:
with switch_db(self._document, self._db_alias) as cls:
self._using_collection = cls._get_collection()

def __call__(self, q_obj=None, **query):
"""Filter the selected documents by calling the
:class:`~mongoengine.queryset.QuerySet` with a query.
@@ -137,9 +145,6 @@ def __getstate__(self):

obj_dict = self.__dict__.copy()

# don't picke collection, instead pickle collection params
obj_dict.pop("_collection_obj")

# don't pickle cursor
obj_dict["_cursor_obj"] = None

@@ -152,11 +157,11 @@ def __setstate__(self, obj_dict):
See https://github.com/MongoEngine/mongoengine/issues/442
"""

obj_dict["_collection_obj"] = obj_dict["_document"]._get_collection()

# update attributes
self.__dict__.update(obj_dict)

self.__init_using_collection()

# forse load cursor
# self._cursor

@@ -494,7 +499,7 @@ def delete(self, write_concern=None, _from_doc_delete=False, cascade_refs=None):
if rule == CASCADE:
cascade_refs = set() if cascade_refs is None else cascade_refs
# Handle recursive reference
if doc._collection == document_cls._collection:
if doc._collection == document_cls._get_collection():
for ref in queryset:
cascade_refs.add(ref.id)
refs = document_cls.objects(
@@ -777,14 +782,11 @@ def using(self, alias):
:param alias: The database alias
"""

with switch_db(self._document, alias) as cls:
collection = cls._get_collection()

return self._clone_into(self.__class__(self._document, collection))
return self._clone_into(self.__class__(self._document, alias))

def clone(self):
"""Create a copy of the current queryset."""
return self._clone_into(self.__class__(self._document, self._collection_obj))
return self._clone_into(self.__class__(self._document, self._db_alias))

def _clone_into(self, new_qs):
"""Copy all of the relevant properties of this queryset to
@@ -1531,7 +1533,7 @@ def sum(self, field):
if isinstance(field_instances[-1], ListField):
pipeline.insert(1, {"$unwind": "$" + field})

result = tuple(self._document._get_collection().aggregate(pipeline))
result = tuple(self._collection.aggregate(pipeline))

if result:
return result[0]["total"]
@@ -1558,7 +1560,7 @@ def average(self, field):
if isinstance(field_instances[-1], ListField):
pipeline.insert(1, {"$unwind": "$" + field})

result = tuple(self._document._get_collection().aggregate(pipeline))
result = tuple(self._collection.aggregate(pipeline))
if result:
return result[0]["total"]
return 0
@@ -1620,7 +1622,8 @@ def _collection(self):
"""Property that returns the collection object. This allows us to
perform operations only if the collection is accessed.
"""
return self._collection_obj
return self._document._get_collection() \
if self._using_collection is None else self._using_collection

@property
def _cursor_args(self):
2 changes: 1 addition & 1 deletion mongoengine/queryset/manager.py
Original file line number Diff line number Diff line change
@@ -35,7 +35,7 @@ def __get__(self, instance, owner):

# owner is the document that contains the QuerySetManager
queryset_class = owner._meta.get("queryset_class", self.default)
queryset = queryset_class(owner, owner._get_collection())
queryset = queryset_class(owner)
if self.get_queryset:
arg_count = self.get_queryset.__code__.co_argcount
if arg_count == 1: