Skip to content

Commit bb26dd7

Browse files
committed
Add type annotations to cqlengine models and query modules
Based on VincentRPS/python-driver#fix-engine-typing with fixes: - Add `from __future__ import annotations` to both files - Fix `__ne__` infinite recursion bug (`not (self != q)` -> `not (self == q)`) - Fix `all()` and `filter()` return types to return `ModelQuerySet` not `list` - Drop incorrect `__getitem__` override on BaseModel (would break column access) - Drop wrong `objects: query.ModelQuerySet` annotation (it's a descriptor) - Drop unused `T = TypeVar` with forward reference issue
1 parent 0b1802b commit bb26dd7

2 files changed

Lines changed: 36 additions & 23 deletions

File tree

cassandra/cqlengine/models.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import logging
1618
import re
19+
from typing import TYPE_CHECKING, TypeVar
1720
from warnings import warn
1821

1922
from cassandra.cqlengine import CQLEngineException, ValidationError
@@ -329,6 +332,9 @@ def __delete__(self, instance):
329332
raise AttributeError('cannot delete {0} columns'.format(self.column.column_name))
330333

331334

335+
M = TypeVar('M', bound='BaseModel')
336+
337+
332338
class BaseModel(object):
333339
"""
334340
The base model class, don't inherit from this, inherit from Model, defined below
@@ -657,7 +663,7 @@ def _as_dict(self):
657663
return values
658664

659665
@classmethod
660-
def create(cls, **kwargs):
666+
def create(cls: type[M], **kwargs) -> M:
661667
"""
662668
Create an instance of this model in the database.
663669
@@ -672,7 +678,7 @@ def create(cls, **kwargs):
672678
return cls.objects.create(**kwargs)
673679

674680
@classmethod
675-
def all(cls):
681+
def all(cls: type[M]) -> query.ModelQuerySet:
676682
"""
677683
Returns a queryset representing all stored objects
678684
@@ -681,7 +687,7 @@ def all(cls):
681687
return cls.objects.all()
682688

683689
@classmethod
684-
def filter(cls, *args, **kwargs):
690+
def filter(cls: type[M], *args, **kwargs) -> query.ModelQuerySet:
685691
"""
686692
Returns a queryset based on filter parameters.
687693
@@ -690,15 +696,15 @@ def filter(cls, *args, **kwargs):
690696
return cls.objects.filter(*args, **kwargs)
691697

692698
@classmethod
693-
def get(cls, *args, **kwargs):
699+
def get(cls: type[M], *args, **kwargs) -> M:
694700
"""
695701
Returns a single object based on the passed filter constraints.
696702
697703
This is a pass-through to the model objects().:method:`~cqlengine.queries.get`.
698704
"""
699705
return cls.objects.get(*args, **kwargs)
700706

701-
def timeout(self, timeout):
707+
def timeout(self: M, timeout: float | None) -> M:
702708
"""
703709
Sets a timeout for use in :meth:`~.save`, :meth:`~.update`, and :meth:`~.delete`
704710
operations
@@ -707,7 +713,7 @@ def timeout(self, timeout):
707713
self._timeout = timeout
708714
return self
709715

710-
def save(self):
716+
def save(self: M) -> M:
711717
"""
712718
Saves an object to the database.
713719
@@ -743,7 +749,7 @@ def save(self):
743749

744750
return self
745751

746-
def update(self, **values):
752+
def update(self: M, **values) -> M:
747753
"""
748754
Performs an update on the model instance. You can pass in values to set on the model
749755
for updating, or you can call without values to execute an update against any modified

cassandra/cqlengine/query.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import copy
1618
from datetime import datetime, timedelta
1719
from functools import partial
1820
import time
21+
from typing import TYPE_CHECKING, TypeVar
1922
from warnings import warn
2023

2124
from cassandra.query import SimpleStatement, BatchType as CBatchType, BatchStatement
@@ -336,9 +339,13 @@ def __exit__(self, exc_type, exc_val, exc_tb):
336339
return
337340

338341

342+
if TYPE_CHECKING:
343+
from cassandra.cqlengine.models import M
344+
345+
339346
class AbstractQuerySet(object):
340347

341-
def __init__(self, model):
348+
def __init__(self, model: type[M]):
342349
super(AbstractQuerySet, self).__init__()
343350
self.model = model
344351

@@ -528,7 +535,7 @@ def __iter__(self):
528535

529536
idx += 1
530537

531-
def __getitem__(self, s):
538+
def __getitem__(self, s: slice | int) -> M | list[M]:
532539
self._execute_query()
533540

534541
if isinstance(s, slice):
@@ -601,7 +608,7 @@ def batch(self, batch_obj):
601608
clone._batch = batch_obj
602609
return clone
603610

604-
def first(self):
611+
def first(self) -> M | None:
605612
try:
606613
return next(iter(self))
607614
except StopIteration:
@@ -618,7 +625,7 @@ def all(self):
618625
"""
619626
return copy.deepcopy(self)
620627

621-
def consistency(self, consistency):
628+
def consistency(self, consistency: int):
622629
"""
623630
Sets the consistency level for the operation. See :class:`.ConsistencyLevel`.
624631
@@ -742,7 +749,7 @@ def filter(self, *args, **kwargs):
742749

743750
return clone
744751

745-
def get(self, *args, **kwargs):
752+
def get(self, *args, **kwargs) -> M:
746753
"""
747754
Returns a single instance matching this query, optionally with additional filter kwargs.
748755
@@ -783,7 +790,7 @@ def _get_ordering_condition(self, colname):
783790

784791
return colname, order_type
785792

786-
def order_by(self, *colnames):
793+
def order_by(self, *colnames: str):
787794
"""
788795
Sets the column(s) to be used for ordering
789796
@@ -827,7 +834,7 @@ class Comment(Model):
827834
clone._order.extend(conditions)
828835
return clone
829836

830-
def count(self):
837+
def count(self) -> int:
831838
"""
832839
Returns the number of rows matched by this query.
833840
@@ -880,7 +887,7 @@ class Automobile(Model):
880887

881888
return clone
882889

883-
def limit(self, v):
890+
def limit(self, v: int):
884891
"""
885892
Limits the number of results returned by Cassandra. Use *0* or *None* to disable.
886893
@@ -912,7 +919,7 @@ def limit(self, v):
912919
clone._limit = v
913920
return clone
914921

915-
def fetch_size(self, v):
922+
def fetch_size(self, v: int):
916923
"""
917924
Sets the number of rows that are fetched at a time.
918925
@@ -968,15 +975,15 @@ def _only_or_defer(self, action, fields):
968975

969976
return clone
970977

971-
def only(self, fields):
978+
def only(self, fields: list[str]):
972979
""" Load only these fields for the returned query """
973980
return self._only_or_defer('only', fields)
974981

975-
def defer(self, fields):
982+
def defer(self, fields: list[str]):
976983
""" Don't load these fields for the returned query """
977984
return self._only_or_defer('defer', fields)
978985

979-
def create(self, **kwargs):
986+
def create(self, **kwargs) -> M:
980987
return self.model(**kwargs) \
981988
.batch(self._batch) \
982989
.ttl(self._ttl) \
@@ -1011,9 +1018,9 @@ def __eq__(self, q):
10111018
return False
10121019

10131020
def __ne__(self, q):
1014-
return not (self != q)
1021+
return not (self == q)
10151022

1016-
def timeout(self, timeout):
1023+
def timeout(self, timeout: float | None):
10171024
"""
10181025
:param timeout: Timeout for the query (in seconds)
10191026
:type timeout: float or None
@@ -1156,7 +1163,7 @@ def values_list(self, *fields, **kwargs):
11561163
clone._flat_values_list = flat
11571164
return clone
11581165

1159-
def ttl(self, ttl):
1166+
def ttl(self, ttl: int):
11601167
"""
11611168
Sets the ttl (in seconds) for modified data.
11621169
@@ -1166,7 +1173,7 @@ def ttl(self, ttl):
11661173
clone._ttl = ttl
11671174
return clone
11681175

1169-
def timestamp(self, timestamp):
1176+
def timestamp(self, timestamp: datetime):
11701177
"""
11711178
Allows for custom timestamps to be saved with the record.
11721179
"""

0 commit comments

Comments
 (0)