diff --git a/src/pytsql/__init__.py b/src/pytsql/__init__.py index 5622157..8ab51da 100644 --- a/src/pytsql/__init__.py +++ b/src/pytsql/__init__.py @@ -3,7 +3,7 @@ import importlib.metadata import warnings -from .tsql import execute, executes +from .tsql import execute, executes, iter_executes_batches try: __version__ = importlib.metadata.version(__name__) @@ -12,4 +12,4 @@ __version__ = "unknown" -__all__ = ["execute", "executes"] +__all__ = ["execute", "executes", "iter_executes_batches"] diff --git a/src/pytsql/tsql.py b/src/pytsql/tsql.py index a2d0e85..5db67c1 100644 --- a/src/pytsql/tsql.py +++ b/src/pytsql/tsql.py @@ -1,9 +1,10 @@ import logging import re import warnings +from collections.abc import Iterator from pathlib import Path from re import Match -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import antlr4.tree.Tree import sqlalchemy @@ -247,6 +248,54 @@ def executes( _fetch_and_clear_prints(conn) +def iter_executes_batches( + code: str, + engine: sqlalchemy.engine.Engine, + parameters: Optional[dict[str, Any]] = None, + isolate_top_level_statements: bool = True, +) -> Iterator[tuple[str, Callable[[], None]]]: + """ + Yields (sql_batch_string, run) for each batch. Mimics executes() but returns a generator + where run() can be called to execute each batch. + + Args + ---- + code T-SQL string to be executed + engine (sqlalchemy.engine.Engine): established mssql connection + parameters An optional dictionary of parameters to substituted in the sql script + isolate_top_level_statements: whether to execute statements one by one or in whole batches + + Returns + ------- + Iterator of (sql_batch_string, run) tuples where run() executes the batch. + """ + parametrized_code = _parameterize(code, parameters) if parameters else code + + # I would love to use a context manager here, but that would close the connection + # before the caller has a chance to call run(). So we have to do it manually. + # Alternatively we could accept a connection instead of an engine, but that would + # not align with the interface of the other functions. + conn = engine.connect().execution_options(isolation_level="AUTOCOMMIT") + try: + conn.execute(_text(f"DROP TABLE IF EXISTS {_PRINTS_TABLE}")) + conn.execute(_text(f"CREATE TABLE {_PRINTS_TABLE} (p NVARCHAR(4000))")) + + for batch in _split(parametrized_code, isolate_top_level_statements): + sql_batch = _text(batch) + + def run(sql=sql_batch, _conn=conn): + _conn.execute(sql) + _fetch_and_clear_prints(_conn) + + # Yield the raw string (or TextClause) and a bound runner + yield batch, run + + finally: + # This is a bit ugly, but we have to close the connection and drop the temp table. + conn.execute(_text(f"DROP TABLE IF EXISTS {_PRINTS_TABLE}")) + conn.close() + + def execute( path: Union[str, Path], engine: sqlalchemy.engine.Engine, diff --git a/tests/integration/test_iter_executes_batches.py b/tests/integration/test_iter_executes_batches.py new file mode 100644 index 0000000..21a0865 --- /dev/null +++ b/tests/integration/test_iter_executes_batches.py @@ -0,0 +1,30 @@ +import sqlalchemy as sa + +from pytsql.tsql import iter_executes_batches + + +def test_executes_batches(engine): + seed = """ + USE [tempdb] + GO + DROP TABLE IF EXISTS [test_table_batches] + CREATE TABLE [test_table_batches] ( + col VARCHAR(3) + ) + GO + INSERT INTO [test_table_batches] (col) + VALUES ('A'), ('AB'), ('ABC') + PRINT('Affected ' + CAST(@@ROWCOUNT AS VARCHAR) + ' rows') + """ + batches = [] + for sql, run in iter_executes_batches(seed, engine, None): + run() + batches.append(sql) + + assert len(batches) == 5 + assert "USE [tempdb]" in batches[0] + + with engine.connect() as conn: + result = conn.execute(sa.text("SELECT COUNT(*) FROM test_table_batches")) + count = result.scalar() + assert count == 3 diff --git a/tests/integration/test_non_isolation_mode.py b/tests/integration/test_non_isolation_mode.py index b271fde..754b328 100644 --- a/tests/integration/test_non_isolation_mode.py +++ b/tests/integration/test_non_isolation_mode.py @@ -28,10 +28,10 @@ def test_semi_persistent_set(engine, caplog): caplog.set_level(logging.INFO) seed = """ - DECLARE @A INT = 12 - DECLARE @B INT = 34 - SET @A = 56 - SET @B = 78 + DECLARE @A INT = 123 + DECLARE @B INT = 345 + SET @A = 567 + SET @B = 789 PRINT(@A) GO PRINT(@B) @@ -39,7 +39,7 @@ def test_semi_persistent_set(engine, caplog): executes(seed, engine, isolate_top_level_statements=False) - assert "56" in caplog.text - assert "34" in caplog.text - assert "12" not in caplog.text - assert "78" not in caplog.text + assert "567" in caplog.text + assert "345" in caplog.text + assert "123" not in caplog.text + assert "789" not in caplog.text