Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions lib/galaxy/model/migrations/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,43 @@ def log_check_not_passed(self) -> None:
self._log_object_does_not_exist_message(name)


class CreateForeignKey(DDLAlterOperation):
"""Wraps alembic's create_foreign_key directive."""

def __init__(
self,
foreign_key_name: str,
table_name: str,
referent_table: str,
local_cols: List[str],
remote_cols: List[str],
**kw: Any,
) -> None:
super().__init__(table_name)
self.foreign_key_name = foreign_key_name
self.referent_table = referent_table
self.local_cols = local_cols
self.remote_cols = remote_cols
self.kw = kw

def batch_execute(self, batch_op) -> None:
batch_op.create_foreign_key(
self.foreign_key_name, self.referent_table, self.local_cols, self.remote_cols, **self.kw
)

def non_batch_execute(self) -> None:
op.create_foreign_key(
self.foreign_key_name, self.table_name, self.referent_table, self.local_cols, self.remote_cols, **self.kw
)

def pre_execute_check(self) -> bool:
return not foreign_key_exists(self.foreign_key_name, self.table_name, False)

def log_check_not_passed(self) -> None:
name = _table_object_description(self.foreign_key_name, self.table_name)
self._log_object_exists_message(name)


class CreateUniqueConstraint(DDLAlterOperation):
"""Wraps alembic's create_unique_constraint directive."""

Expand Down Expand Up @@ -294,6 +331,17 @@ def drop_index(index_name, table_name) -> None:
DropIndex(index_name, table_name).run()


def create_foreign_key(
foreign_key_name: str,
table_name: str,
referent_table: str,
local_cols: List[str],
remote_cols: List[str],
**kw: Any,
) -> None:
CreateForeignKey(foreign_key_name, table_name, referent_table, local_cols, remote_cols, **kw).run()


def create_unique_constraint(constraint_name: str, table_name: str, columns: List[str]) -> None:
CreateUniqueConstraint(constraint_name, table_name, columns).run()

Expand Down Expand Up @@ -328,6 +376,15 @@ def index_exists(index_name: str, table_name: str, default: bool) -> bool:
return any(index["name"] == index_name for index in indexes)


def foreign_key_exists(constraint_name: str, table_name: str, default: bool) -> bool:
"""Check if unique constraint exists. If running in offline mode, return default."""
if context.is_offline_mode():
_log_offline_mode_message(foreign_key_exists.__name__, default)
return default
constraints = _inspector().get_foreign_keys(table_name)
return any(c["name"] == constraint_name for c in constraints)


def unique_constraint_exists(constraint_name: str, table_name: str, default: bool) -> bool:
"""Check if unique constraint exists. If running in offline mode, return default."""
if context.is_offline_mode():
Expand Down