Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/database/deletion/triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,21 +191,20 @@ def create_deletion_trigger_many_to_many(
f"""
NOT EXISTS (
SELECT 1 FROM {link_name}
WHERE {link_name}.{link_to_identifier} = {delete_name}.{to_delete_identifier}
WHERE {link_name}.{link_to_identifier} = OLD.{link_to_identifier}
)
""" # noqa: S608 # never user input
for link_name in link_names
)
return DDL(
f"""
CREATE TRIGGER IF NOT EXISTS delete_{link_name}
AFTER DELETE ON {trigger_name}
CREATE TRIGGER IF NOT EXISTS delete_orphan_{link_name}
AFTER DELETE ON {link_name}
FOR EACH ROW
BEGIN
DELETE FROM {link_name}
WHERE {link_name}.{link_from_identifier} = OLD.{trigger_identifier};
DELETE FROM {delete_name}
WHERE {links_clause};
WHERE {delete_name}.{to_delete_identifier} = OLD.{link_to_identifier}
AND {links_clause};
END;
""" # noqa: S608 # never user input
)
1 change: 1 addition & 0 deletions src/database/model/helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def many_to_many_link_factory(
onupdate="CASCADE",
),
primary_key=True,
index=True,
)
),
),
Expand Down
45 changes: 39 additions & 6 deletions src/database/model/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from starlette.status import HTTP_404_NOT_FOUND

from authentication import KeycloakUser
from database.model.helper_functions import get_relationships, get_asset_by_identifier
from database.model.helper_functions import get_relationships
from database.model.named_relation import NamedRelation, Taxonomy
from database.model.ai_resource.resource_table import AIResourceORM
from database.session import DbSession
Expand Down Expand Up @@ -259,12 +259,45 @@ def create_getter_dict(attribute_serializers: Dict[str, Serializer]):
def is_soft_deleted(item) -> bool:
if hasattr(item, "date_deleted") and item.date_deleted:
return True
if isinstance(item, AIResourceORM):
with DbSession() as session:
clazz_, resource = get_asset_by_identifier(item.identifier, session)
return resource.date_deleted is not None
if isinstance(item, AIResourceORM) or item.__class__.__name__ in (
"AIAssetTable",
"AgentTable",
"KnowledgeAssetTable",
):
from sqlalchemy.orm import object_session
from database.model.helper_functions import get_asset_type_by_abbreviation

session = object_session(item)

return False # Not sure what cases are not covered here.
def check_deletion(sess):
asset_type_map = get_asset_type_by_abbreviation()
model_class = next(
(cls for cls in asset_type_map.values() if cls.__tablename__ == item.type),
None,
)
if model_class:
match item.__class__.__name__:
case "AIAssetTable":
id_field_name = "ai_asset_id"
case "AgentTable":
id_field_name = "agent_id"
case "KnowledgeAssetTable":
id_field_name = "knowledge_asset_id"
case _:
id_field_name = "ai_resource_id"

resource = sess.query(model_class).filter(
getattr(model_class, id_field_name) == item.identifier
).first()
return resource is None or resource.date_deleted is not None
return False

if session is not None:
return check_deletion(session)
else:
with DbSession() as new_session:
return check_deletion(new_session)
return False

class GetterDictSerializer(GetterDict):
def get(self, key: Any, default: Any = None) -> Any:
Expand Down
29 changes: 28 additions & 1 deletion src/routers/resource_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import partial
from typing import Annotated, Any, Literal, Sequence, Type, TypeVar, Union, Callable, cast
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path
from sqlalchemy import and_, func
from sqlalchemy import and_, func, or_, delete
from sqlalchemy.sql.operators import is_
from sqlmodel import SQLModel, Session, select

Expand All @@ -23,6 +23,11 @@
from database.model.concept.concept import AIoDConcept
from database.model.platform.platform import Platform
from database.model.platform.platform_names import PlatformName
from database.model.ai_resource.resource_table import (
AIResourceORM,
AIResourcePartLink,
AIResourceRelevantLink,
)
from database.model.serializers import deserialize_resource_relationships
from database.review import Submission, SubmissionCreateV2, AssetReview
from database.session import DbSession
Expand Down Expand Up @@ -630,6 +635,28 @@ def delete_resource(
session.delete(resource)
else:
resource.date_deleted = datetime.datetime.utcnow()
if isinstance(resource, AIResource) and resource.ai_resource_id:
# Sever all relationships immediately on soft delete
session.execute(
delete(AIResourcePartLink).where(
or_(
AIResourcePartLink.parent_identifier
== resource.ai_resource_id,
AIResourcePartLink.child_identifier
== resource.ai_resource_id,
)
)
)
session.execute(
delete(AIResourceRelevantLink).where(
or_(
AIResourceRelevantLink.parent_identifier
== resource.ai_resource_id,
AIResourceRelevantLink.relevant_identifier
== resource.ai_resource_id,
)
)
)
session.add(resource)
session.commit()
return None
Expand Down
26 changes: 26 additions & 0 deletions verify_triggers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import sys
from pathlib import Path

# Add src to path
sys.path.append(str(Path.cwd() / "src"))

from database.deletion.triggers import create_deletion_trigger_many_to_many
from database.model.dataset.dataset import Dataset
from database.model.ai_resource.keyword import Keyword
from database.model.helper_functions import many_to_many_link_factory

# Mock many-to-many relationship details
link_model = many_to_many_link_factory("dataset", "keyword", from_identifier_type=str)
other_links = ["model_keyword_link", "publication_keyword_link"]

ddl = create_deletion_trigger_many_to_many(
trigger=Dataset,
link=link_model,
to_delete=Keyword,
other_links=other_links
)

print("-" * 20)
print("Generated DDL:")
print(ddl.statement)
print("-" * 20)