@@ -54,9 +54,9 @@ class BaseQuerySet:
54
54
__dereference = False
55
55
_auto_dereference = True
56
56
57
- def __init__ (self , document , collection ):
57
+ def __init__ (self , document , db_alias = None ):
58
58
self ._document = document
59
- self ._collection_obj = collection
59
+ self ._db_alias = db_alias
60
60
self ._mongo_query = None
61
61
self ._query_obj = Q ()
62
62
self ._cls_query = {}
@@ -74,6 +74,8 @@ def __init__(self, document, collection):
74
74
self ._as_pymongo = False
75
75
self ._search_text = None
76
76
77
+ self .__init_using_collection ()
78
+
77
79
# If inheritance is allowed, only return instances and instances of
78
80
# subclasses of the class being used
79
81
if document ._meta .get ("allow_inheritance" ) is True :
@@ -100,6 +102,12 @@ def __init__(self, document, collection):
100
102
# it anytime we change _limit. Inspired by how it is done in pymongo.Cursor
101
103
self ._empty = False
102
104
105
+ def __init_using_collection (self ):
106
+ self ._using_collection = None
107
+ if self ._db_alias is not None :
108
+ with switch_db (self ._document , self ._db_alias ) as cls :
109
+ self ._using_collection = cls ._get_collection ()
110
+
103
111
def __call__ (self , q_obj = None , ** query ):
104
112
"""Filter the selected documents by calling the
105
113
:class:`~mongoengine.queryset.QuerySet` with a query.
@@ -137,9 +145,6 @@ def __getstate__(self):
137
145
138
146
obj_dict = self .__dict__ .copy ()
139
147
140
- # don't picke collection, instead pickle collection params
141
- obj_dict .pop ("_collection_obj" )
142
-
143
148
# don't pickle cursor
144
149
obj_dict ["_cursor_obj" ] = None
145
150
@@ -152,11 +157,11 @@ def __setstate__(self, obj_dict):
152
157
See https://github.com/MongoEngine/mongoengine/issues/442
153
158
"""
154
159
155
- obj_dict ["_collection_obj" ] = obj_dict ["_document" ]._get_collection ()
156
-
157
160
# update attributes
158
161
self .__dict__ .update (obj_dict )
159
162
163
+ self .__init_using_collection ()
164
+
160
165
# forse load cursor
161
166
# self._cursor
162
167
@@ -494,7 +499,7 @@ def delete(self, write_concern=None, _from_doc_delete=False, cascade_refs=None):
494
499
if rule == CASCADE :
495
500
cascade_refs = set () if cascade_refs is None else cascade_refs
496
501
# Handle recursive reference
497
- if doc ._collection == document_cls ._collection :
502
+ if doc ._collection == document_cls ._get_collection () :
498
503
for ref in queryset :
499
504
cascade_refs .add (ref .id )
500
505
refs = document_cls .objects (
@@ -777,14 +782,11 @@ def using(self, alias):
777
782
:param alias: The database alias
778
783
"""
779
784
780
- with switch_db (self ._document , alias ) as cls :
781
- collection = cls ._get_collection ()
782
-
783
- return self ._clone_into (self .__class__ (self ._document , collection ))
785
+ return self ._clone_into (self .__class__ (self ._document , alias ))
784
786
785
787
def clone (self ):
786
788
"""Create a copy of the current queryset."""
787
- return self ._clone_into (self .__class__ (self ._document , self ._collection_obj ))
789
+ return self ._clone_into (self .__class__ (self ._document , self ._db_alias ))
788
790
789
791
def _clone_into (self , new_qs ):
790
792
"""Copy all of the relevant properties of this queryset to
@@ -1531,7 +1533,7 @@ def sum(self, field):
1531
1533
if isinstance (field_instances [- 1 ], ListField ):
1532
1534
pipeline .insert (1 , {"$unwind" : "$" + field })
1533
1535
1534
- result = tuple (self ._document . _get_collection () .aggregate (pipeline ))
1536
+ result = tuple (self ._collection .aggregate (pipeline ))
1535
1537
1536
1538
if result :
1537
1539
return result [0 ]["total" ]
@@ -1558,7 +1560,7 @@ def average(self, field):
1558
1560
if isinstance (field_instances [- 1 ], ListField ):
1559
1561
pipeline .insert (1 , {"$unwind" : "$" + field })
1560
1562
1561
- result = tuple (self ._document . _get_collection () .aggregate (pipeline ))
1563
+ result = tuple (self ._collection .aggregate (pipeline ))
1562
1564
if result :
1563
1565
return result [0 ]["total" ]
1564
1566
return 0
@@ -1620,7 +1622,8 @@ def _collection(self):
1620
1622
"""Property that returns the collection object. This allows us to
1621
1623
perform operations only if the collection is accessed.
1622
1624
"""
1623
- return self ._collection_obj
1625
+ return self ._document ._get_collection () \
1626
+ if self ._using_collection is None else self ._using_collection
1624
1627
1625
1628
@property
1626
1629
def _cursor_args (self ):
0 commit comments