diff --git a/pytest_django/fixtures.py b/pytest_django/fixtures.py index b58aadeb..5b7adf11 100644 --- a/pytest_django/fixtures.py +++ b/pytest_django/fixtures.py @@ -5,6 +5,7 @@ from functools import partial import pytest +from django.utils.module_loading import import_string from . import live_server_helper from .django_compat import is_django_unittest @@ -19,7 +20,6 @@ _DjangoDbDatabases = Optional[Union["Literal['__all__']", Iterable[str]]] _DjangoDb = Tuple[bool, bool, _DjangoDbDatabases] - __all__ = [ "django_db_setup", "db", @@ -42,6 +42,18 @@ ] +def import_from_string(val, setting_name): + """ + Attempt to import a class from a string representation. + """ + try: + return import_string(val) + except ImportError as e: + msg = "Could not import '%s' for API setting '%s'. %s: %s." \ + % (val, setting_name, e.__class__.__name__, e) + raise ImportError(msg) + + @pytest.fixture(scope="session") def django_db_modify_db_settings_tox_suffix() -> None: skip_if_no_django() @@ -64,15 +76,15 @@ def django_db_modify_db_settings_xdist_suffix(request) -> None: @pytest.fixture(scope="session") def django_db_modify_db_settings_parallel_suffix( - django_db_modify_db_settings_tox_suffix: None, - django_db_modify_db_settings_xdist_suffix: None, + django_db_modify_db_settings_tox_suffix: None, + django_db_modify_db_settings_xdist_suffix: None, ) -> None: skip_if_no_django() @pytest.fixture(scope="session") def django_db_modify_db_settings( - django_db_modify_db_settings_parallel_suffix: None, + django_db_modify_db_settings_parallel_suffix: None, ) -> None: skip_if_no_django() @@ -94,13 +106,13 @@ def django_db_createdb(request) -> bool: @pytest.fixture(scope="session") def django_db_setup( - request, - django_test_environment: None, - django_db_blocker, - django_db_use_migrations: bool, - django_db_keepdb: bool, - django_db_createdb: bool, - django_db_modify_db_settings: None, + request, + django_test_environment: None, + django_db_blocker, + django_db_use_migrations: bool, + django_db_keepdb: bool, + django_db_createdb: bool, + django_db_modify_db_settings: None, ) -> None: """Top level fixture to ensure test databases are available""" from django.test.utils import setup_databases, teardown_databases @@ -136,11 +148,12 @@ def teardown_database() -> None: def _django_db_fixture_helper( - request, - django_db_blocker, - transactional: bool = False, - reset_sequences: bool = False, + request, + django_db_blocker, + transactional: bool = False, + reset_sequences: bool = False, ) -> None: + if is_django_unittest(request): return @@ -155,13 +168,16 @@ def _django_db_fixture_helper( django_db_blocker.unblock() request.addfinalizer(django_db_blocker.restore) - import django.test - import django.db - if transactional: - test_case_class = django.test.TransactionTestCase + test_case_classname = request.config.getvalue("transaction_testcase_class") or os.getenv( + "DJANGO_TRANSACTION_TEST_CASE_CLASS" + ) or "django.test.TransactionTestCase" else: - test_case_class = django.test.TestCase + test_case_classname = request.config.getvalue("testcase_class") or os.getenv( + "DJANGO_TEST_CASE_CLASS" + ) or "django.test.TestCase" + + test_case_class = import_string(test_case_classname) _reset_sequences = reset_sequences @@ -223,9 +239,9 @@ def _set_suffix_to_test_databases(suffix: str) -> None: @pytest.fixture(scope="function") def db( - request, - django_db_setup: None, - django_db_blocker, + request, + django_db_setup: None, + django_db_blocker, ) -> None: """Require a django test database. @@ -243,8 +259,8 @@ def db( if "django_db_reset_sequences" in request.fixturenames: request.getfixturevalue("django_db_reset_sequences") if ( - "transactional_db" in request.fixturenames - or "live_server" in request.fixturenames + "transactional_db" in request.fixturenames + or "live_server" in request.fixturenames ): request.getfixturevalue("transactional_db") else: @@ -253,9 +269,9 @@ def db( @pytest.fixture(scope="function") def transactional_db( - request, - django_db_setup: None, - django_db_blocker, + request, + django_db_setup: None, + django_db_blocker, ) -> None: """Require a django test database with transaction support. @@ -276,9 +292,9 @@ def transactional_db( @pytest.fixture(scope="function") def django_db_reset_sequences( - request, - django_db_setup: None, - django_db_blocker, + request, + django_db_setup: None, + django_db_blocker, ) -> None: """Require a transactional test database with sequence reset support. @@ -332,9 +348,9 @@ def django_username_field(django_user_model) -> str: @pytest.fixture() def admin_user( - db: None, - django_user_model, - django_username_field: str, + db: None, + django_user_model, + django_username_field: str, ): """A Django admin user. @@ -363,8 +379,8 @@ def admin_user( @pytest.fixture() def admin_client( - db: None, - admin_user, + db: None, + admin_user, ) -> "django.test.client.Client": """A Django test client logged in as an admin user.""" from django.test.client import Client @@ -496,11 +512,11 @@ def _live_server_helper(request) -> None: @contextmanager def _assert_num_queries( - config, - num: int, - exact: bool = True, - connection=None, - info=None, + config, + num: int, + exact: bool = True, + connection=None, + info=None, ) -> Generator["django.test.utils.CaptureQueriesContext", None, None]: from django.test.utils import CaptureQueriesContext @@ -547,9 +563,9 @@ def django_assert_max_num_queries(pytestconfig): @contextmanager def _capture_on_commit_callbacks( - *, - using: Optional[str] = None, - execute: bool = False + *, + using: Optional[str] = None, + execute: bool = False ): from django.db import DEFAULT_DB_ALIAS, connections from django.test import TestCase diff --git a/pytest_django/plugin.py b/pytest_django/plugin.py index 3e9dd9c6..75f2c12f 100644 --- a/pytest_django/plugin.py +++ b/pytest_django/plugin.py @@ -120,6 +120,17 @@ def pytest_addoption(parser) -> None: default=None, help="Address and port for the live_server fixture.", ) + group.addoption( + "--testcase-class", + default=None, + help="The base TestCase class to patch for use with django. Useful for hypothesis users", + ) + group.addoption( + "--transaction-testcase-class", + default=None, + help="The base TransactionTestCase class to patch for use with django. " + "Useful for hypothesis users", + ) parser.addini( SETTINGS_MODULE_ENV, "Django settings module to use by pytest-django." )