Skip to content

Commit 896a462

Browse files
Merge pull request #29 from piercefreeman/feature/branching-queries
Support branching of query builder
2 parents d056add + 3b27031 commit 896a462

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

iceaxe/__tests__/test_queries.py

+10
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,13 @@ def test_select_multiple_typehints():
267267
query = select((UserDemo, UserDemo.id, UserDemo.name))
268268
if TYPE_CHECKING:
269269
_: QueryBuilder[tuple[UserDemo, int, str], Literal["SELECT"]] = query
270+
271+
272+
def test_allow_branching():
273+
base_query = select(UserDemo)
274+
275+
query_1 = base_query.limit(1)
276+
query_2 = base_query.limit(2)
277+
278+
assert query_1.limit_value == 1
279+
assert query_2.limit_value == 2

iceaxe/queries.py

+30
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from copy import copy
4+
from functools import wraps
35
from typing import Any, Generic, Literal, Type, TypeVar, TypeVarTuple, cast, overload
46

57
from iceaxe.base import (
@@ -53,6 +55,22 @@
5355
OrderDirection = Literal["ASC", "DESC"]
5456

5557

58+
def allow_branching(fn):
59+
"""
60+
Allows query method modifiers to implement their logic as if `self` is being
61+
modified, but in the background we'll actually return a new instance of the
62+
query builder to allow for branching of the same underlying query.
63+
64+
"""
65+
66+
@wraps(fn)
67+
def new_fn(self, *args, **kwargs):
68+
self = copy(self)
69+
return fn(self, *args, **kwargs)
70+
71+
return new_fn
72+
73+
5674
class QueryBuilder(Generic[P, QueryType]):
5775
"""
5876
The QueryBuilder owns all construction of the SQL string given
@@ -118,6 +136,7 @@ def select(
118136
self, fields: tuple[T | Type[T], T2 | Type[T2], T3 | Type[T3], *Ts]
119137
) -> QueryBuilder[tuple[T, T2, T3, *Ts], Literal["SELECT"]]: ...
120138

139+
@allow_branching
121140
def select(
122141
self,
123142
fields: (
@@ -212,6 +231,7 @@ def _select_inner(
212231
self.select_raw.append(field)
213232
self.select_aggregate_count += 1
214233

234+
@allow_branching
215235
def update(self, model: Type[TableBase]) -> QueryBuilder[None, Literal["UPDATE"]]:
216236
"""
217237
Creates a new update query for the given model. Returns the same
@@ -222,6 +242,7 @@ def update(self, model: Type[TableBase]) -> QueryBuilder[None, Literal["UPDATE"]
222242
self.main_model = model
223243
return self # type: ignore
224244

245+
@allow_branching
225246
def delete(self, model: Type[TableBase]) -> QueryBuilder[None, Literal["DELETE"]]:
226247
"""
227248
Creates a new delete query for the given model. Returns the same
@@ -232,6 +253,7 @@ def delete(self, model: Type[TableBase]) -> QueryBuilder[None, Literal["DELETE"]
232253
self.main_model = model
233254
return self # type: ignore
234255

256+
@allow_branching
235257
def where(self, *conditions: bool):
236258
"""
237259
Adds a where condition to the query. The conditions are combined with
@@ -250,6 +272,7 @@ def where(self, *conditions: bool):
250272
self.where_conditions += validated_comparisons
251273
return self
252274

275+
@allow_branching
253276
def order_by(self, field: Any, direction: OrderDirection = "ASC"):
254277
"""
255278
Adds an order by clause to the query. The field must be a column.
@@ -265,6 +288,7 @@ def order_by(self, field: Any, direction: OrderDirection = "ASC"):
265288
self.order_by_clauses.append(f"{field_token} {direction}")
266289
return self
267290

291+
@allow_branching
268292
def join(self, table: Type[TableBase], on: bool, join_type: JoinType = "INNER"):
269293
"""
270294
Adds a join clause to the query. The `on` parameter should be a comparison
@@ -289,6 +313,7 @@ def join(self, table: Type[TableBase], on: bool, join_type: JoinType = "INNER"):
289313
self.join_clauses.append(join_sql)
290314
return self
291315

316+
@allow_branching
292317
def set(self, column: T, value: T | None):
293318
"""
294319
Sets a column to a specific value in an update query.
@@ -300,6 +325,7 @@ def set(self, column: T, value: T | None):
300325
self.update_values.append((column, value))
301326
return self
302327

328+
@allow_branching
303329
def limit(self, value: int):
304330
"""
305331
Limit the number of rows returned by the query. Useful in pagination
@@ -309,6 +335,7 @@ def limit(self, value: int):
309335
self.limit_value = value
310336
return self
311337

338+
@allow_branching
312339
def offset(self, value: int):
313340
"""
314341
Offset the number of rows returned by the query.
@@ -317,6 +344,7 @@ def offset(self, value: int):
317344
self.offset_value = value
318345
return self
319346

347+
@allow_branching
320348
def group_by(self, *fields: Any):
321349
"""
322350
Groups the results of the query by the given fields. This allows
@@ -334,6 +362,7 @@ def group_by(self, *fields: Any):
334362
self.group_by_fields = valid_fields
335363
return self
336364

365+
@allow_branching
337366
def having(self, *conditions: bool):
338367
"""
339368
Require the result of an aggregation query like func.sum(MyTable.column)
@@ -351,6 +380,7 @@ def having(self, *conditions: bool):
351380
self.having_conditions += valid_conditions
352381
return self
353382

383+
@allow_branching
354384
def text(self, query: str, *variables: Any):
355385
"""
356386
Override the ORM builder and use a raw SQL query instead.

0 commit comments

Comments
 (0)