1
1
from __future__ import annotations
2
2
3
+ from copy import copy
4
+ from functools import wraps
3
5
from typing import Any , Generic , Literal , Type , TypeVar , TypeVarTuple , cast , overload
4
6
5
7
from iceaxe .base import (
53
55
OrderDirection = Literal ["ASC" , "DESC" ]
54
56
55
57
58
+ def allow_branching (fn ):
59
+ """
60
+ Allows query method modifiers to implement their logic as if `self` is being
61
+ modified, but in the background we'll actually return a new instance of the
62
+ query builder to allow for branching of the same underlying query.
63
+
64
+ """
65
+
66
+ @wraps (fn )
67
+ def new_fn (self , * args , ** kwargs ):
68
+ self = copy (self )
69
+ return fn (self , * args , ** kwargs )
70
+
71
+ return new_fn
72
+
73
+
56
74
class QueryBuilder (Generic [P , QueryType ]):
57
75
"""
58
76
The QueryBuilder owns all construction of the SQL string given
@@ -118,6 +136,7 @@ def select(
118
136
self , fields : tuple [T | Type [T ], T2 | Type [T2 ], T3 | Type [T3 ], * Ts ]
119
137
) -> QueryBuilder [tuple [T , T2 , T3 , * Ts ], Literal ["SELECT" ]]: ...
120
138
139
+ @allow_branching
121
140
def select (
122
141
self ,
123
142
fields : (
@@ -212,6 +231,7 @@ def _select_inner(
212
231
self .select_raw .append (field )
213
232
self .select_aggregate_count += 1
214
233
234
+ @allow_branching
215
235
def update (self , model : Type [TableBase ]) -> QueryBuilder [None , Literal ["UPDATE" ]]:
216
236
"""
217
237
Creates a new update query for the given model. Returns the same
@@ -222,6 +242,7 @@ def update(self, model: Type[TableBase]) -> QueryBuilder[None, Literal["UPDATE"]
222
242
self .main_model = model
223
243
return self # type: ignore
224
244
245
+ @allow_branching
225
246
def delete (self , model : Type [TableBase ]) -> QueryBuilder [None , Literal ["DELETE" ]]:
226
247
"""
227
248
Creates a new delete query for the given model. Returns the same
@@ -232,6 +253,7 @@ def delete(self, model: Type[TableBase]) -> QueryBuilder[None, Literal["DELETE"]
232
253
self .main_model = model
233
254
return self # type: ignore
234
255
256
+ @allow_branching
235
257
def where (self , * conditions : bool ):
236
258
"""
237
259
Adds a where condition to the query. The conditions are combined with
@@ -250,6 +272,7 @@ def where(self, *conditions: bool):
250
272
self .where_conditions += validated_comparisons
251
273
return self
252
274
275
+ @allow_branching
253
276
def order_by (self , field : Any , direction : OrderDirection = "ASC" ):
254
277
"""
255
278
Adds an order by clause to the query. The field must be a column.
@@ -265,6 +288,7 @@ def order_by(self, field: Any, direction: OrderDirection = "ASC"):
265
288
self .order_by_clauses .append (f"{ field_token } { direction } " )
266
289
return self
267
290
291
+ @allow_branching
268
292
def join (self , table : Type [TableBase ], on : bool , join_type : JoinType = "INNER" ):
269
293
"""
270
294
Adds a join clause to the query. The `on` parameter should be a comparison
@@ -289,6 +313,7 @@ def join(self, table: Type[TableBase], on: bool, join_type: JoinType = "INNER"):
289
313
self .join_clauses .append (join_sql )
290
314
return self
291
315
316
+ @allow_branching
292
317
def set (self , column : T , value : T | None ):
293
318
"""
294
319
Sets a column to a specific value in an update query.
@@ -300,6 +325,7 @@ def set(self, column: T, value: T | None):
300
325
self .update_values .append ((column , value ))
301
326
return self
302
327
328
+ @allow_branching
303
329
def limit (self , value : int ):
304
330
"""
305
331
Limit the number of rows returned by the query. Useful in pagination
@@ -309,6 +335,7 @@ def limit(self, value: int):
309
335
self .limit_value = value
310
336
return self
311
337
338
+ @allow_branching
312
339
def offset (self , value : int ):
313
340
"""
314
341
Offset the number of rows returned by the query.
@@ -317,6 +344,7 @@ def offset(self, value: int):
317
344
self .offset_value = value
318
345
return self
319
346
347
+ @allow_branching
320
348
def group_by (self , * fields : Any ):
321
349
"""
322
350
Groups the results of the query by the given fields. This allows
@@ -334,6 +362,7 @@ def group_by(self, *fields: Any):
334
362
self .group_by_fields = valid_fields
335
363
return self
336
364
365
+ @allow_branching
337
366
def having (self , * conditions : bool ):
338
367
"""
339
368
Require the result of an aggregation query like func.sum(MyTable.column)
@@ -351,6 +380,7 @@ def having(self, *conditions: bool):
351
380
self .having_conditions += valid_conditions
352
381
return self
353
382
383
+ @allow_branching
354
384
def text (self , query : str , * variables : Any ):
355
385
"""
356
386
Override the ORM builder and use a raw SQL query instead.
0 commit comments