diff --git a/pyproject.toml b/pyproject.toml index e7ecca35de..0c6fb2ae8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ include = ["src/zenml", "*.txt", "*.sh", "*.md"] zenml = "zenml.cli.cli:cli" [tool.poetry.dependencies] -alembic = { version = "~1.8.1" } +alembic = { version = ">=1.8.1,<=1.15.2" } bcrypt = { version = "4.0.1" } click = "^8.0.1,<8.1.8" cloudpickle = ">=2.0.0,<3" diff --git a/src/zenml/zen_stores/migrations/alembic.py b/src/zenml/zen_stores/migrations/alembic.py index 65430cafc6..a46e6481dd 100644 --- a/src/zenml/zen_stores/migrations/alembic.py +++ b/src/zenml/zen_stores/migrations/alembic.py @@ -19,7 +19,15 @@ """ from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Union, +) from alembic.config import Config from alembic.runtime.environment import EnvironmentContext @@ -39,7 +47,11 @@ def include_object( - object: Any, name: str, type_: str, *args: Any, **kwargs: Any + object: Any, # Use Any for backward compatibility + name: Optional[str], + type_: str, # Use string instead of Literal for backward compatibility + reflected: bool = False, + compare_to: Optional[Any] = None, ) -> bool: """Function used to exclude tables from the migration scripts. @@ -47,8 +59,8 @@ def include_object( object: The schema item object to check. name: The name of the object to check. type_: The type of the object to check. - *args: Additional arguments. - **kwargs: Additional keyword arguments. + reflected: Whether this object is being reflected. + compare_to: The object being compared against. Returns: True if the object should be included, False otherwise. @@ -135,6 +147,7 @@ def run_migrations( fn_context_args["fn"] = fn with self.engine.connect() as connection: + # Configure the context with our metadata self.environment_context.configure( connection=connection, target_metadata=self.metadata, @@ -180,9 +193,15 @@ def current_revisions(self) -> List[str]: def do_get_current_rev(rev: _RevIdType, context: Any) -> List[Any]: nonlocal current_revisions - for r in self.script_directory.get_all_current( - rev # type:ignore [arg-type] - ): + # Handle rev parameter in a way that's compatible with different alembic versions + rev_input: Any + if isinstance(rev, str): + rev_input = rev + else: + rev_input = tuple(str(r) for r in rev) + + # Get current revision(s) + for r in self.script_directory.get_all_current(rev_input): if r is None: continue current_revisions.append(r.revision) @@ -200,7 +219,13 @@ def stamp(self, revision: str) -> None: """ def do_stamp(rev: _RevIdType, context: Any) -> List[Any]: - return self.script_directory._stamp_revs(revision, rev) + # Handle rev parameter in a way that's compatible with different alembic versions + if isinstance(rev, str): + return self.script_directory._stamp_revs(revision, rev) + else: + # Convert to tuple for compatibility + rev_tuple = tuple(str(r) for r in rev) + return self.script_directory._stamp_revs(revision, rev_tuple) self.run_migrations(do_stamp) @@ -212,10 +237,16 @@ def upgrade(self, revision: str = "heads") -> None: """ def do_upgrade(rev: _RevIdType, context: Any) -> List[Any]: - return self.script_directory._upgrade_revs( - revision, - rev, # type:ignore [arg-type] - ) + # Handle rev parameter in a way that's compatible with different alembic versions + if isinstance(rev, str): + return self.script_directory._upgrade_revs(revision, rev) + else: + if rev: + # Use first element or revs for compatibility + return self.script_directory._upgrade_revs( + revision, str(rev[0]) + ) + return [] self.run_migrations(do_upgrade) @@ -227,9 +258,14 @@ def downgrade(self, revision: str) -> None: """ def do_downgrade(rev: _RevIdType, context: Any) -> List[Any]: - return self.script_directory._downgrade_revs( - revision, - rev, # type:ignore [arg-type] - ) + # Handle rev parameter in a way that's compatible with different alembic versions + if isinstance(rev, str): + return self.script_directory._downgrade_revs(revision, rev) + else: + if rev: + return self.script_directory._downgrade_revs( + revision, str(rev[0]) + ) + return self.script_directory._downgrade_revs(revision, None) self.run_migrations(do_downgrade)