Skip to content

Commit 47ebf40

Browse files
committed
New QueryBuilder and EngineDispatcher classes. Improved create_initial_migration()
1 parent 3c83afa commit 47ebf40

File tree

13 files changed

+169
-30
lines changed

13 files changed

+169
-30
lines changed

docs/engine.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,33 @@ Connections are pooled and re-used by default. You can disabled this behavior by
5757
`max_pool_conns` can also be used to define the maximum number of connections to start.
5858

5959
Use `engine.disconnect_all()` to close all connections.
60+
61+
## Engine dispatcher
62+
63+
Multiple engines can be used at the same time. `EngineDispatcher` makes it easy to select engines based on matching tags.
64+
65+
Register engines on the dispatcher using `dispatcher.register(engine, ["tag1", "tag2"])`.
66+
67+
Select one engine using `dispatcher.tag_name`. If multiple engines matche a tag, one is randomly selected.
68+
69+
```py
70+
from sqlorm import EngineDispatcher
71+
72+
dispatcher = EngineDispatcher()
73+
dispatcher.register(Engine.from_uri("postgresql://primary"), default=True) # there can be multiple default engines
74+
dispatcher.register(Engine.from_uri("postgresql://replica1"), ["readonly"])
75+
dispatcher.register(Engine.from_uri("postgresql://replica2"), ["readonly"])
76+
77+
with dispatcher:
78+
# uses default engine (ie. primary)
79+
80+
with dispatcher.readonly:
81+
# uses a randomly selected engines from the one matching the readonly tag (ie. replica1 or replica2)
82+
```
83+
84+
The context behaves the same as a context from an Engine, ie. starting a transaction. To start a session instead, use `dispatcher.session(tag_name)`.
85+
86+
By default, if no engines match the given tag, it will fallback on the default. This can be changed using the `fallback` argument of the `EngineDispatcher` constructor.
87+
Pass `False` to disable fallback or a tag name to fallback to a specific tag.
88+
89+
Use `dispatcher.disconnect_all()` to close all connections across all engines.

docs/executing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ There are a few shortcut methods on the transaction:
6161
```python
6262
with engine as tx:
6363
task_row = tx.fetch("SELECT * FROM tasks WHERE id = ?", [1]).first()
64-
task_row = tx.fetchone("SELECT * FROM tasks WHERE id = ?", [1]) # same as line before
64+
task_row = tx.fetchone("SELECT * FROM tasks WHERE id = ?", [1]) # same as previous line
6565

6666
for row in tx.fetch("SELECT * FROM tasks"):
6767
print(row)

docs/integrations.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Integrations
2+
3+
## Flask
4+
5+
See [Flask-SQLORM](https://github.com/hyperflask/flask-sqlorm).

docs/sql-utilities.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,21 @@ Useful shortcuts:
7474
- `SQL.update()` instead of `SQL().update()`
7575
- `SQL.delete_from()` instead of `SQL().delete_from()`
7676

77+
## Query builder
78+
79+
A query builder for SELECT statements, built on top of the `SQL` class, is provided.
80+
81+
```py
82+
from sqlorm import QueryBuilder
83+
query = QueryBuilder().from_("table").where(col="value").where(SQL("col") >= 1)
84+
```
85+
86+
It is aware of the different parts of the query and will treat them differently.
87+
88+
- `builder.select()`, `.join()`, `.where()` will add expressions, all other parts will be replaced when called
89+
- If `.select()` is not called, `SELECT *` is used
90+
- Keyword arguments can be used with `.where()` to generate equal comparisons
91+
7792
## Handling list of SQL pieces
7893

7994
Use `SQL.List` to manage lists of `SQL` objects. A few subclasses exists:

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ nav:
1616
- schema.md
1717
- drivers.md
1818
- instrumentation.md
19+
- integrations.md
1920

2021
theme:
2122
name: material

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "sqlorm-py"
3-
version = "0.3.1"
3+
version = "0.4.0"
44
description = "A new kind or ORM that do not abstract away your database or SQL queries."
55
authors = [
66
{"name" = "Maxime Bouroumeau-Fuseau", email = "[email protected]"}

src/sqlorm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .engine import (
22
Engine,
33
EngineError,
4+
EngineDispatcher,
45
Session,
56
SessionError,
67
Transaction,

src/sqlorm/builder.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@ class QueryBuilderError(Exception):
1111

1212

1313
class QueryBuilder(SQLStr):
14-
components = ("SELECT", "FROM", "JOIN+", "WHERE+", "GROUP BY", "HAVING", "ORDER BY", "LIMIT", "OFFSET")
14+
components = ("SELECT+", "FROM", "JOIN+", "WHERE+", "GROUP BY", "HAVING", "ORDER BY", "LIMIT", "OFFSET")
1515
wrappers = {"SELECT": List, "FROM": List, "WHERE": Conditions, "HAVING": Conditions, "ORDER BY": List}
1616

1717
def __init__(self):
18-
self.parts = {}
18+
self.parts = {"SELECT": QueryPartBuilder(self, "*")}
19+
self._selected = False
1920

2021
def _render(self, params):
2122
stmt = []
@@ -38,12 +39,23 @@ def __getattr__(self, name):
3839

3940
self.parts[name] = QueryPartBuilder(self)
4041
return self.parts[name]
42+
43+
def select(self, *parts):
44+
if not self._selected:
45+
self._selected = True
46+
self.parts["SELECT"] = QueryPartBuilder(self)
47+
return self.parts["SELECT"](*parts)
48+
49+
def where(self, *parts, **filters):
50+
for k, v in filters.items():
51+
parts = list(parts) + [SQL(k) == v]
52+
return self.__getattr__("where")(*parts)
4153

4254

4355
class QueryPartBuilder:
44-
def __init__(self, builder):
56+
def __init__(self, builder, *parts):
4557
self.builder = builder
46-
self.parts = []
58+
self.parts = list(parts)
4759

4860
def __call__(self, *parts):
4961
if len(parts) == 1 and parts[0] is None:

src/sqlorm/engine.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import inspect
66
import urllib.parse
7+
import random
78
from blinker import Namespace
89
from .sql import render, ParametrizedStmt
910
from .resultset import ResultSet, CompositeResultSet, CompositionMap
@@ -152,16 +153,68 @@ def session(self, **kwargs):
152153
session.close()
153154

154155
def __enter__(self):
155-
if session_context.top:
156+
if session_context.top and session_context.top.engine == self:
156157
return session_context.top.__enter__()
157-
session = self.make_session()
158+
session = self.make_session(close_after_tx=True)
158159
return session.__enter__()
159160

160161
def __exit__(self, exc_type, exc_value, exc_tb):
161-
session = session_context.top
162-
session.__exit__(exc_type, exc_value, exc_tb)
163-
if not session.transactions:
164-
session.close()
162+
session_context.top.__exit__(exc_type, exc_value, exc_tb)
163+
164+
165+
class EngineDispatcher:
166+
def __init__(self, fallback=None, default_tag='default'):
167+
self._engines = []
168+
self.default_tag = default_tag
169+
self.fallback = default_tag if fallback is None else fallback
170+
171+
def register(self, engine, tags=None, default=False):
172+
tags = tags or []
173+
if default:
174+
tags.append(self.default_tag)
175+
self._engines.append((engine, tags))
176+
177+
@property
178+
def engines(self):
179+
return [e for e, tags in self.engines]
180+
181+
@property
182+
def tags(self):
183+
return list(set([t for e, tags in self.engines for t in tags]))
184+
185+
def select_all(self, tag=None):
186+
if not tag and self.default_tag:
187+
tag = self.default_tag
188+
found = []
189+
for engine, tags in self._engines:
190+
if not tag or tag in tags:
191+
found.append(engine)
192+
if found:
193+
return found
194+
if self.fallback is True:
195+
return self.engines
196+
if self.fallback:
197+
return self.select_all(self.fallback)
198+
raise EngineError("No engines found")
199+
200+
def select(self, tag=None):
201+
return random.choice(self.select_all(tag))
202+
203+
def session(self, tag=None, **kwargs):
204+
return self.select(tag).session(**kwargs)
205+
206+
def disconnect_all(self):
207+
for engine in self.engines:
208+
engine.disconnect_all()
209+
210+
def __getattr__(self, tag):
211+
return self.select(tag)
212+
213+
def __enter__(self):
214+
return self.select().__enter__()
215+
216+
def __exit__(self, exc_type, exc_value, exc_tb):
217+
session_context.top.__exit__(exc_type, exc_value, exc_tb)
165218

166219

167220
class EngineError(Exception):
@@ -258,6 +311,7 @@ def __init__(
258311
dbapi_conn=None,
259312
auto_close_conn=None,
260313
virtual_tx=False,
314+
close_after_tx=False,
261315
logger=None,
262316
logger_level=None,
263317
engine=None,
@@ -274,6 +328,7 @@ def __init__(
274328
self.conn = dbapi_conn
275329
self.auto_close_conn = not bool(dbapi_conn) if auto_close_conn is None else auto_close_conn
276330
self.virtual_tx = virtual_tx
331+
self.close_after_tx = close_after_tx
277332
self.logger = logger
278333
self.logger_level = logger_level
279334
self.transactions = []
@@ -344,6 +399,8 @@ def __exit__(self, exc_type, exc_value, exc_tb):
344399
tx.__exit__(exc_type, exc_value, exc_tb)
345400
if not self.transactions:
346401
session_context.pop()
402+
if self.close_after_tx:
403+
self.close()
347404

348405
@property
349406
def transaction(self):

src/sqlorm/schema.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def create_table(mapper, default_type="varchar"):
4343
return SQL(
4444
"CREATE TABLE",
4545
mapper.table,
46-
SQL.Tuple(
46+
SQL.List(
4747
[
4848
SQL(c.schema_def) if c.schema_def else SQL(
4949
c.name,
@@ -61,7 +61,10 @@ def create_table(mapper, default_type="varchar"):
6161
else "",
6262
)
6363
for c in mapper.columns
64-
]
64+
],
65+
startstr="(\n ",
66+
joinstr=",\n ",
67+
endstr="\n)"
6568
),
6669
)
6770

@@ -149,14 +152,29 @@ def set_schema_version(version, engine=None):
149152
)
150153

151154

152-
def create_initial_migration(model_registry=None, path="migrations", version="000", **kwargs):
155+
def create_initial_migration(model_registry=None, models=None, path="migrations", version=None, **kwargs):
153156
if not model_registry:
154157
model_registry = BaseModel.__model_registry__
155158

159+
if version is None:
160+
migrations = create_migrations_from_dir(path)
161+
version = '%03d' % (migrations[-1][0] + 1) if migrations else "000"
162+
163+
if not models:
164+
models = model_registry.values()
165+
header = "Initial creation of the database"
166+
name = "initial"
167+
else:
168+
models = [model_registry[m] if isinstance(m, str) else m for m in models]
169+
header = "Creation of tables: " + ", ".join(m.__mapper__.table for m in models)
170+
name = "create_" + "_".join(m.__mapper__.table for m in models)
171+
156172
stmts = []
157-
for model in model_registry.values():
158-
stmts.append(create_table.sql(model.__mapper__, **kwargs))
173+
for model in models:
174+
stmts.append(str(create_table.sql(model.__mapper__, **kwargs)))
175+
176+
sql = f"-- {header} (auto-generated by sqlorm)\n\n" + ";\n\n".join(stmts) + ";\n"
177+
with open(os.path.join(path, f"{version}_{name}.sql"), "w") as f:
178+
f.write(sql)
159179

160-
sql = "-- Initial creation of the database (auto-generated by sqlorm)\n\n" + ";\n\n".join(stmts)
161-
with open(os.path.join(path, f"{version}_initial.sql"), "w") as f:
162-
f.write(sql)
180+
return (header, version, name)

0 commit comments

Comments
 (0)