diff --git a/src/database/deletion/triggers.py b/src/database/deletion/triggers.py index 70c3910f1..b330f3f96 100644 --- a/src/database/deletion/triggers.py +++ b/src/database/deletion/triggers.py @@ -191,21 +191,20 @@ def create_deletion_trigger_many_to_many( f""" NOT EXISTS ( SELECT 1 FROM {link_name} - WHERE {link_name}.{link_to_identifier} = {delete_name}.{to_delete_identifier} + WHERE {link_name}.{link_to_identifier} = OLD.{link_to_identifier} ) """ # noqa: S608 # never user input for link_name in link_names ) return DDL( f""" - CREATE TRIGGER IF NOT EXISTS delete_{link_name} - AFTER DELETE ON {trigger_name} + CREATE TRIGGER IF NOT EXISTS delete_orphan_{link_name} + AFTER DELETE ON {link_name} FOR EACH ROW BEGIN - DELETE FROM {link_name} - WHERE {link_name}.{link_from_identifier} = OLD.{trigger_identifier}; DELETE FROM {delete_name} - WHERE {links_clause}; + WHERE {delete_name}.{to_delete_identifier} = OLD.{link_to_identifier} + AND {links_clause}; END; """ # noqa: S608 # never user input ) diff --git a/src/database/model/helper_functions.py b/src/database/model/helper_functions.py index ad877a445..7562419c5 100644 --- a/src/database/model/helper_functions.py +++ b/src/database/model/helper_functions.py @@ -58,6 +58,7 @@ def many_to_many_link_factory( onupdate="CASCADE", ), primary_key=True, + index=True, ) ), ), diff --git a/src/database/model/serializers.py b/src/database/model/serializers.py index 55569a3ed..882d5623a 100644 --- a/src/database/model/serializers.py +++ b/src/database/model/serializers.py @@ -8,7 +8,7 @@ from starlette.status import HTTP_404_NOT_FOUND from authentication import KeycloakUser -from database.model.helper_functions import get_relationships, get_asset_by_identifier +from database.model.helper_functions import get_relationships from database.model.named_relation import NamedRelation, Taxonomy from database.model.ai_resource.resource_table import AIResourceORM from database.session import DbSession @@ -259,12 +259,45 @@ def create_getter_dict(attribute_serializers: Dict[str, Serializer]): def is_soft_deleted(item) -> bool: if hasattr(item, "date_deleted") and item.date_deleted: return True - if isinstance(item, AIResourceORM): - with DbSession() as session: - clazz_, resource = get_asset_by_identifier(item.identifier, session) - return resource.date_deleted is not None + if isinstance(item, AIResourceORM) or item.__class__.__name__ in ( + "AIAssetTable", + "AgentTable", + "KnowledgeAssetTable", + ): + from sqlalchemy.orm import object_session + from database.model.helper_functions import get_asset_type_by_abbreviation + + session = object_session(item) - return False # Not sure what cases are not covered here. + def check_deletion(sess): + asset_type_map = get_asset_type_by_abbreviation() + model_class = next( + (cls for cls in asset_type_map.values() if cls.__tablename__ == item.type), + None, + ) + if model_class: + match item.__class__.__name__: + case "AIAssetTable": + id_field_name = "ai_asset_id" + case "AgentTable": + id_field_name = "agent_id" + case "KnowledgeAssetTable": + id_field_name = "knowledge_asset_id" + case _: + id_field_name = "ai_resource_id" + + resource = sess.query(model_class).filter( + getattr(model_class, id_field_name) == item.identifier + ).first() + return resource is None or resource.date_deleted is not None + return False + + if session is not None: + return check_deletion(session) + else: + with DbSession() as new_session: + return check_deletion(new_session) + return False class GetterDictSerializer(GetterDict): def get(self, key: Any, default: Any = None) -> Any: diff --git a/src/routers/resource_router.py b/src/routers/resource_router.py index b6ec00a25..27ebb8f09 100644 --- a/src/routers/resource_router.py +++ b/src/routers/resource_router.py @@ -4,7 +4,7 @@ from functools import partial from typing import Annotated, Any, Literal, Sequence, Type, TypeVar, Union, Callable, cast from fastapi import APIRouter, Depends, HTTPException, status, Query, Path -from sqlalchemy import and_, func +from sqlalchemy import and_, func, or_, delete from sqlalchemy.sql.operators import is_ from sqlmodel import SQLModel, Session, select @@ -23,6 +23,11 @@ from database.model.concept.concept import AIoDConcept from database.model.platform.platform import Platform from database.model.platform.platform_names import PlatformName +from database.model.ai_resource.resource_table import ( + AIResourceORM, + AIResourcePartLink, + AIResourceRelevantLink, +) from database.model.serializers import deserialize_resource_relationships from database.review import Submission, SubmissionCreateV2, AssetReview from database.session import DbSession @@ -630,6 +635,28 @@ def delete_resource( session.delete(resource) else: resource.date_deleted = datetime.datetime.utcnow() + if isinstance(resource, AIResource) and resource.ai_resource_id: + # Sever all relationships immediately on soft delete + session.execute( + delete(AIResourcePartLink).where( + or_( + AIResourcePartLink.parent_identifier + == resource.ai_resource_id, + AIResourcePartLink.child_identifier + == resource.ai_resource_id, + ) + ) + ) + session.execute( + delete(AIResourceRelevantLink).where( + or_( + AIResourceRelevantLink.parent_identifier + == resource.ai_resource_id, + AIResourceRelevantLink.relevant_identifier + == resource.ai_resource_id, + ) + ) + ) session.add(resource) session.commit() return None diff --git a/verify_triggers.py b/verify_triggers.py new file mode 100644 index 000000000..ab74f534c --- /dev/null +++ b/verify_triggers.py @@ -0,0 +1,26 @@ +import sys +from pathlib import Path + +# Add src to path +sys.path.append(str(Path.cwd() / "src")) + +from database.deletion.triggers import create_deletion_trigger_many_to_many +from database.model.dataset.dataset import Dataset +from database.model.ai_resource.keyword import Keyword +from database.model.helper_functions import many_to_many_link_factory + +# Mock many-to-many relationship details +link_model = many_to_many_link_factory("dataset", "keyword", from_identifier_type=str) +other_links = ["model_keyword_link", "publication_keyword_link"] + +ddl = create_deletion_trigger_many_to_many( + trigger=Dataset, + link=link_model, + to_delete=Keyword, + other_links=other_links +) + +print("-" * 20) +print("Generated DDL:") +print(ddl.statement) +print("-" * 20)