Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 76 additions & 21 deletions .github/workflows/sqlguard-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,63 +13,118 @@ on:
jobs:
test-sqlguard:
runs-on: ubuntu-latest
container:
image: ghcr.io/astral-sh/uv:debian-slim
strategy:
matrix:
python-version:
- "3.13"
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Set up Rust
run: |
rustup toolchain install --profile minimal --no-self-update
- name: Install uv
uses: astral-sh/setup-uv@v5
- name: Install dependencies
run: |
uv sync --dev
- name: Run tests with uv
run: |
export UV_LINK_MODE=copy
uv sync --all-extras
uv run pytest tests

test-tox-sqlguard:
runs-on: ubuntu-latest
container:
image: ghcr.io/astral-sh/uv:debian-slim
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Rust
run: |
rustup toolchain install --profile minimal --no-self-update
- name: Install uv
uses: astral-sh/setup-uv@v5
- name: Install Tox & Tox UV
run: |
uv tool install tox --with tox-uv
echo "$HOME/.local/bin" >> $GITHUB_PATH
- name: Run Tox suite
run: tox run -- tests

build-sqlguard:
needs: [test-sqlguard, test-tox-sqlguard]
if: startsWith(github.ref, 'refs/tags/')
build-sqlguard-sdist:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Build sdist
uses: PyO3/maturin-action@v1
with:
command: sdist
args: --out dist --manifest-path rust/Cargo.toml
- name: Upload sdist
uses: actions/upload-artifact@v4
with:
name: sqlguard-sdist
path: dist

build-sqlguard-wheels:
runs-on: ubuntu-latest
container:
image: ghcr.io/astral-sh/uv:debian-slim
strategy:
matrix:
platform:
- runner: ubuntu-24.04
target: x86_64
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Build package
run: uv build
- name: "Upload Artifact"
- name: Build wheels
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.platform.target }}
args: --release --out dist --find-interpreter --manifest-path rust/Cargo.toml
sccache: ${{ !startsWith(github.ref, 'refs/tags/') }}
manylinux: auto
- name: Upload wheels
uses: actions/upload-artifact@v4
with:
name: sqlguard-dist
name: wheels-sqlguard-${{ matrix.platform.target }}
path: dist
retention-days: 5
publish-sqlguard:
needs: [build-sqlguard]
if: startsWith(github.ref, 'refs/tags/')
needs:
- build-sqlguard-sdist
- build-sqlguard-wheels
strategy:
matrix:
platform:
- runner: ubuntu-24.04
target: x86_64
runs-on: ubuntu-latest
permissions:
# IMPORTANT: this permission is mandatory for Trusted Publishing
id-token: write
# Needed for fetching the code
contents: read
steps:
- name: Download Artifacts
- name: Download Artifacts Sdist
uses: actions/download-artifact@v4
with:
name: sqlguard-dist
name: sqlguard-sdist
path: dist
- name: Publish package distributions to PyPI
- name: Download Artifacts Wheels
uses: actions/download-artifact@v4
with:
name: wheels-sqlguard-${{ matrix.platform.target }}
path: dist
- name: Publish package distributions to TEST PyPI
if: "!startsWith(github.ref, 'refs/tags/')"
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://test.pypi.org/legacy/
skip-existing: true
verbose: true
- name: Publish package distributions to PyPI Official (Tags Only)
if: startsWith(github.ref, 'refs/tags/')
uses: pypa/gh-action-pypi-publish@release/v1
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,6 @@ pyrightconfig.json

# We don't version uv.lock
uv.lock

# We don't freeze the deps for Cargo neither
rust/Cargo.lock
96 changes: 96 additions & 0 deletions benchs/bench_sql_fingerprint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from pytest_sqlguard.sql import sql_fingerprint as sql_fingperprint_py

QUERY = """
-- CTE to find customers with minimum number of purchases
WITH FrequentCustomers AS (
SELECT
customer_id,
COUNT(order_id) AS total_orders
FROM
Orders
GROUP BY
customer_id
HAVING
COUNT(order_id) >= 5
),

-- CTE to calculate the average order value per product category
CategoryAverage AS (
SELECT
pc.category_id,
pc.category_name,
AVG(oi.price * oi.quantity) AS avg_category_value
FROM
OrderItems oi
JOIN
Products p ON oi.product_id = p.product_id
JOIN
ProductCategories pc ON p.category_id = pc.category_id
GROUP BY
pc.category_id, pc.category_name
)

-- Main query joining the CTEs with other tables
SELECT
c.customer_id,
c.first_name,
c.last_name,
fc.total_orders,
o.order_id,
o.order_date,
p.product_name,
oi.quantity,
oi.price,
ca.category_name,
ca.avg_category_value,
(oi.price * oi.quantity) AS order_item_total,
CASE
WHEN (oi.price * oi.quantity) > ca.avg_category_value THEN 'Above Average'
WHEN (oi.price * oi.quantity) = ca.avg_category_value THEN 'Average'
ELSE 'Below Average'
END AS price_comparison
FROM
Customers c
JOIN
FrequentCustomers fc ON c.customer_id = fc.customer_id
JOIN
Orders o ON c.customer_id = o.customer_id
JOIN
OrderItems oi ON o.order_id = oi.order_id
JOIN
Products p ON oi.product_id = p.product_id
JOIN
CategoryAverage ca ON p.category_id = ca.category_id
WHERE
o.order_date >= DATEADD(MONTH, -6, GETDATE())
AND oi.price > 10.00
AND p.category_id IN (1, 3, 5)
ORDER BY
c.customer_id, o.order_date DESC;
"""


def print_sql_fingerprint_py():
print(sql_fingperprint_py(query=QUERY))


def test_normalize_python(benchmark):
"""
Benchmark our sql_fingerprint python version
It disables the lru_cache surrounding that function.
"""

benchmark(
sql_fingperprint_py.__wrapped__,
query=QUERY,
)


def test_normalize_rust(benchmark):
from pytest_sqlguard.sqlrs import normalize_sql

benchmark(normalize_sql, QUERY)


if __name__ == "__main__":
print_sql_fingerprint_py()
22 changes: 18 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pytest-sqlguard"
version = "2025.311.0"
version = "2025.313.0"
description = "Pytest fixture to record and check SQL Queries made by SQLAlchemy"
authors = [{ name = "Manu", email = "[email protected]" }]
readme = "README.md"
Expand Down Expand Up @@ -40,15 +40,25 @@ dev = [
"bumpver>=2024.1130",
"click",
"pyright",
"pytest-benchmark>=4.0.0",
"pytest-clarity",
"pyyaml",
"ruff",
"termcolor",
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
requires = ["maturin>=1.0"]
build-backend = "maturin"

[tool.maturin]
python-source = "src"
module-name = "pytest_sqlguard.sqlrs"
features = ["pyo3/extension-module"]
include = [
{ path = "LICENSE", format = "sdist" },
{ path = "tox.ini", format = "sdist" },
]

[project.entry-points.pytest11]
sqlguard = "pytest_sqlguard.sqlguard"
Expand Down Expand Up @@ -91,7 +101,7 @@ pythonPlatform = "Linux"
executionEnvironments = [{ root = "src" }]

[tool.bumpver]
current_version = "2025.311.0"
current_version = "2025.313.0"
version_pattern = "YYYY.MM0D.INC0[-TAG]"
commit_message = "bump version {old_version} -> {new_version}"
tag_message = "{new_version}"
Expand All @@ -102,3 +112,7 @@ push = false

[tool.bumpver.file_patterns]
"pyproject.toml" = ['^version = "{version}"', '^current_version = "{version}"']


[tool.uv]
cache-keys = [{file = "pyproject.toml"}, {file = "rust/Cargo.toml"}, {file = "**/*.rs"}]
21 changes: 21 additions & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[package]
name = "sql-rs"
version = "0.1.0"
edition = "2021"
# Packaging stuff
description = "Rust library to use inside pytest-sqlguard"
authors = ["Manu <[email protected]"]
readme = "../README.md"
license = "MIT"
homepage = "https://github.com/PayLead/pytest-sqlguard"
documentation = "https://github.com/PayLead/pytest-sqlguard"
repository = "https://github.com/PayLead/pytest-sqlguard"

[lib]
name = "sqlrs"
crate-type = ["cdylib"]
path = "src/lib.rs"

[dependencies]
pyo3 = { version = "0.24.1", features = ["extension-module"] }
sql-insight = "0.2.0"
30 changes: 30 additions & 0 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use sql_insight::sqlparser::dialect::GenericDialect;

#[pyfunction]
fn normalize_sql(sql: &str) -> PyResult<String> {
let dialect = GenericDialect {};
let normalized_queries = sql_insight::normalize(&dialect, sql).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("SQL normalization error: {}", e))
})?;
let first_query = normalized_queries[0].clone();
Ok(first_query)
}

#[pyfunction]
fn format_sql(sql: &str) -> PyResult<String> {
let dialect = GenericDialect {};
let formatted_queries = sql_insight::format(&dialect, sql).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("SQL formatting error: {}", e))
})?;
let first_query = formatted_queries[0].clone();
Ok(first_query)
}

#[pymodule]
fn sqlrs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(normalize_sql, m)?)?;
m.add_function(wrap_pyfunction!(format_sql, m)?)?;
Ok(())
}
5 changes: 4 additions & 1 deletion src/pytest_sqlguard/perf_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from sqlalchemy import text
from sqlparse import format as sql_format

from pytest_sqlguard.sql import sql_fingerprint
try:
from pytest_sqlguard.sqlrs import normalize_sql as sql_fingerprint
except ImportError:
from pytest_sqlguard.sql import sql_fingerprint


class Query(NamedTuple):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_file_is_created_and_contains_ok_stuff.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ def test_file_is_created_and_contains_ok_stuff(session, sqlguard):
queries = obj["test_file_is_created_and_contains_ok_stuff"]["queries"]
assert queries
assert len(queries) == 2
assert "CREATE TABLE test_1(i int, t text)" == queries[0]["statement"]
assert "CREATE TABLE test_1 (i INT, t TEXT)" == queries[0]["statement"]
4 changes: 2 additions & 2 deletions tests/test_file_is_created_and_contains_ok_stuff.queries.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
test_file_is_created_and_contains_ok_stuff:
queries:
- statement: CREATE TABLE test_1(i int, t text)
- statement: CREATE TABLE test_1 (i INT, t TEXT)
- statement: |-
SELECT *
from test_1
FROM test_1
4 changes: 2 additions & 2 deletions tests/test_sqlguard_options.queries.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
TestSQLGuardFailMissing::test_disabled:
queries:
- statement: CREATE TABLE TestSQLGuardFailMissing_test_disabled(i int, t text)
- statement: CREATE TABLE TestSQLGuardFailMissing_test_disabled (i INT, t TEXT)
TestSQLGuardOverwrite::test_disabled:
queries:
- statement: CREATE TABLE TestSQLGuardOverwrite_test_disabled(i int, t text)
TestSQLGuardOverwrite::test_enabled:
queries:
- statement: CREATE TABLE TestSQLGuardOverwrite_test_enabled(i int, t text)
- statement: CREATE TABLE TestSQLGuardOverwrite_test_enabled (i INT, t TEXT)
2 changes: 1 addition & 1 deletion tests/unit_tests/test_perf_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_simple_query_with_simplification(self, session):
with record_queries(session) as ctx:
session.execute(text("SELECT 1"))
assert len(ctx.recorder.queries) == 1
assert ctx.recorder.queries[0].statement == "SELECT #"
assert ctx.recorder.queries[0].statement == "SELECT ?"


class TestSaveToFile:
Expand Down
Loading