Skip to content

Adds ability to capture all the db queries at once #1177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 35 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ Compatibility

* Added official support for Django 5.2 (`PR #1179 <https://github.com/pytest-dev/pytest-django/pull/1179>`__).
* Dropped testing on MySQL’s MyISAM storage engine (`PR #1180 <https://github.com/pytest-dev/pytest-django/pull/1180>`__).
* Added fixtures :fixture:`django_assert_num_queries_all_connections` and
:fixture:`django_assert_max_num_queries_all_connections` to check all
your database connections at once.


Bugfixes
^^^^^^^^
Expand Down
69 changes: 69 additions & 0 deletions docs/helpers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,75 @@ If you use type annotations, you can annotate the fixture like this::
...


.. fixture:: django_assert_num_queries_all_connections

``django_assert_num_queries_all_connections``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. py:function:: django_assert_num_queries_all_connections(num, info=None)

:param num: expected number of queries

This fixture allows to check for an expected number of DB queries on all
your database connections.

If the assertion failed, the executed queries can be shown by using
the verbose command line option.

It wraps ``django.test.utils.CaptureQueriesContext`` and yields the wrapped
``DjangoAssertNumAllConnectionsQueries`` instance.

Example usage::

def test_queries(django_assert_num_queries_all_connections):
with django_assert_num_queries_all_connections(3) as captured:
Item.objects.using("default").create('foo')
Item.objects.using("logs").create('bar')
Item.objects.using("finance").create('baz')

assert 'foo' in captured.captured_queries[0]['sql']

If you use type annotations, you can annotate the fixture like this::

from pytest_django import DjangoAssertNumAllConnectionsQueries

def test_num_queries(
django_assert_num_queries: DjangoAssertNumAllConnectionsQueries,
):
...


.. fixture:: django_assert_max_num_queries_all_connections

``django_assert_max_num_queries_all_connections``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. py:function:: django_assert_max_num_queries_all_connections(num, info=None)

:param num: expected maximum number of queries

This fixture allows to check for an expected maximum number of DB queries on all
your database connections.

It is a specialized version of :fixture:`django_assert_num_queries_all_connections`.

Example usage::

def test_max_queries(django_assert_max_num_queries_all_connections):
with django_assert_max_num_queries_all_connections(2):
Item.objects.using("logs").create('foo')
Item.objects.using("finance").create('bar')

If you use type annotations, you can annotate the fixture like this::

from pytest_django import DjangoAssertNumAllConnectionsQueries

def test_max_num_queries(
django_assert_max_num_queries_all_connections: DjangoAssertNumAllConnectionsQueries,
):
...


.. fixture:: django_capture_on_commit_callbacks

``django_capture_on_commit_callbacks``
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ skip_covered = true
exclude_lines = [
"pragma: no cover",
"if TYPE_CHECKING:",
"pass",
"...",
]

[tool.ruff]
Expand Down
7 changes: 6 additions & 1 deletion pytest_django/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@
__version__ = "unknown"


from .fixtures import DjangoAssertNumQueries, DjangoCaptureOnCommitCallbacks
from .fixtures import (
DjangoAssertNumAllConnectionsQueries,
DjangoAssertNumQueries,
DjangoCaptureOnCommitCallbacks,
)
from .plugin import DjangoDbBlocker


__all__ = [
"DjangoAssertNumAllConnectionsQueries",
"DjangoAssertNumQueries",
"DjangoCaptureOnCommitCallbacks",
"DjangoDbBlocker",
Expand Down
118 changes: 115 additions & 3 deletions pytest_django/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import os
from collections.abc import Sized
from contextlib import contextmanager
from functools import partial
from typing import (
Expand All @@ -11,15 +12,19 @@
Any,
Callable,
ContextManager,
Dict,
Generator,
Iterable,
Iterator,
List,
Literal,
Optional,
Protocol,
Sequence,
Tuple,
TypeVar,
Union,
runtime_checkable,
)

import pytest
Expand Down Expand Up @@ -51,7 +56,9 @@
"client",
"db",
"django_assert_max_num_queries",
"django_assert_max_num_queries_all_connections",
"django_assert_num_queries",
"django_assert_num_queries_all_connections",
"django_capture_on_commit_callbacks",
"django_db_reset_sequences",
"django_db_serialized_rollback",
Expand All @@ -65,6 +72,19 @@
]


@runtime_checkable
class QueryCaptureContextProtocol(Protocol, Sized):
@property
def captured_queries(self) -> List[Dict[str, Any]]: ...

def __enter__(self) -> QueryCaptureContextProtocol: ...

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: ...


_QueriesContext = TypeVar("_QueriesContext", bound=QueryCaptureContextProtocol)


@pytest.fixture(scope="session")
def django_db_modify_db_settings_tox_suffix() -> None:
skip_if_no_django()
Expand Down Expand Up @@ -654,6 +674,43 @@ def _live_server_helper(request: pytest.FixtureRequest) -> Generator[None, None,
live_server._live_server_modified_settings.disable()


class CaptureAllConnectionsQueriesContext:
"""
Context manager that captures all queries executed by Django ORM across all Databases in settings.DATABASES.
"""

def __init__(self) -> None:
from django.db import connections
from django.test.utils import CaptureQueriesContext

self.contexts = {alias: CaptureQueriesContext(connections[alias]) for alias in connections}

def __iter__(self) -> Iterable[dict[str, Any]]:
return iter(self.captured_queries)

def __getitem__(self, index: int) -> dict[str, Any]:
return self.captured_queries[index]

def __len__(self) -> int:
return len(self.captured_queries)

@property
def captured_queries(self) -> list[dict[str, Any]]:
queries = []
for context in self.contexts.values():
queries.extend(context.captured_queries)
return queries

def __enter__(self):
for context in self.contexts.values():
context.__enter__()
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
for context in self.contexts.values():
context.__exit__(exc_type, exc_val, exc_tb)


class DjangoAssertNumQueries(Protocol):
"""The type of the `django_assert_num_queries` and
`django_assert_max_num_queries` fixtures."""
Expand All @@ -665,8 +722,18 @@ def __call__(
info: str | None = ...,
*,
using: str | None = ...,
) -> django.test.utils.CaptureQueriesContext:
pass # pragma: no cover
) -> ContextManager[django.test.utils.CaptureQueriesContext]: ...


class DjangoAssertNumAllConnectionsQueries(Protocol):
"""The type of the `django_assert_num_queries_all_connections` and
`django_assert_max_num_queries_all_connections` fixtures."""

def __call__(
self,
num: int,
info: str | None = ...,
) -> ContextManager[CaptureAllConnectionsQueriesContext]: ...


@contextmanager
Expand All @@ -692,8 +759,37 @@ def _assert_num_queries(
else:
conn = default_conn

verbose = config.getoption("verbose") > 0
with CaptureQueriesContext(conn) as context:
yield from _assert_num_queries_context(
config=config, context=context, num=num, exact=exact, info=info
)


@contextmanager
def _assert_num_queries_all_db(
config,
num: int,
exact: bool = True,
info: str | None = None,
) -> Generator[CaptureAllConnectionsQueriesContext, None, None]:
"""A recreation of pytest-django's assert_num_queries that works with all databases in settings.Databases."""

with CaptureAllConnectionsQueriesContext() as context:
yield from _assert_num_queries_context(
config=config, context=context, num=num, exact=exact, info=info
)


def _assert_num_queries_context(
*,
config: pytest.Config,
context: _QueriesContext,
num: int,
exact: bool = True,
info: str | None = None,
) -> Iterator[_QueriesContext]:
verbose = config.getoption("verbose") > 0
with context:
yield context
num_performed = len(context)
if exact:
Expand Down Expand Up @@ -728,6 +824,22 @@ def django_assert_max_num_queries(pytestconfig: pytest.Config) -> DjangoAssertNu
return partial(_assert_num_queries, pytestconfig, exact=False)


@pytest.fixture(scope="function")
def django_assert_num_queries_all_connections(
pytestconfig: pytest.Config,
) -> DjangoAssertNumAllConnectionsQueries:
"""Asserts that the number of queries executed by Django ORM across all connections in settings.DATABASES is equal to the given number."""
return partial(_assert_num_queries_all_db, pytestconfig)


@pytest.fixture(scope="function")
def django_assert_max_num_queries_all_connections(
pytestconfig: pytest.Config,
) -> DjangoAssertNumAllConnectionsQueries:
"""Asserts that the number of queries executed by Django ORM across all connections in settings.DATABASES is less than or equal to the given number."""
return partial(_assert_num_queries_all_db, pytestconfig, exact=False)


class DjangoCaptureOnCommitCallbacks(Protocol):
"""The type of the `django_capture_on_commit_callbacks` fixture."""

Expand Down
2 changes: 2 additions & 0 deletions pytest_django/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
client, # noqa: F401
db, # noqa: F401
django_assert_max_num_queries, # noqa: F401
django_assert_max_num_queries_all_connections, # noqa: F401
django_assert_num_queries, # noqa: F401
django_assert_num_queries_all_connections, # noqa: F401
django_capture_on_commit_callbacks, # noqa: F401
django_db_createdb, # noqa: F401
django_db_keepdb, # noqa: F401
Expand Down
41 changes: 40 additions & 1 deletion tests/test_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@

from .helpers import DjangoPytester

from pytest_django import DjangoAssertNumQueries, DjangoCaptureOnCommitCallbacks, DjangoDbBlocker
from pytest_django import (
DjangoAssertNumAllConnectionsQueries,
DjangoAssertNumQueries,
DjangoCaptureOnCommitCallbacks,
DjangoDbBlocker,
)
from pytest_django_test.app.models import Item


Expand Down Expand Up @@ -259,6 +264,40 @@ def test_queries(django_assert_num_queries):
assert result.ret == 1


@pytest.mark.django_db(databases=["default", "replica", "second"])
def test_django_assert_num_queries_all_connections(
django_assert_num_queries_all_connections: DjangoAssertNumAllConnectionsQueries,
) -> None:
with django_assert_num_queries_all_connections(3):
Item.objects.count()
Item.objects.using("replica").count()
Item.objects.using("second").count()


@pytest.mark.django_db(databases=["default", "replica", "second"])
def test_django_assert_max_num_queries_all_connections(
request: pytest.FixtureRequest,
django_assert_max_num_queries_all_connections: DjangoAssertNumAllConnectionsQueries,
) -> None:
with nonverbose_config(request.config):
with django_assert_max_num_queries_all_connections(2):
Item.objects.create(name="1-foo")
Item.objects.using("second").create(name="2-bar")

with pytest.raises(pytest.fail.Exception) as excinfo: # noqa: PT012
with django_assert_max_num_queries_all_connections(2) as captured:
Item.objects.create(name="1-foo")
Item.objects.create(name="2-bar")
Item.objects.using("second").create(name="3-quux")

assert excinfo.value.args == (
"Expected to perform 2 queries or less but 3 were done "
"(add -v option to show queries)",
)
assert len(captured.captured_queries) == 3
assert "1-foo" in captured.captured_queries[0]["sql"]


@pytest.mark.django_db
def test_django_capture_on_commit_callbacks(
django_capture_on_commit_callbacks: DjangoCaptureOnCommitCallbacks,
Expand Down