Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/pytsql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import importlib.metadata
import warnings

from .tsql import execute, executes
from .tsql import execute, executes, iter_executes_batches

try:
__version__ = importlib.metadata.version(__name__)
Expand All @@ -12,4 +12,4 @@
__version__ = "unknown"


__all__ = ["execute", "executes"]
__all__ = ["execute", "executes", "iter_executes_batches"]
51 changes: 50 additions & 1 deletion src/pytsql/tsql.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
import re
import warnings
from collections.abc import Iterator
from pathlib import Path
from re import Match
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union

import antlr4.tree.Tree
import sqlalchemy
Expand Down Expand Up @@ -247,6 +248,54 @@ def executes(
_fetch_and_clear_prints(conn)


def iter_executes_batches(
code: str,
engine: sqlalchemy.engine.Engine,
parameters: Optional[dict[str, Any]] = None,
isolate_top_level_statements: bool = True,
) -> Iterator[tuple[str, Callable[[], None]]]:
"""
Yields (sql_batch_string, run) for each batch. Mimics executes() but returns a generator
where run() can be called to execute each batch.

Args
----
code T-SQL string to be executed
engine (sqlalchemy.engine.Engine): established mssql connection
parameters An optional dictionary of parameters to substituted in the sql script
isolate_top_level_statements: whether to execute statements one by one or in whole batches

Returns
-------
Iterator of (sql_batch_string, run) tuples where run() executes the batch.
"""
parametrized_code = _parameterize(code, parameters) if parameters else code

# I would love to use a context manager here, but that would close the connection
# before the caller has a chance to call run(). So we have to do it manually.
# Alternatively we could accept a connection instead of an engine, but that would
# not align with the interface of the other functions.
conn = engine.connect().execution_options(isolation_level="AUTOCOMMIT")
try:
conn.execute(_text(f"DROP TABLE IF EXISTS {_PRINTS_TABLE}"))
conn.execute(_text(f"CREATE TABLE {_PRINTS_TABLE} (p NVARCHAR(4000))"))

for batch in _split(parametrized_code, isolate_top_level_statements):
sql_batch = _text(batch)

def run(sql=sql_batch, _conn=conn):
_conn.execute(sql)
_fetch_and_clear_prints(_conn)

# Yield the raw string (or TextClause) and a bound runner
yield batch, run

finally:
# This is a bit ugly, but we have to close the connection and drop the temp table.
conn.execute(_text(f"DROP TABLE IF EXISTS {_PRINTS_TABLE}"))
conn.close()


def execute(
path: Union[str, Path],
engine: sqlalchemy.engine.Engine,
Expand Down
30 changes: 30 additions & 0 deletions tests/integration/test_iter_executes_batches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import sqlalchemy as sa

from pytsql.tsql import iter_executes_batches


def test_executes_batches(engine):
seed = """
USE [tempdb]
GO
DROP TABLE IF EXISTS [test_table_batches]
CREATE TABLE [test_table_batches] (
col VARCHAR(3)
)
GO
INSERT INTO [test_table_batches] (col)
VALUES ('A'), ('AB'), ('ABC')
PRINT('Affected ' + CAST(@@ROWCOUNT AS VARCHAR) + ' rows')
"""
batches = []
for sql, run in iter_executes_batches(seed, engine, None):
run()
batches.append(sql)

assert len(batches) == 5
assert "USE [tempdb]" in batches[0]

with engine.connect() as conn:
result = conn.execute(sa.text("SELECT COUNT(*) FROM test_table_batches"))
count = result.scalar()
assert count == 3
16 changes: 8 additions & 8 deletions tests/integration/test_non_isolation_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,18 @@ def test_semi_persistent_set(engine, caplog):
caplog.set_level(logging.INFO)

seed = """
DECLARE @A INT = 12
DECLARE @B INT = 34
SET @A = 56
SET @B = 78
DECLARE @A INT = 123
DECLARE @B INT = 345
SET @A = 567
SET @B = 789
PRINT(@A)
GO
PRINT(@B)
"""

executes(seed, engine, isolate_top_level_statements=False)

assert "56" in caplog.text
assert "34" in caplog.text
assert "12" not in caplog.text
assert "78" not in caplog.text
assert "567" in caplog.text
assert "345" in caplog.text
assert "123" not in caplog.text
assert "789" not in caplog.text
Loading