From 88c5599fac92a8a00ab91905046a99df8741b5be Mon Sep 17 00:00:00 2001 From: Jesse McLaughlin Date: Fri, 19 Aug 2022 16:09:49 +0100 Subject: [PATCH 1/4] [#2685] implement a thread-safe switch db context manager --- mongoengine/connection.py | 28 ++++++++++++++++++++++++++++ mongoengine/context_managers.py | 33 ++++++++++++++++++++++++++++++++- mongoengine/errors.py | 5 +++++ 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 49b665f5b..17da7135a 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,8 +1,10 @@ import warnings +from threading import local from pymongo import MongoClient, ReadPreference, uri_parser from pymongo.database import _check_name +from mongoengine.errors import DatabaseAliasError from mongoengine.pymongo_support import PYMONGO_VERSION __all__ = [ @@ -15,6 +17,8 @@ "get_connection", "get_db", "register_connection", + "set_local_db_alias", + "del_local_db_alias" ] @@ -26,6 +30,8 @@ _connection_settings = {} _connections = {} _dbs = {} +_local = local() +_local.db_alias = {} READ_PREFERENCE = ReadPreference.PRIMARY @@ -372,7 +378,29 @@ def _clean_settings(settings_dict): return _connections[db_alias] +def set_local_db_alias(local_alias, alias=DEFAULT_CONNECTION_NAME): + if not alias or not local_alias: + raise DatabaseAliasError(f"db alias and local_alias cannot be empty") + + if alias not in _local.db_alias: + _local.db_alias[alias] = local_alias + else: + raise DatabaseAliasError(f"local db alias already set: {alias}") + + +def del_local_db_alias(alias): + if not alias: + raise DatabaseAliasError(f"db alias cannot be empty") + if alias in _local.db_alias: + del _local.db_alias[alias] + else: + raise DatabaseAliasError(f"local db alias not set: {alias}") + + def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): + if alias in _local.db_alias: + alias = _local.db_alias[alias] + if reconnect: disconnect(alias) diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index eb9c99622..c6423e4c8 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -4,10 +4,11 @@ from pymongo.write_concern import WriteConcern from mongoengine.common import _import_class -from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db +from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db, set_local_db_alias, del_local_db_alias from mongoengine.pymongo_support import count_documents __all__ = ( + "switch_db_local", "switch_db", "switch_collection", "no_dereference", @@ -18,6 +19,36 @@ ) +class switch_db_local: + """switch_db_local alias context manager. + + Switches a db alias in a thread-safe way. + + Example :: + register_connection('testdb-1', 'mongoenginetest1') + register_connection('testdb-2', 'mongoenginetest2') + + class Group(Document): + name = StringField() + + # The following two calls to save() could be run concurrently + with switch_db_local('testdb-1'): + Group(name='test').save() + with switch_db_local('testdb-2'): + Group(name='test').save() + """ + + def __init__(self, local_alias, alias=DEFAULT_CONNECTION_NAME): + self.local_alias = local_alias + self.alias = alias + + def __enter__(self): + set_local_db_alias(self.local_alias, self.alias) + + def __exit__(self, t, value, traceback): + del_local_db_alias(self.alias) + + class switch_db: """switch_db alias context manager. diff --git a/mongoengine/errors.py b/mongoengine/errors.py index d789b2a10..4d1a219df 100644 --- a/mongoengine/errors.py +++ b/mongoengine/errors.py @@ -1,6 +1,7 @@ from collections import defaultdict __all__ = ( + "DatabaseAliasError", "NotRegistered", "InvalidDocumentError", "LookUpError", @@ -21,6 +22,10 @@ class MongoEngineException(Exception): pass +class DatabaseAliasError(MongoEngineException): + pass + + class NotRegistered(MongoEngineException): pass From 60965a61c59cec0f7743117ff6e1fb6e225b8635 Mon Sep 17 00:00:00 2001 From: Jesse McLaughlin Date: Tue, 30 Aug 2022 15:54:20 +0100 Subject: [PATCH 2/4] [#2685] support a stack of db switches --- mongoengine/connection.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 17da7135a..649c44949 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -383,23 +383,24 @@ def set_local_db_alias(local_alias, alias=DEFAULT_CONNECTION_NAME): raise DatabaseAliasError(f"db alias and local_alias cannot be empty") if alias not in _local.db_alias: - _local.db_alias[alias] = local_alias - else: - raise DatabaseAliasError(f"local db alias already set: {alias}") + _local.db_alias[alias] = [] + + _local.db_alias[alias].append(local_alias) def del_local_db_alias(alias): if not alias: raise DatabaseAliasError(f"db alias cannot be empty") - if alias in _local.db_alias: - del _local.db_alias[alias] - else: + + if alias not in _local.db_alias or not _local.db_alias[alias]: raise DatabaseAliasError(f"local db alias not set: {alias}") + _local.db_alias[alias].pop() + def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): - if alias in _local.db_alias: - alias = _local.db_alias[alias] + if alias in _local.db_alias and _local.db_alias[alias]: + alias = _local.db_alias[alias][-1] if reconnect: disconnect(alias) From 514bf14d5438e0495dd8ed69dd7b4b9f0a88fdbe Mon Sep 17 00:00:00 2001 From: Jesse McLaughlin Date: Sat, 3 Sep 2022 17:08:43 +0100 Subject: [PATCH 3/4] [#2685] document classes now cache a collection db alias that can be different per thread --- mongoengine/base/metaclasses.py | 2 +- mongoengine/connection.py | 10 ++++++++-- mongoengine/context_managers.py | 12 +++++++---- mongoengine/document.py | 29 +++++++++++++++++---------- mongoengine/queryset/base.py | 35 ++++++++++++++++++--------------- mongoengine/queryset/manager.py | 2 +- 6 files changed, 56 insertions(+), 34 deletions(-) diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index 072b3aada..fc19136b4 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -166,7 +166,7 @@ def __new__(mcs, name, bases, attrs): ) = mcs._import_classes() if issubclass(new_class, Document): - new_class._collection = None + new_class._collections = {} # Add class to the _document_registry _document_registry[new_class._class_name] = new_class diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 649c44949..cd9586df1 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -18,7 +18,8 @@ "get_db", "register_connection", "set_local_db_alias", - "del_local_db_alias" + "del_local_db_alias", + "get_local_db_alias" ] @@ -398,9 +399,14 @@ def del_local_db_alias(alias): _local.db_alias[alias].pop() -def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): +def get_local_db_alias(alias): if alias in _local.db_alias and _local.db_alias[alias]: alias = _local.db_alias[alias][-1] + return alias + + +def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): + alias = get_local_db_alias(alias) if reconnect: disconnect(alias) diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index c6423e4c8..0b4d4d20e 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -81,18 +81,22 @@ def __init__(self, cls, db_alias): def __enter__(self): """Change the db_alias and clear the cached collection.""" self.cls._meta["db_alias"] = self.db_alias - self.cls._collection = None + self.cls._set_collection(None) return self.cls def __exit__(self, t, value, traceback): """Reset the db_alias and collection.""" self.cls._meta["db_alias"] = self.ori_db_alias - self.cls._collection = self.collection + self.cls._set_collection(self.collection) class switch_collection: """switch_collection alias context manager. + Warning :: + + ### This is NOT completely thread-safe ### + Example :: class Group(Document): @@ -123,12 +127,12 @@ def _get_collection_name(cls): return self.collection_name self.cls._get_collection_name = _get_collection_name - self.cls._collection = None + self.cls._set_collection(None) return self.cls def __exit__(self, t, value, traceback): """Reset the collection.""" - self.cls._collection = self.ori_collection + self.cls._set_collection(self.ori_collection) self.cls._get_collection_name = self.ori_get_collection_name diff --git a/mongoengine/document.py b/mongoengine/document.py index e7a1938f2..b007c0d29 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -15,7 +15,7 @@ get_document, ) from mongoengine.common import _import_class -from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db +from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db, get_local_db_alias from mongoengine.context_managers import ( set_write_concern, switch_collection, @@ -196,15 +196,23 @@ def __hash__(self): return hash(self.pk) + @classmethod + def _get_local_db_alias(cls): + return get_local_db_alias(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME)) + @classmethod def _get_db(cls): """Some Model using other db_alias""" - return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME)) + return get_db(cls._get_local_db_alias()) @classmethod def _disconnect(cls): - """Detach the Document class from the (cached) database collection""" - cls._collection = None + """Detach the Document class from all (cached) database collections""" + cls._collections = {} + + @classmethod + def _set_collection(cls, collection): + cls._collections[cls._get_local_db_alias()] = collection @classmethod def _get_collection(cls): @@ -216,14 +224,15 @@ def _get_collection(cls): 2. Creates indexes defined in this document's :attr:`meta` dictionary. This happens only if `auto_create_index` is True. """ - if not hasattr(cls, "_collection") or cls._collection is None: + local_db_alias = cls._get_local_db_alias() + if local_db_alias not in cls._collections: # Get the collection, either capped or regular. if cls._meta.get("max_size") or cls._meta.get("max_documents"): - cls._collection = cls._get_capped_collection() + cls._collections[local_db_alias] = cls._get_capped_collection() else: db = cls._get_db() collection_name = cls._get_collection_name() - cls._collection = db[collection_name] + cls._collections[local_db_alias] = db[collection_name] # Ensure indexes on the collection unless auto_create_index was # set to False. @@ -232,7 +241,7 @@ def _get_collection(cls): if cls._meta.get("auto_create_index", True) and db.client.is_primary: cls.ensure_indexes() - return cls._collection + return cls._collections[local_db_alias] @classmethod def _get_capped_collection(cls): @@ -260,7 +269,7 @@ def _get_capped_collection(cls): if options.get("max") != max_documents or options.get("size") != max_size: raise InvalidCollectionError( 'Cannot create collection "{}" as a capped ' - "collection as it already exists".format(cls._collection) + "collection as it already exists".format(collection_name) ) return collection @@ -837,7 +846,7 @@ def drop_collection(cls): raise OperationError( "Document %s has no collection defined (is it abstract ?)" % cls ) - cls._collection = None + cls._set_collection(None) db = cls._get_db() db.drop_collection(coll_name) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 703ae2def..6c6966d63 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -54,9 +54,9 @@ class BaseQuerySet: __dereference = False _auto_dereference = True - def __init__(self, document, collection): + def __init__(self, document, db_alias=None): self._document = document - self._collection_obj = collection + self._db_alias = db_alias self._mongo_query = None self._query_obj = Q() self._cls_query = {} @@ -74,6 +74,8 @@ def __init__(self, document, collection): self._as_pymongo = False self._search_text = None + self.__init_using_collection() + # If inheritance is allowed, only return instances and instances of # subclasses of the class being used if document._meta.get("allow_inheritance") is True: @@ -100,6 +102,12 @@ def __init__(self, document, collection): # it anytime we change _limit. Inspired by how it is done in pymongo.Cursor self._empty = False + def __init_using_collection(self): + self._using_collection = None + if self._db_alias is not None: + with switch_db(self._document, self._db_alias) as cls: + self._using_collection = cls._get_collection() + def __call__(self, q_obj=None, **query): """Filter the selected documents by calling the :class:`~mongoengine.queryset.QuerySet` with a query. @@ -137,9 +145,6 @@ def __getstate__(self): obj_dict = self.__dict__.copy() - # don't picke collection, instead pickle collection params - obj_dict.pop("_collection_obj") - # don't pickle cursor obj_dict["_cursor_obj"] = None @@ -152,11 +157,11 @@ def __setstate__(self, obj_dict): See https://github.com/MongoEngine/mongoengine/issues/442 """ - obj_dict["_collection_obj"] = obj_dict["_document"]._get_collection() - # update attributes self.__dict__.update(obj_dict) + self.__init_using_collection() + # forse load cursor # self._cursor @@ -494,7 +499,7 @@ def delete(self, write_concern=None, _from_doc_delete=False, cascade_refs=None): if rule == CASCADE: cascade_refs = set() if cascade_refs is None else cascade_refs # Handle recursive reference - if doc._collection == document_cls._collection: + if doc._collection == document_cls._get_collection(): for ref in queryset: cascade_refs.add(ref.id) refs = document_cls.objects( @@ -777,14 +782,11 @@ def using(self, alias): :param alias: The database alias """ - with switch_db(self._document, alias) as cls: - collection = cls._get_collection() - - return self._clone_into(self.__class__(self._document, collection)) + return self._clone_into(self.__class__(self._document, alias)) def clone(self): """Create a copy of the current queryset.""" - return self._clone_into(self.__class__(self._document, self._collection_obj)) + return self._clone_into(self.__class__(self._document, self._db_alias)) def _clone_into(self, new_qs): """Copy all of the relevant properties of this queryset to @@ -1531,7 +1533,7 @@ def sum(self, field): if isinstance(field_instances[-1], ListField): pipeline.insert(1, {"$unwind": "$" + field}) - result = tuple(self._document._get_collection().aggregate(pipeline)) + result = tuple(self._collection.aggregate(pipeline)) if result: return result[0]["total"] @@ -1558,7 +1560,7 @@ def average(self, field): if isinstance(field_instances[-1], ListField): pipeline.insert(1, {"$unwind": "$" + field}) - result = tuple(self._document._get_collection().aggregate(pipeline)) + result = tuple(self._collection.aggregate(pipeline)) if result: return result[0]["total"] return 0 @@ -1620,7 +1622,8 @@ def _collection(self): """Property that returns the collection object. This allows us to perform operations only if the collection is accessed. """ - return self._collection_obj + return self._document._get_collection() \ + if self._using_collection is None else self._using_collection @property def _cursor_args(self): diff --git a/mongoengine/queryset/manager.py b/mongoengine/queryset/manager.py index 46f137a27..9fa880f94 100644 --- a/mongoengine/queryset/manager.py +++ b/mongoengine/queryset/manager.py @@ -35,7 +35,7 @@ def __get__(self, instance, owner): # owner is the document that contains the QuerySetManager queryset_class = owner._meta.get("queryset_class", self.default) - queryset = queryset_class(owner, owner._get_collection()) + queryset = queryset_class(owner) if self.get_queryset: arg_count = self.get_queryset.__code__.co_argcount if arg_count == 1: From 86c20e3f741dbdf07d1adaf9f254799c8deb38a5 Mon Sep 17 00:00:00 2001 From: Jesse McLaughlin Date: Fri, 23 Sep 2022 15:12:32 +0100 Subject: [PATCH 4/4] [#2685] properly initialise db_alias map for each thread --- mongoengine/connection.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/mongoengine/connection.py b/mongoengine/connection.py index cd9586df1..a385072eb 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -32,7 +32,6 @@ _connections = {} _dbs = {} _local = local() -_local.db_alias = {} READ_PREFERENCE = ReadPreference.PRIMARY @@ -379,29 +378,35 @@ def _clean_settings(settings_dict): return _connections[db_alias] +def __local_db_alias(): + if getattr(_local, "db_alias", None) is None: + _local.db_alias = {} + return _local.db_alias + + def set_local_db_alias(local_alias, alias=DEFAULT_CONNECTION_NAME): if not alias or not local_alias: raise DatabaseAliasError(f"db alias and local_alias cannot be empty") - if alias not in _local.db_alias: - _local.db_alias[alias] = [] + if alias not in __local_db_alias(): + __local_db_alias()[alias] = [] - _local.db_alias[alias].append(local_alias) + __local_db_alias()[alias].append(local_alias) def del_local_db_alias(alias): if not alias: raise DatabaseAliasError(f"db alias cannot be empty") - if alias not in _local.db_alias or not _local.db_alias[alias]: + if alias not in __local_db_alias() or not __local_db_alias()[alias]: raise DatabaseAliasError(f"local db alias not set: {alias}") - _local.db_alias[alias].pop() + __local_db_alias()[alias].pop() def get_local_db_alias(alias): - if alias in _local.db_alias and _local.db_alias[alias]: - alias = _local.db_alias[alias][-1] + if alias in __local_db_alias() and __local_db_alias()[alias]: + alias = __local_db_alias()[alias][-1] return alias