Skip to content

Commit 514bf14

Browse files
author
Jesse McLaughlin
committed
[#2685] document classes now cache a collection db alias that can be different per thread
1 parent 60965a6 commit 514bf14

File tree

6 files changed

+56
-34
lines changed

6 files changed

+56
-34
lines changed

mongoengine/base/metaclasses.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def __new__(mcs, name, bases, attrs):
166166
) = mcs._import_classes()
167167

168168
if issubclass(new_class, Document):
169-
new_class._collection = None
169+
new_class._collections = {}
170170

171171
# Add class to the _document_registry
172172
_document_registry[new_class._class_name] = new_class

mongoengine/connection.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
"get_db",
1919
"register_connection",
2020
"set_local_db_alias",
21-
"del_local_db_alias"
21+
"del_local_db_alias",
22+
"get_local_db_alias"
2223
]
2324

2425

@@ -398,9 +399,14 @@ def del_local_db_alias(alias):
398399
_local.db_alias[alias].pop()
399400

400401

401-
def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
402+
def get_local_db_alias(alias):
402403
if alias in _local.db_alias and _local.db_alias[alias]:
403404
alias = _local.db_alias[alias][-1]
405+
return alias
406+
407+
408+
def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
409+
alias = get_local_db_alias(alias)
404410

405411
if reconnect:
406412
disconnect(alias)

mongoengine/context_managers.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,22 @@ def __init__(self, cls, db_alias):
8181
def __enter__(self):
8282
"""Change the db_alias and clear the cached collection."""
8383
self.cls._meta["db_alias"] = self.db_alias
84-
self.cls._collection = None
84+
self.cls._set_collection(None)
8585
return self.cls
8686

8787
def __exit__(self, t, value, traceback):
8888
"""Reset the db_alias and collection."""
8989
self.cls._meta["db_alias"] = self.ori_db_alias
90-
self.cls._collection = self.collection
90+
self.cls._set_collection(self.collection)
9191

9292

9393
class switch_collection:
9494
"""switch_collection alias context manager.
9595
96+
Warning ::
97+
98+
### This is NOT completely thread-safe ###
99+
96100
Example ::
97101
98102
class Group(Document):
@@ -123,12 +127,12 @@ def _get_collection_name(cls):
123127
return self.collection_name
124128

125129
self.cls._get_collection_name = _get_collection_name
126-
self.cls._collection = None
130+
self.cls._set_collection(None)
127131
return self.cls
128132

129133
def __exit__(self, t, value, traceback):
130134
"""Reset the collection."""
131-
self.cls._collection = self.ori_collection
135+
self.cls._set_collection(self.ori_collection)
132136
self.cls._get_collection_name = self.ori_get_collection_name
133137

134138

mongoengine/document.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
get_document,
1616
)
1717
from mongoengine.common import _import_class
18-
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
18+
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db, get_local_db_alias
1919
from mongoengine.context_managers import (
2020
set_write_concern,
2121
switch_collection,
@@ -196,15 +196,23 @@ def __hash__(self):
196196

197197
return hash(self.pk)
198198

199+
@classmethod
200+
def _get_local_db_alias(cls):
201+
return get_local_db_alias(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME))
202+
199203
@classmethod
200204
def _get_db(cls):
201205
"""Some Model using other db_alias"""
202-
return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME))
206+
return get_db(cls._get_local_db_alias())
203207

204208
@classmethod
205209
def _disconnect(cls):
206-
"""Detach the Document class from the (cached) database collection"""
207-
cls._collection = None
210+
"""Detach the Document class from all (cached) database collections"""
211+
cls._collections = {}
212+
213+
@classmethod
214+
def _set_collection(cls, collection):
215+
cls._collections[cls._get_local_db_alias()] = collection
208216

209217
@classmethod
210218
def _get_collection(cls):
@@ -216,14 +224,15 @@ def _get_collection(cls):
216224
2. Creates indexes defined in this document's :attr:`meta` dictionary.
217225
This happens only if `auto_create_index` is True.
218226
"""
219-
if not hasattr(cls, "_collection") or cls._collection is None:
227+
local_db_alias = cls._get_local_db_alias()
228+
if local_db_alias not in cls._collections:
220229
# Get the collection, either capped or regular.
221230
if cls._meta.get("max_size") or cls._meta.get("max_documents"):
222-
cls._collection = cls._get_capped_collection()
231+
cls._collections[local_db_alias] = cls._get_capped_collection()
223232
else:
224233
db = cls._get_db()
225234
collection_name = cls._get_collection_name()
226-
cls._collection = db[collection_name]
235+
cls._collections[local_db_alias] = db[collection_name]
227236

228237
# Ensure indexes on the collection unless auto_create_index was
229238
# set to False.
@@ -232,7 +241,7 @@ def _get_collection(cls):
232241
if cls._meta.get("auto_create_index", True) and db.client.is_primary:
233242
cls.ensure_indexes()
234243

235-
return cls._collection
244+
return cls._collections[local_db_alias]
236245

237246
@classmethod
238247
def _get_capped_collection(cls):
@@ -260,7 +269,7 @@ def _get_capped_collection(cls):
260269
if options.get("max") != max_documents or options.get("size") != max_size:
261270
raise InvalidCollectionError(
262271
'Cannot create collection "{}" as a capped '
263-
"collection as it already exists".format(cls._collection)
272+
"collection as it already exists".format(collection_name)
264273
)
265274

266275
return collection
@@ -837,7 +846,7 @@ def drop_collection(cls):
837846
raise OperationError(
838847
"Document %s has no collection defined (is it abstract ?)" % cls
839848
)
840-
cls._collection = None
849+
cls._set_collection(None)
841850
db = cls._get_db()
842851
db.drop_collection(coll_name)
843852

mongoengine/queryset/base.py

+19-16
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ class BaseQuerySet:
5454
__dereference = False
5555
_auto_dereference = True
5656

57-
def __init__(self, document, collection):
57+
def __init__(self, document, db_alias=None):
5858
self._document = document
59-
self._collection_obj = collection
59+
self._db_alias = db_alias
6060
self._mongo_query = None
6161
self._query_obj = Q()
6262
self._cls_query = {}
@@ -74,6 +74,8 @@ def __init__(self, document, collection):
7474
self._as_pymongo = False
7575
self._search_text = None
7676

77+
self.__init_using_collection()
78+
7779
# If inheritance is allowed, only return instances and instances of
7880
# subclasses of the class being used
7981
if document._meta.get("allow_inheritance") is True:
@@ -100,6 +102,12 @@ def __init__(self, document, collection):
100102
# it anytime we change _limit. Inspired by how it is done in pymongo.Cursor
101103
self._empty = False
102104

105+
def __init_using_collection(self):
106+
self._using_collection = None
107+
if self._db_alias is not None:
108+
with switch_db(self._document, self._db_alias) as cls:
109+
self._using_collection = cls._get_collection()
110+
103111
def __call__(self, q_obj=None, **query):
104112
"""Filter the selected documents by calling the
105113
:class:`~mongoengine.queryset.QuerySet` with a query.
@@ -137,9 +145,6 @@ def __getstate__(self):
137145

138146
obj_dict = self.__dict__.copy()
139147

140-
# don't picke collection, instead pickle collection params
141-
obj_dict.pop("_collection_obj")
142-
143148
# don't pickle cursor
144149
obj_dict["_cursor_obj"] = None
145150

@@ -152,11 +157,11 @@ def __setstate__(self, obj_dict):
152157
See https://github.com/MongoEngine/mongoengine/issues/442
153158
"""
154159

155-
obj_dict["_collection_obj"] = obj_dict["_document"]._get_collection()
156-
157160
# update attributes
158161
self.__dict__.update(obj_dict)
159162

163+
self.__init_using_collection()
164+
160165
# forse load cursor
161166
# self._cursor
162167

@@ -494,7 +499,7 @@ def delete(self, write_concern=None, _from_doc_delete=False, cascade_refs=None):
494499
if rule == CASCADE:
495500
cascade_refs = set() if cascade_refs is None else cascade_refs
496501
# Handle recursive reference
497-
if doc._collection == document_cls._collection:
502+
if doc._collection == document_cls._get_collection():
498503
for ref in queryset:
499504
cascade_refs.add(ref.id)
500505
refs = document_cls.objects(
@@ -777,14 +782,11 @@ def using(self, alias):
777782
:param alias: The database alias
778783
"""
779784

780-
with switch_db(self._document, alias) as cls:
781-
collection = cls._get_collection()
782-
783-
return self._clone_into(self.__class__(self._document, collection))
785+
return self._clone_into(self.__class__(self._document, alias))
784786

785787
def clone(self):
786788
"""Create a copy of the current queryset."""
787-
return self._clone_into(self.__class__(self._document, self._collection_obj))
789+
return self._clone_into(self.__class__(self._document, self._db_alias))
788790

789791
def _clone_into(self, new_qs):
790792
"""Copy all of the relevant properties of this queryset to
@@ -1531,7 +1533,7 @@ def sum(self, field):
15311533
if isinstance(field_instances[-1], ListField):
15321534
pipeline.insert(1, {"$unwind": "$" + field})
15331535

1534-
result = tuple(self._document._get_collection().aggregate(pipeline))
1536+
result = tuple(self._collection.aggregate(pipeline))
15351537

15361538
if result:
15371539
return result[0]["total"]
@@ -1558,7 +1560,7 @@ def average(self, field):
15581560
if isinstance(field_instances[-1], ListField):
15591561
pipeline.insert(1, {"$unwind": "$" + field})
15601562

1561-
result = tuple(self._document._get_collection().aggregate(pipeline))
1563+
result = tuple(self._collection.aggregate(pipeline))
15621564
if result:
15631565
return result[0]["total"]
15641566
return 0
@@ -1620,7 +1622,8 @@ def _collection(self):
16201622
"""Property that returns the collection object. This allows us to
16211623
perform operations only if the collection is accessed.
16221624
"""
1623-
return self._collection_obj
1625+
return self._document._get_collection() \
1626+
if self._using_collection is None else self._using_collection
16241627

16251628
@property
16261629
def _cursor_args(self):

mongoengine/queryset/manager.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __get__(self, instance, owner):
3535

3636
# owner is the document that contains the QuerySetManager
3737
queryset_class = owner._meta.get("queryset_class", self.default)
38-
queryset = queryset_class(owner, owner._get_collection())
38+
queryset = queryset_class(owner)
3939
if self.get_queryset:
4040
arg_count = self.get_queryset.__code__.co_argcount
4141
if arg_count == 1:

0 commit comments

Comments
 (0)