diff --git a/.github/workflows/sqlguard-ci.yml b/.github/workflows/sqlguard-ci.yml index 4861230..33c3655 100644 --- a/.github/workflows/sqlguard-ci.yml +++ b/.github/workflows/sqlguard-ci.yml @@ -13,25 +13,40 @@ on: jobs: test-sqlguard: runs-on: ubuntu-latest - container: - image: ghcr.io/astral-sh/uv:debian-slim + strategy: + matrix: + python-version: + - "3.13" steps: - name: Checkout code uses: actions/checkout@v4 - + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Set up Rust + run: | + rustup toolchain install --profile minimal --no-self-update + - name: Install uv + uses: astral-sh/setup-uv@v5 + - name: Install dependencies + run: | + uv sync --dev - name: Run tests with uv run: | export UV_LINK_MODE=copy - uv sync --all-extras uv run pytest tests test-tox-sqlguard: runs-on: ubuntu-latest - container: - image: ghcr.io/astral-sh/uv:debian-slim steps: - name: Checkout code uses: actions/checkout@v4 + - name: Set up Rust + run: | + rustup toolchain install --profile minimal --no-self-update + - name: Install uv + uses: astral-sh/setup-uv@v5 - name: Install Tox & Tox UV run: | uv tool install tox --with tox-uv @@ -39,26 +54,53 @@ jobs: - name: Run Tox suite run: tox run -- tests - build-sqlguard: - needs: [test-sqlguard, test-tox-sqlguard] - if: startsWith(github.ref, 'refs/tags/') + build-sqlguard-sdist: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Build sdist + uses: PyO3/maturin-action@v1 + with: + command: sdist + args: --out dist --manifest-path rust/Cargo.toml + - name: Upload sdist + uses: actions/upload-artifact@v4 + with: + name: sqlguard-sdist + path: dist + + build-sqlguard-wheels: runs-on: ubuntu-latest - container: - image: ghcr.io/astral-sh/uv:debian-slim + strategy: + matrix: + platform: + - runner: ubuntu-24.04 + target: x86_64 steps: - name: Checkout code uses: actions/checkout@v4 - - name: Build package - run: uv build - - name: "Upload Artifact" + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.platform.target }} + args: --release --out dist --find-interpreter --manifest-path rust/Cargo.toml + sccache: ${{ !startsWith(github.ref, 'refs/tags/') }} + manylinux: auto + - name: Upload wheels uses: actions/upload-artifact@v4 with: - name: sqlguard-dist + name: wheels-sqlguard-${{ matrix.platform.target }} path: dist - retention-days: 5 publish-sqlguard: - needs: [build-sqlguard] - if: startsWith(github.ref, 'refs/tags/') + needs: + - build-sqlguard-sdist + - build-sqlguard-wheels + strategy: + matrix: + platform: + - runner: ubuntu-24.04 + target: x86_64 runs-on: ubuntu-latest permissions: # IMPORTANT: this permission is mandatory for Trusted Publishing @@ -66,10 +108,23 @@ jobs: # Needed for fetching the code contents: read steps: - - name: Download Artifacts + - name: Download Artifacts Sdist uses: actions/download-artifact@v4 with: - name: sqlguard-dist + name: sqlguard-sdist path: dist - - name: Publish package distributions to PyPI + - name: Download Artifacts Wheels + uses: actions/download-artifact@v4 + with: + name: wheels-sqlguard-${{ matrix.platform.target }} + path: dist + - name: Publish package distributions to TEST PyPI + if: "!startsWith(github.ref, 'refs/tags/')" + uses: pypa/gh-action-pypi-publish@release/v1 + with: + repository-url: https://test.pypi.org/legacy/ + skip-existing: true + verbose: true + - name: Publish package distributions to PyPI Official (Tags Only) + if: startsWith(github.ref, 'refs/tags/') uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.gitignore b/.gitignore index f6561ee..2e874fc 100644 --- a/.gitignore +++ b/.gitignore @@ -178,3 +178,6 @@ pyrightconfig.json # We don't version uv.lock uv.lock + +# We don't freeze the deps for Cargo neither +rust/Cargo.lock diff --git a/benchs/bench_sql_fingerprint.py b/benchs/bench_sql_fingerprint.py new file mode 100644 index 0000000..97c410f --- /dev/null +++ b/benchs/bench_sql_fingerprint.py @@ -0,0 +1,96 @@ +from pytest_sqlguard.sql import sql_fingerprint as sql_fingperprint_py + +QUERY = """ +-- CTE to find customers with minimum number of purchases +WITH FrequentCustomers AS ( + SELECT + customer_id, + COUNT(order_id) AS total_orders + FROM + Orders + GROUP BY + customer_id + HAVING + COUNT(order_id) >= 5 +), + +-- CTE to calculate the average order value per product category +CategoryAverage AS ( + SELECT + pc.category_id, + pc.category_name, + AVG(oi.price * oi.quantity) AS avg_category_value + FROM + OrderItems oi + JOIN + Products p ON oi.product_id = p.product_id + JOIN + ProductCategories pc ON p.category_id = pc.category_id + GROUP BY + pc.category_id, pc.category_name +) + +-- Main query joining the CTEs with other tables +SELECT + c.customer_id, + c.first_name, + c.last_name, + fc.total_orders, + o.order_id, + o.order_date, + p.product_name, + oi.quantity, + oi.price, + ca.category_name, + ca.avg_category_value, + (oi.price * oi.quantity) AS order_item_total, + CASE + WHEN (oi.price * oi.quantity) > ca.avg_category_value THEN 'Above Average' + WHEN (oi.price * oi.quantity) = ca.avg_category_value THEN 'Average' + ELSE 'Below Average' + END AS price_comparison +FROM + Customers c +JOIN + FrequentCustomers fc ON c.customer_id = fc.customer_id +JOIN + Orders o ON c.customer_id = o.customer_id +JOIN + OrderItems oi ON o.order_id = oi.order_id +JOIN + Products p ON oi.product_id = p.product_id +JOIN + CategoryAverage ca ON p.category_id = ca.category_id +WHERE + o.order_date >= DATEADD(MONTH, -6, GETDATE()) + AND oi.price > 10.00 + AND p.category_id IN (1, 3, 5) +ORDER BY + c.customer_id, o.order_date DESC; +""" + + +def print_sql_fingerprint_py(): + print(sql_fingperprint_py(query=QUERY)) + + +def test_normalize_python(benchmark): + """ + Benchmark our sql_fingerprint python version + It disables the lru_cache surrounding that function. + """ + + benchmark( + sql_fingperprint_py.__wrapped__, + query=QUERY, + ) + + +def test_normalize_rust(benchmark): + from pytest_sqlguard.sqlrs import normalize_sql + + benchmark(normalize_sql, QUERY) + + +if __name__ == "__main__": + print_sql_fingerprint_py() diff --git a/pyproject.toml b/pyproject.toml index 074e025..5cde0d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pytest-sqlguard" -version = "2025.311.0" +version = "2025.313.0" description = "Pytest fixture to record and check SQL Queries made by SQLAlchemy" authors = [{ name = "Manu", email = "manuel.vives@paylead.fr" }] readme = "README.md" @@ -40,6 +40,7 @@ dev = [ "bumpver>=2024.1130", "click", "pyright", + "pytest-benchmark>=4.0.0", "pytest-clarity", "pyyaml", "ruff", @@ -47,8 +48,17 @@ dev = [ ] [build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" +requires = ["maturin>=1.0"] +build-backend = "maturin" + +[tool.maturin] +python-source = "src" +module-name = "pytest_sqlguard.sqlrs" +features = ["pyo3/extension-module"] +include = [ + { path = "LICENSE", format = "sdist" }, + { path = "tox.ini", format = "sdist" }, +] [project.entry-points.pytest11] sqlguard = "pytest_sqlguard.sqlguard" @@ -91,7 +101,7 @@ pythonPlatform = "Linux" executionEnvironments = [{ root = "src" }] [tool.bumpver] -current_version = "2025.311.0" +current_version = "2025.313.0" version_pattern = "YYYY.MM0D.INC0[-TAG]" commit_message = "bump version {old_version} -> {new_version}" tag_message = "{new_version}" @@ -102,3 +112,7 @@ push = false [tool.bumpver.file_patterns] "pyproject.toml" = ['^version = "{version}"', '^current_version = "{version}"'] + + +[tool.uv] +cache-keys = [{file = "pyproject.toml"}, {file = "rust/Cargo.toml"}, {file = "**/*.rs"}] diff --git a/rust/Cargo.toml b/rust/Cargo.toml new file mode 100644 index 0000000..8633b53 --- /dev/null +++ b/rust/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "sql-rs" +version = "0.1.0" +edition = "2021" +# Packaging stuff +description = "Rust library to use inside pytest-sqlguard" +authors = ["Manu PyResult { + let dialect = GenericDialect {}; + let normalized_queries = sql_insight::normalize(&dialect, sql).map_err(|e| { + PyErr::new::(format!("SQL normalization error: {}", e)) + })?; + let first_query = normalized_queries[0].clone(); + Ok(first_query) +} + +#[pyfunction] +fn format_sql(sql: &str) -> PyResult { + let dialect = GenericDialect {}; + let formatted_queries = sql_insight::format(&dialect, sql).map_err(|e| { + PyErr::new::(format!("SQL formatting error: {}", e)) + })?; + let first_query = formatted_queries[0].clone(); + Ok(first_query) +} + +#[pymodule] +fn sqlrs(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(normalize_sql, m)?)?; + m.add_function(wrap_pyfunction!(format_sql, m)?)?; + Ok(()) +} diff --git a/src/pytest_sqlguard/perf_rec.py b/src/pytest_sqlguard/perf_rec.py index 13ef607..c583549 100644 --- a/src/pytest_sqlguard/perf_rec.py +++ b/src/pytest_sqlguard/perf_rec.py @@ -8,7 +8,10 @@ from sqlalchemy import text from sqlparse import format as sql_format -from pytest_sqlguard.sql import sql_fingerprint +try: + from pytest_sqlguard.sqlrs import normalize_sql as sql_fingerprint +except ImportError: + from pytest_sqlguard.sql import sql_fingerprint class Query(NamedTuple): diff --git a/tests/test_file_is_created_and_contains_ok_stuff.py b/tests/test_file_is_created_and_contains_ok_stuff.py index 5002364..6e153e5 100644 --- a/tests/test_file_is_created_and_contains_ok_stuff.py +++ b/tests/test_file_is_created_and_contains_ok_stuff.py @@ -29,4 +29,4 @@ def test_file_is_created_and_contains_ok_stuff(session, sqlguard): queries = obj["test_file_is_created_and_contains_ok_stuff"]["queries"] assert queries assert len(queries) == 2 - assert "CREATE TABLE test_1(i int, t text)" == queries[0]["statement"] + assert "CREATE TABLE test_1 (i INT, t TEXT)" == queries[0]["statement"] diff --git a/tests/test_file_is_created_and_contains_ok_stuff.queries.yaml b/tests/test_file_is_created_and_contains_ok_stuff.queries.yaml index b8a3157..f49c087 100644 --- a/tests/test_file_is_created_and_contains_ok_stuff.queries.yaml +++ b/tests/test_file_is_created_and_contains_ok_stuff.queries.yaml @@ -1,6 +1,6 @@ test_file_is_created_and_contains_ok_stuff: queries: - - statement: CREATE TABLE test_1(i int, t text) + - statement: CREATE TABLE test_1 (i INT, t TEXT) - statement: |- SELECT * - from test_1 + FROM test_1 diff --git a/tests/test_sqlguard_options.queries.yaml b/tests/test_sqlguard_options.queries.yaml index 1147677..70a8a53 100644 --- a/tests/test_sqlguard_options.queries.yaml +++ b/tests/test_sqlguard_options.queries.yaml @@ -1,9 +1,9 @@ TestSQLGuardFailMissing::test_disabled: queries: - - statement: CREATE TABLE TestSQLGuardFailMissing_test_disabled(i int, t text) + - statement: CREATE TABLE TestSQLGuardFailMissing_test_disabled (i INT, t TEXT) TestSQLGuardOverwrite::test_disabled: queries: - statement: CREATE TABLE TestSQLGuardOverwrite_test_disabled(i int, t text) TestSQLGuardOverwrite::test_enabled: queries: - - statement: CREATE TABLE TestSQLGuardOverwrite_test_enabled(i int, t text) + - statement: CREATE TABLE TestSQLGuardOverwrite_test_enabled (i INT, t TEXT) diff --git a/tests/unit_tests/test_perf_rec.py b/tests/unit_tests/test_perf_rec.py index fb5e894..6e80c41 100644 --- a/tests/unit_tests/test_perf_rec.py +++ b/tests/unit_tests/test_perf_rec.py @@ -25,7 +25,7 @@ def test_simple_query_with_simplification(self, session): with record_queries(session) as ctx: session.execute(text("SELECT 1")) assert len(ctx.recorder.queries) == 1 - assert ctx.recorder.queries[0].statement == "SELECT #" + assert ctx.recorder.queries[0].statement == "SELECT ?" class TestSaveToFile: