diff --git a/core/testcontainers/core/generic.py b/core/testcontainers/core/generic.py index 591a4a8a8..1410321ee 100644 --- a/core/testcontainers/core/generic.py +++ b/core/testcontainers/core/generic.py @@ -29,6 +29,8 @@ class DbContainer(DockerContainer): """ **DEPRECATED (for removal)** + Please use database-specific container classes or `SqlContainer` instead. + # from testcontainers.generic.sql import SqlContainer Generic database container. """ diff --git a/modules/generic/README.rst b/modules/generic/README.rst index 4497ec922..4b7281121 100644 --- a/modules/generic/README.rst +++ b/modules/generic/README.rst @@ -9,6 +9,7 @@ FastAPI container that is using :code:`ServerContainer` >>> from testcontainers.generic import ServerContainer >>> from testcontainers.core.waiting_utils import wait_for_logs + >>> from testcontainers.core.image import DockerImage >>> with DockerImage(path="./modules/generic/tests/samples/fastapi", tag="fastapi-test:latest") as image: ... with ServerContainer(port=80, image=image) as fastapi_server: @@ -50,3 +51,39 @@ A more advance use-case, where we are using a FastAPI container that is using Re ... response = client.get(f"/get/{test_data['key']}") ... assert response.status_code == 200, "Failed to get data" ... assert response.json() == {"key": test_data["key"], "value": test_data["value"]} + +.. autoclass:: testcontainers.generic.SqlContainer +.. title:: testcontainers.generic.SqlContainer + +Postgres container that is using :code:`SqlContainer` + +.. doctest:: + + >>> from testcontainers.generic import SqlContainer + >>> from testcontainers.generic.providers.sql_connection_wait_strategy import SqlAlchemyConnectWaitStrategy + >>> from sqlalchemy import text + >>> import sqlalchemy + + >>> class CustomPostgresContainer(SqlContainer): + ... def __init__(self, image="postgres:15-alpine", + ... port=5432, username="test", password="test", dbname="test"): + ... super().__init__(image=image, wait_strategy=SqlAlchemyConnectWaitStrategy()) + ... self.port_to_expose = port + ... self.username = username + ... self.password = password + ... self.dbname = dbname + ... def get_connection_url(self) -> str: + ... host = self.get_container_host_ip() + ... port = self.get_exposed_port(self.port_to_expose) + ... return f"postgresql://{self.username}:{self.password}@{host}:{port}/{self.dbname}" + ... def _configure(self) -> None: + ... self.with_exposed_ports(self.port_to_expose) + ... self.with_env("POSTGRES_USER", self.username) + ... self.with_env("POSTGRES_PASSWORD", self.password) + ... self.with_env("POSTGRES_DB", self.dbname) + + >>> with CustomPostgresContainer() as postgres: + ... engine = sqlalchemy.create_engine(postgres.get_connection_url()) + ... with engine.connect() as conn: + ... result = conn.execute(text("SELECT 1")) + ... assert result.scalar() == 1 diff --git a/modules/generic/testcontainers/generic/__init__.py b/modules/generic/testcontainers/generic/__init__.py index f239a80c6..ce6610a3c 100644 --- a/modules/generic/testcontainers/generic/__init__.py +++ b/modules/generic/testcontainers/generic/__init__.py @@ -1 +1,2 @@ from .server import ServerContainer # noqa: F401 +from .sql import SqlContainer # noqa: F401 diff --git a/modules/generic/testcontainers/generic/providers/__init__.py b/modules/generic/testcontainers/generic/providers/__init__.py new file mode 100644 index 000000000..5b5eb95a2 --- /dev/null +++ b/modules/generic/testcontainers/generic/providers/__init__.py @@ -0,0 +1 @@ +from .sql_connection_wait_strategy import SqlAlchemyConnectWaitStrategy # noqa: F401 diff --git a/modules/generic/testcontainers/generic/providers/sql_connection_wait_strategy.py b/modules/generic/testcontainers/generic/providers/sql_connection_wait_strategy.py new file mode 100644 index 000000000..bad46c743 --- /dev/null +++ b/modules/generic/testcontainers/generic/providers/sql_connection_wait_strategy.py @@ -0,0 +1,48 @@ +# This module provides a wait strategy for SQL database connectivity testing using SQLAlchemy. +# It includes handling for transient exceptions and connection retries. + +import logging + +from testcontainers.core.waiting_utils import WaitStrategy, WaitStrategyTarget + +logger = logging.getLogger(__name__) + +ADDITIONAL_TRANSIENT_ERRORS = [] +try: + from sqlalchemy.exc import DBAPIError + + ADDITIONAL_TRANSIENT_ERRORS.append(DBAPIError) +except ImportError: + logger.debug("SQLAlchemy not available, skipping DBAPIError handling") + + +class SqlAlchemyConnectWaitStrategy(WaitStrategy): + """Wait strategy for database connectivity testing using SQLAlchemy.""" + + def __init__(self): + super().__init__() + self.with_transient_exceptions(TimeoutError, ConnectionError, *ADDITIONAL_TRANSIENT_ERRORS) + + def wait_until_ready(self, container: WaitStrategyTarget) -> None: + """Test database connectivity with retry logic until success or timeout.""" + if not hasattr(container, "get_connection_url"): + raise AttributeError(f"Container {container} must have a get_connection_url method") + + try: + import sqlalchemy + except ImportError as e: + raise ImportError("SQLAlchemy is required for database containers") from e + + def _test_connection() -> bool: + """Test database connection, returning True if successful.""" + engine = sqlalchemy.create_engine(container.get_connection_url()) + try: + with engine.connect(): + logger.info("Database connection successful") + return True + finally: + engine.dispose() + + result = self._poll(_test_connection) + if not result: + raise TimeoutError(f"Database connection failed after {self._startup_timeout}s timeout") diff --git a/modules/generic/testcontainers/generic/server.py b/modules/generic/testcontainers/generic/server.py index 61e9c5eb9..fe990f179 100644 --- a/modules/generic/testcontainers/generic/server.py +++ b/modules/generic/testcontainers/generic/server.py @@ -9,8 +9,6 @@ from testcontainers.core.image import DockerImage from testcontainers.core.waiting_utils import wait_container_is_ready -# This comment can be removed (Used for testing) - class ServerContainer(DockerContainer): """ diff --git a/modules/generic/testcontainers/generic/sql.py b/modules/generic/testcontainers/generic/sql.py new file mode 100644 index 000000000..c7ed755ed --- /dev/null +++ b/modules/generic/testcontainers/generic/sql.py @@ -0,0 +1,139 @@ +import logging +from typing import Any, Optional +from urllib.parse import quote, urlencode + +from testcontainers.core.container import DockerContainer +from testcontainers.core.exceptions import ContainerStartException +from testcontainers.core.waiting_utils import WaitStrategy + +logger = logging.getLogger(__name__) + + +class SqlContainer(DockerContainer): + """ + Generic SQL database container providing common functionality. + + This class can serve as a base for database-specific container implementations. + It provides connection management, URL construction, and basic lifecycle methods. + Database connection readiness is automatically handled by the provided wait strategy. + + Note: `SqlAlchemyConnectWaitStrategy` from `sql_connection_wait_strategy` is a provided wait strategy for SQL databases. + """ + + def __init__(self, image: str, wait_strategy: WaitStrategy, **kwargs): + """ + Initialize SqlContainer with optional wait strategy. + + Args: + image: Docker image name + wait_strategy: Wait strategy for SQL database connectivity + **kwargs: Additional arguments passed to DockerContainer + """ + super().__init__(image, **kwargs) + self.wait_strategy = wait_strategy + + def _create_connection_url( + self, + dialect: str, + username: str, + password: str, + host: Optional[str] = None, + port: Optional[int] = None, + dbname: Optional[str] = None, + query_params: Optional[dict[str, str]] = None, + **kwargs: Any, + ) -> str: + """ + Create a database connection URL. + + Args: + dialect: Database dialect (e.g., 'postgresql', 'mysql') + username: Database username + password: Database password + host: Database host (defaults to container host) + port: Database port + dbname: Database name + query_params: Additional query parameters for the URL + **kwargs: Additional parameters (checked for deprecated usage) + + Returns: + str: Formatted database connection URL + + Raises: + ValueError: If unexpected arguments are provided or required parameters are missing + ContainerStartException: If container is not started + """ + + if self._container is None: + raise ContainerStartException("Container has not been started") + + host = host or self.get_container_host_ip() + exposed_port = self.get_exposed_port(port) + quoted_password = quote(password, safe="") + quoted_username = quote(username, safe="") + url = f"{dialect}://{quoted_username}:{quoted_password}@{host}:{exposed_port}" + + if dbname: + quoted_dbname = quote(dbname, safe="") + url = f"{url}/{quoted_dbname}" + + if query_params: + query_string = urlencode(query_params) + url = f"{url}?{query_string}" + + return url + + def start(self) -> "SqlContainer": + """ + Start the database container and perform initialization. + + Returns: + SqlContainer: Self for method chaining + + Raises: + ContainerStartException: If container fails to start + Exception: If configuration, seed transfer, or connection fails + """ + logger.info(f"Starting database container: {self.image}") + + try: + self._configure() + self.waiting_for(self.wait_strategy) + super().start() + self._transfer_seed() + logger.info("Database container started successfully") + except Exception as e: + logger.error(f"Failed to start database container: {e}") + raise + + return self + + def _configure(self) -> None: + """ + Configure the database container before starting. + + Raises: + NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError("Subclasses must implement _configure()") + + def _transfer_seed(self) -> None: + """ + Transfer seed data to the database container. + + This method can be overridden by subclasses to provide + database-specific seeding functionality. + """ + logger.debug("No seed data to transfer") + + def get_connection_url(self) -> str: + """ + Get the database connection URL. + + Returns: + str: Database connection URL + + Raises: + NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError("Subclasses must implement get_connection_url()") diff --git a/modules/generic/tests/test_generic.py b/modules/generic/tests/test_server.py similarity index 100% rename from modules/generic/tests/test_generic.py rename to modules/generic/tests/test_server.py diff --git a/modules/generic/tests/test_sql.py b/modules/generic/tests/test_sql.py new file mode 100644 index 000000000..69fff2427 --- /dev/null +++ b/modules/generic/tests/test_sql.py @@ -0,0 +1,238 @@ +import pytest +from unittest.mock import patch + +from testcontainers.core.exceptions import ContainerStartException +from testcontainers.generic.sql import SqlContainer +from testcontainers.generic.providers.sql_connection_wait_strategy import SqlAlchemyConnectWaitStrategy + + +class SimpleSqlContainer(SqlContainer): + """Simple concrete implementation for testing.""" + + def __init__(self, image: str = "postgres:13"): + super().__init__(image, wait_strategy=SqlAlchemyConnectWaitStrategy()) + self.username = "testuser" + self.password = "testpass" + self.dbname = "testdb" + self.port = 5432 + + def get_connection_url(self) -> str: + return self._create_connection_url( + dialect="postgresql", username=self.username, password=self.password, port=self.port, dbname=self.dbname + ) + + def _configure(self) -> None: + self.with_env("POSTGRES_USER", self.username) + self.with_env("POSTGRES_PASSWORD", self.password) + self.with_env("POSTGRES_DB", self.dbname) + self.with_exposed_ports(self.port) + + +class TestSqlContainer: + def test_abstract_methods_raise_not_implemented(self): + container = SqlContainer("test:latest", SqlAlchemyConnectWaitStrategy()) + + with pytest.raises(NotImplementedError): + container.get_connection_url() + + with pytest.raises(NotImplementedError): + container._configure() + + def test_transfer_seed_default_behavior(self): + container = SqlContainer("test:latest", SqlAlchemyConnectWaitStrategy()) + # Should not raise an exception + container._transfer_seed() + + def test_connection_url_creation_basic(self): + container = SimpleSqlContainer() + container._container = type("MockContainer", (), {})() # Simple mock + container.get_container_host_ip = lambda: "localhost" + container.get_exposed_port = lambda port: port + + url = container._create_connection_url(dialect="postgresql", username="user", password="pass", port=5432) + + assert url == "postgresql://user:pass@localhost:5432" + + def test_connection_url_with_database_name(self): + container = SimpleSqlContainer() + container._container = type("MockContainer", (), {})() + container.get_container_host_ip = lambda: "localhost" + container.get_exposed_port = lambda port: port + + url = container._create_connection_url( + dialect="postgresql", username="user", password="pass", port=5432, dbname="mydb" + ) + + assert url == "postgresql://user:pass@localhost:5432/mydb" + + def test_connection_url_with_special_characters(self): + container = SimpleSqlContainer() + container._container = type("MockContainer", (), {})() + container.get_container_host_ip = lambda: "localhost" + container.get_exposed_port = lambda port: port + + url = container._create_connection_url( + dialect="postgresql", username="user@domain", password="p@ss/word", port=5432 + ) + + # Check that special characters are URL encoded + assert "user%40domain" in url + assert "p%40ss%2Fword" in url + + def test_connection_url_with_query_params(self): + container = SimpleSqlContainer() + container._container = type("MockContainer", (), {})() + container.get_container_host_ip = lambda: "localhost" + container.get_exposed_port = lambda port: port + + url = container._create_connection_url( + dialect="postgresql", + username="user", + password="pass", + port=5432, + query_params={"ssl": "require", "timeout": "30"}, + ) + + assert "?" in url + assert "ssl=require" in url + assert "timeout=30" in url + + def test_connection_url_type_errors(self): + """Test that _create_connection_url raises TypeError with invalid types""" + container = SimpleSqlContainer() + container._container = type("MockContainer", (), {"id": "test-id"})() + + # Mock get_exposed_port to simulate what happens with None port + with patch.object(container, "get_exposed_port") as mock_get_port: + # Simulate the TypeError that would occur when int(None) is called + mock_get_port.side_effect = TypeError( + "int() argument must be a string, a bytes-like object or a real number, not 'NoneType'" + ) + + with pytest.raises(TypeError, match="int\\(\\) argument must be a string"): + container._create_connection_url("postgresql", "user", "pass", port=None) + + def test_connection_url_container_not_started(self): + container = SimpleSqlContainer() + container._container = None + + with pytest.raises(ContainerStartException, match="Container has not been started"): + container._create_connection_url("postgresql", "user", "pass", port=5432) + + def test_container_configuration(self): + container = SimpleSqlContainer("postgres:13") + + # Test that configuration sets up environment + container._configure() + + assert container.env["POSTGRES_USER"] == "testuser" + assert container.env["POSTGRES_PASSWORD"] == "testpass" + assert container.env["POSTGRES_DB"] == "testdb" + + def test_concrete_container_connection_url(self): + container = SimpleSqlContainer() + container._container = type("MockContainer", (), {})() + container.get_container_host_ip = lambda: "localhost" + container.get_exposed_port = lambda port: 5432 + + url = container.get_connection_url() + + assert url.startswith("postgresql://") + assert "testuser" in url + assert "testpass" in url + assert "testdb" in url + assert "localhost:5432" in url + + def test_container_inheritance(self): + container = SimpleSqlContainer() + + assert isinstance(container, SqlContainer) + assert hasattr(container, "get_connection_url") + assert hasattr(container, "_configure") + assert hasattr(container, "_transfer_seed") + assert hasattr(container, "start") + + def test_additional_transient_errors_list(self): + from testcontainers.generic.providers.sql_connection_wait_strategy import ADDITIONAL_TRANSIENT_ERRORS + + assert isinstance(ADDITIONAL_TRANSIENT_ERRORS, list) + # List may be empty if SQLAlchemy not available, or contain DBAPIError if it is + + def test_empty_password_handling(self): + container = SimpleSqlContainer() + container._container = type("MockContainer", (), {})() + container.get_container_host_ip = lambda: "localhost" + container.get_exposed_port = lambda port: port + + url = container._create_connection_url(dialect="postgresql", username="user", password="", port=5432) + + assert url == "postgresql://user:@localhost:5432" + + def test_unicode_characters_in_credentials(self): + container = SimpleSqlContainer() + container._container = type("MockContainer", (), {})() + container.get_container_host_ip = lambda: "localhost" + container.get_exposed_port = lambda port: port + + url = container._create_connection_url( + dialect="postgresql", username="usér", password="päss", port=5432, dbname="tëstdb" + ) + + assert "us%C3%A9r" in url + assert "p%C3%A4ss" in url + assert "t%C3%ABstdb" in url + + def test_start_postgres_container_integration(self): + """Integration test that actually starts a PostgreSQL container.""" + container = SimpleSqlContainer() + + # This will start the container and test the connection + container.start() + + # Verify the container is running + assert container._container is not None + + # Test that we can get a connection URL + url = container.get_connection_url() + assert url.startswith("postgresql://") + assert "testuser" in url + assert "testdb" in url + + # Verify environment variables are set + assert container.env["POSTGRES_USER"] == "testuser" + assert container.env["POSTGRES_PASSWORD"] == "testpass" + assert container.env["POSTGRES_DB"] == "testdb" + + # check logs + logs = container.get_logs() + assert "database system is ready to accept connections" in logs[0].decode("utf-8").lower() + + def test_sql_postgres_container_integration(self): + """Integration test for SqlContainer with PostgreSQL.""" + container = SimpleSqlContainer() + + # This will start the container and test the connection + container.start() + + # Verify the container is running + assert container._container is not None + + # Test that we can get a connection URL + url = container.get_connection_url() + + # check sql operations + import sqlalchemy + + engine = sqlalchemy.create_engine(url) + with engine.connect() as conn: + # Create a test table + conn.execute( + sqlalchemy.text("CREATE TABLE IF NOT EXISTS test_table (id SERIAL PRIMARY KEY, name VARCHAR(50));") + ) + # Insert a test record + conn.execute(sqlalchemy.text("INSERT INTO test_table (name) VALUES ('test_name');")) + # Query the test record + result = conn.execute(sqlalchemy.text("SELECT name FROM test_table WHERE name='test_name';")) + fetched = result.fetchone() + assert fetched is not None + assert fetched[0] == "test_name" diff --git a/poetry.lock b/poetry.lock index 7411ad744..5d60ab8b0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6797,7 +6797,7 @@ cockroachdb = [] cosmosdb = ["azure-cosmos"] db2 = ["ibm_db_sa", "sqlalchemy"] elasticsearch = [] -generic = ["httpx", "redis"] +generic = ["httpx", "redis", "sqlalchemy"] google = ["google-cloud-datastore", "google-cloud-pubsub"] influxdb = ["influxdb", "influxdb-client"] k3s = ["kubernetes", "pyyaml"] @@ -6836,4 +6836,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.1" python-versions = ">=3.9.2,<4.0" -content-hash = "241e8b6ba610907adea4496fdeaef4c3fdc3315d222ab87004692aa9371698fa" +content-hash = "c5a635d3fa182f964fe96d5685efdfb4bd5fcbab829388a5d5a5be90fb81eaee" diff --git a/pyproject.toml b/pyproject.toml index 59f2f5a85..c5eb5b4c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,6 +136,7 @@ elasticsearch = [] generic = [ "httpx", "redis", + "sqlalchemy", ] # The advance doctests for ServerContainer require redis test_module_import = ["httpx"] google = ["google-cloud-pubsub", "google-cloud-datastore"] @@ -230,6 +231,9 @@ exclude_lines = [ "pass", "raise NotImplementedError", # TODO: used in core/generic.py, not sure we need DbContainer ] +omit = [ + "core/testcontainers/core/generic.py", # Marked for deprecation +] [tool.ruff] target-version = "py39"