Skip to content

Commit 8d77bd3

Browse files
authored
fix(mysql): add dialect parameter instead of hardcoded mysql dialect (#739)
closes #727 * add parameter `dialect`; * tests fixing and add some assertions
1 parent 3436cbf commit 8d77bd3

File tree

2 files changed

+45
-13
lines changed

2 files changed

+45
-13
lines changed

modules/mysql/testcontainers/mysql/__init__.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ class MySqlContainer(DbContainer):
3131
The example will spin up a MySql database to which you can connect with the credentials
3232
passed in the constructor. Alternatively, you may use the :code:`get_connection_url()`
3333
method which returns a sqlalchemy-compatible url in format
34-
:code:`dialect+driver://username:password@host:port/database`.
34+
:code:`mysql+dialect://username:password@host:port/database`.
3535
3636
.. doctest::
3737
3838
>>> import sqlalchemy
3939
>>> from testcontainers.mysql import MySqlContainer
4040
41-
>>> with MySqlContainer('mysql:5.7.17') as mysql:
41+
>>> with MySqlContainer("mysql:5.7.17", dialect="pymysql") as mysql:
4242
... engine = sqlalchemy.create_engine(mysql.get_connection_url())
4343
... with engine.begin() as connection:
4444
... result = connection.execute(sqlalchemy.text("select version()"))
@@ -64,6 +64,7 @@ class MySqlContainer(DbContainer):
6464
def __init__(
6565
self,
6666
image: str = "mysql:latest",
67+
dialect: Optional[str] = None,
6768
username: Optional[str] = None,
6869
root_password: Optional[str] = None,
6970
password: Optional[str] = None,
@@ -72,6 +73,10 @@ def __init__(
7273
seed: Optional[str] = None,
7374
**kwargs,
7475
) -> None:
76+
if dialect is not None and dialect.startswith("mysql+"):
77+
msg = "Please remove 'mysql+' prefix from dialect parameter"
78+
raise ValueError(msg)
79+
7580
raise_for_deprecated_parameter(kwargs, "MYSQL_USER", "username")
7681
raise_for_deprecated_parameter(kwargs, "MYSQL_ROOT_PASSWORD", "root_password")
7782
raise_for_deprecated_parameter(kwargs, "MYSQL_PASSWORD", "password")
@@ -85,6 +90,9 @@ def __init__(
8590
self.password = password or environ.get("MYSQL_PASSWORD", "test")
8691
self.dbname = dbname or environ.get("MYSQL_DATABASE", "test")
8792

93+
self.dialect = dialect or environ.get("MYSQL_DIALECT", None)
94+
self._db_url_dialect_part = "mysql" if self.dialect is None else f"mysql+{self.dialect}"
95+
8896
if self.username == "root":
8997
self.root_password = self.password
9098
self.seed = seed
@@ -105,7 +113,11 @@ def _connect(self) -> None:
105113

106114
def get_connection_url(self) -> str:
107115
return super()._create_connection_url(
108-
dialect="mysql+pymysql", username=self.username, password=self.password, dbname=self.dbname, port=self.port
116+
dialect=self._db_url_dialect_part,
117+
username=self.username,
118+
password=self.password,
119+
dbname=self.dbname,
120+
port=self.port,
109121
)
110122

111123
def _transfer_seed(self) -> None:

modules/mysql/tests/test_mysql.py

+30-10
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@
1111

1212
@pytest.mark.inside_docker_check
1313
def test_docker_run_mysql():
14-
config = MySqlContainer("mysql:8.3.0")
14+
config = MySqlContainer("mysql:8.3.0", dialect="pymysql")
1515
with config as mysql:
16-
engine = sqlalchemy.create_engine(mysql.get_connection_url())
16+
connection_url = mysql.get_connection_url()
17+
18+
assert mysql.dialect == "pymysql"
19+
assert connection_url.startswith("mysql+pymysql://")
20+
21+
engine = sqlalchemy.create_engine(connection_url)
1722
with engine.begin() as connection:
1823
result = connection.execute(sqlalchemy.text("select version()"))
1924
for row in result:
@@ -22,7 +27,7 @@ def test_docker_run_mysql():
2227

2328
@pytest.mark.skipif(is_arm(), reason="mysql container not available for ARM")
2429
def test_docker_run_legacy_mysql():
25-
config = MySqlContainer("mysql:5.7.44")
30+
config = MySqlContainer("mysql:5.7.44", dialect="pymysql")
2631
with config as mysql:
2732
engine = sqlalchemy.create_engine(mysql.get_connection_url())
2833
with engine.begin() as connection:
@@ -35,7 +40,7 @@ def test_docker_run_legacy_mysql():
3540
def test_docker_run_mysql_8_seed():
3641
# Avoid pytest CWD path issues
3742
SEEDS_PATH = (Path(__file__).parent / "seeds").absolute()
38-
config = MySqlContainer("mysql:8", seed=SEEDS_PATH)
43+
config = MySqlContainer("mysql:8", dialect="pymysql", seed=str(SEEDS_PATH))
3944
with config as mysql:
4045
engine = sqlalchemy.create_engine(mysql.get_connection_url())
4146
with engine.begin() as connection:
@@ -45,7 +50,7 @@ def test_docker_run_mysql_8_seed():
4550

4651
@pytest.mark.parametrize("version", ["11.3.2", "10.11.7"])
4752
def test_docker_run_mariadb(version: str):
48-
with MySqlContainer(f"mariadb:{version}") as mariadb:
53+
with MySqlContainer(f"mariadb:{version}", dialect="pymysql") as mariadb:
4954
engine = sqlalchemy.create_engine(mariadb.get_connection_url())
5055
with engine.begin() as connection:
5156
result = connection.execute(sqlalchemy.text("select version()"))
@@ -55,14 +60,29 @@ def test_docker_run_mariadb(version: str):
5560

5661
def test_docker_env_variables():
5762
with (
58-
mock.patch.dict("os.environ", MYSQL_USER="demo", MYSQL_DATABASE="custom_db"),
63+
mock.patch.dict("os.environ", MYSQL_DIALECT="pymysql", MYSQL_USER="demo", MYSQL_DATABASE="custom_db"),
5964
MySqlContainer("mariadb:10.6.5").with_bind_ports(3306, 32785) as container,
6065
):
6166
url = container.get_connection_url()
6267
pattern = r"mysql\+pymysql:\/\/demo:test@[\w,.]+:(3306|32785)\/custom_db"
6368
assert re.match(pattern, url)
6469

6570

71+
@pytest.mark.parametrize(
72+
"dialect",
73+
[
74+
"mysql+pymysql",
75+
"mysql+mariadb",
76+
"mysql+mysqldb",
77+
],
78+
)
79+
def test_mysql_dialect_expecting_error_on_mysql_prefix(dialect: str):
80+
match = f"Please remove *.* prefix from dialect parameter"
81+
82+
with pytest.raises(ValueError, match=match):
83+
_ = MySqlContainer("mariadb:10.6.5", dialect=dialect)
84+
85+
6686
# This is a feature in the generic DbContainer class
6787
# but it can't be tested on its own
6888
# so is tested in various database modules:
@@ -75,18 +95,18 @@ def test_quoted_password():
7595
user = "root"
7696
password = "p@$%25+0&%rd :/!=?"
7797
quoted_password = "p%40%24%2525+0%26%25rd %3A%2F%21%3D%3F"
78-
driver = "pymysql"
79-
with MySqlContainer("mariadb:10.6.5", username=user, password=password) as container:
98+
dialect = "pymysql"
99+
with MySqlContainer("mariadb:10.6.5", dialect=dialect, username=user, password=password) as container:
80100
host = container.get_container_host_ip()
81101
port = container.get_exposed_port(3306)
82-
expected_url = f"mysql+{driver}://{user}:{quoted_password}@{host}:{port}/test"
102+
expected_url = f"mysql+{dialect}://{user}:{quoted_password}@{host}:{port}/test"
83103
url = container.get_connection_url()
84104
assert url == expected_url
85105

86106
with sqlalchemy.create_engine(expected_url).begin() as connection:
87107
connection.execute(sqlalchemy.text("select version()"))
88108

89-
raw_pass_url = f"mysql+{driver}://{user}:{password}@{host}:{port}/test"
109+
raw_pass_url = f"mysql+{dialect}://{user}:{password}@{host}:{port}/test"
90110
with pytest.raises(Exception):
91111
with sqlalchemy.create_engine(raw_pass_url).begin() as connection:
92112
connection.execute(sqlalchemy.text("select version()"))

0 commit comments

Comments
 (0)