@@ -54,9 +54,9 @@ class BaseQuerySet:
5454 __dereference = False
5555 _auto_dereference = True
5656
57- def __init__ (self , document , collection ):
57+ def __init__ (self , document , db_alias = None ):
5858 self ._document = document
59- self ._collection_obj = collection
59+ self ._db_alias = db_alias
6060 self ._mongo_query = None
6161 self ._query_obj = Q ()
6262 self ._cls_query = {}
@@ -74,6 +74,8 @@ def __init__(self, document, collection):
7474 self ._as_pymongo = False
7575 self ._search_text = None
7676
77+ self .__init_using_collection ()
78+
7779 # If inheritance is allowed, only return instances and instances of
7880 # subclasses of the class being used
7981 if document ._meta .get ("allow_inheritance" ) is True :
@@ -100,6 +102,12 @@ def __init__(self, document, collection):
100102 # it anytime we change _limit. Inspired by how it is done in pymongo.Cursor
101103 self ._empty = False
102104
105+ def __init_using_collection (self ):
106+ self ._using_collection = None
107+ if self ._db_alias is not None :
108+ with switch_db (self ._document , self ._db_alias ) as cls :
109+ self ._using_collection = cls ._get_collection ()
110+
103111 def __call__ (self , q_obj = None , ** query ):
104112 """Filter the selected documents by calling the
105113 :class:`~mongoengine.queryset.QuerySet` with a query.
@@ -137,9 +145,6 @@ def __getstate__(self):
137145
138146 obj_dict = self .__dict__ .copy ()
139147
140- # don't picke collection, instead pickle collection params
141- obj_dict .pop ("_collection_obj" )
142-
143148 # don't pickle cursor
144149 obj_dict ["_cursor_obj" ] = None
145150
@@ -152,11 +157,11 @@ def __setstate__(self, obj_dict):
152157 See https://github.com/MongoEngine/mongoengine/issues/442
153158 """
154159
155- obj_dict ["_collection_obj" ] = obj_dict ["_document" ]._get_collection ()
156-
157160 # update attributes
158161 self .__dict__ .update (obj_dict )
159162
163+ self .__init_using_collection ()
164+
160165 # forse load cursor
161166 # self._cursor
162167
@@ -494,7 +499,7 @@ def delete(self, write_concern=None, _from_doc_delete=False, cascade_refs=None):
494499 if rule == CASCADE :
495500 cascade_refs = set () if cascade_refs is None else cascade_refs
496501 # Handle recursive reference
497- if doc ._collection == document_cls ._collection :
502+ if doc ._collection == document_cls ._get_collection () :
498503 for ref in queryset :
499504 cascade_refs .add (ref .id )
500505 refs = document_cls .objects (
@@ -777,14 +782,11 @@ def using(self, alias):
777782 :param alias: The database alias
778783 """
779784
780- with switch_db (self ._document , alias ) as cls :
781- collection = cls ._get_collection ()
782-
783- return self ._clone_into (self .__class__ (self ._document , collection ))
785+ return self ._clone_into (self .__class__ (self ._document , alias ))
784786
785787 def clone (self ):
786788 """Create a copy of the current queryset."""
787- return self ._clone_into (self .__class__ (self ._document , self ._collection_obj ))
789+ return self ._clone_into (self .__class__ (self ._document , self ._db_alias ))
788790
789791 def _clone_into (self , new_qs ):
790792 """Copy all of the relevant properties of this queryset to
@@ -1531,7 +1533,7 @@ def sum(self, field):
15311533 if isinstance (field_instances [- 1 ], ListField ):
15321534 pipeline .insert (1 , {"$unwind" : "$" + field })
15331535
1534- result = tuple (self ._document . _get_collection () .aggregate (pipeline ))
1536+ result = tuple (self ._collection .aggregate (pipeline ))
15351537
15361538 if result :
15371539 return result [0 ]["total" ]
@@ -1558,7 +1560,7 @@ def average(self, field):
15581560 if isinstance (field_instances [- 1 ], ListField ):
15591561 pipeline .insert (1 , {"$unwind" : "$" + field })
15601562
1561- result = tuple (self ._document . _get_collection () .aggregate (pipeline ))
1563+ result = tuple (self ._collection .aggregate (pipeline ))
15621564 if result :
15631565 return result [0 ]["total" ]
15641566 return 0
@@ -1620,7 +1622,8 @@ def _collection(self):
16201622 """Property that returns the collection object. This allows us to
16211623 perform operations only if the collection is accessed.
16221624 """
1623- return self ._collection_obj
1625+ return self ._document ._get_collection () \
1626+ if self ._using_collection is None else self ._using_collection
16241627
16251628 @property
16261629 def _cursor_args (self ):
0 commit comments