Skip to content

Commit 39d995b

Browse files
authored
Merge pull request #81 from mts-ai/fix-cascade-changes
Fix cascade changes
2 parents 08d1f7f + 715d10a commit 39d995b

File tree

5 files changed

+196
-3
lines changed

5 files changed

+196
-3
lines changed

Diff for: fastapi_jsonapi/data_layers/sqla_orm.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Type
44

55
from sqlalchemy import delete, func, select
6-
from sqlalchemy.exc import DBAPIError, IntegrityError, NoResultFound
6+
from sqlalchemy.exc import DBAPIError, IntegrityError, MissingGreenlet, NoResultFound
77
from sqlalchemy.ext.asyncio import AsyncSession, AsyncSessionTransaction
88
from sqlalchemy.inspection import inspect
99
from sqlalchemy.orm import joinedload, selectinload
@@ -185,6 +185,18 @@ async def apply_relationships(self, obj: TypeModel, data_create: BaseJSONAPIItem
185185
related_id_field=relationship_info.id_field_name,
186186
id_value=relationship_in.data.id,
187187
)
188+
189+
try:
190+
hasattr(obj, relation_name)
191+
except MissingGreenlet:
192+
raise InternalServerError(
193+
detail=(
194+
f"Error of loading the {relation_name!r} relationship. "
195+
f"Please add this relationship to include query parameter explicitly."
196+
),
197+
parameter="include",
198+
)
199+
188200
# todo: relation name may be different?
189201
setattr(obj, relation_name, related_data)
190202

@@ -386,8 +398,10 @@ async def delete_object(self, obj: TypeModel, view_kwargs: dict):
386398
:param view_kwargs: kwargs from the resource view.
387399
"""
388400
await self.before_delete_object(obj, view_kwargs)
401+
stmt = delete(self.model).where(self.model.id == obj.id)
402+
389403
try:
390-
await self.session.delete(obj)
404+
await self.session.execute(stmt)
391405
await self.save()
392406
except DBAPIError as e:
393407
await self.session.rollback()

Diff for: tests/common.py

+18
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from os import getenv
22
from pathlib import Path
33

4+
from sqlalchemy import event
5+
from sqlalchemy.engine import Engine
6+
47

58
def sqla_uri():
69
testing_db_url = getenv("TESTING_DB_URL")
@@ -12,3 +15,18 @@ def sqla_uri():
1215

1316
def is_postgres_tests() -> bool:
1417
return "postgres" in sqla_uri()
18+
19+
20+
def is_sqlite_tests() -> bool:
21+
return "sqlite" in sqla_uri()
22+
23+
24+
@event.listens_for(Engine, "connect")
25+
def set_sqlite_pragma(dbapi_connection, connection_record):
26+
"""
27+
https://docs.sqlalchemy.org/en/14/dialects/sqlite.html#foreign-key-support
28+
"""
29+
if is_sqlite_tests():
30+
cursor = dbapi_connection.cursor()
31+
cursor.execute("PRAGMA foreign_keys=ON")
32+
cursor.close()

Diff for: tests/models.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from sqlalchemy import JSON, Column, DateTime, ForeignKey, Index, Integer, String, Text
55
from sqlalchemy.ext.declarative import declarative_base
6-
from sqlalchemy.orm import declared_attr, relationship
6+
from sqlalchemy.orm import backref, declared_attr, relationship
77
from sqlalchemy.types import CHAR, TypeDecorator
88

99
from tests.common import is_postgres_tests, sqla_uri
@@ -390,3 +390,25 @@ class Delta(Base):
390390
name = Column(String)
391391
gammas: List["Gamma"] = relationship("Gamma", back_populates="delta", lazy="noload")
392392
betas: List["Beta"] = relationship("Beta", secondary="beta_delta_binding", back_populates="deltas", lazy="noload")
393+
394+
395+
class CascadeCase(Base):
396+
__tablename__ = "cascade_case"
397+
398+
id = Column(Integer, primary_key=True, autoincrement=True)
399+
parent_item_id = Column(
400+
Integer,
401+
ForeignKey(
402+
"cascade_case.id",
403+
onupdate="CASCADE",
404+
ondelete="CASCADE",
405+
),
406+
nullable=True,
407+
)
408+
sub_items = relationship(
409+
"CascadeCase",
410+
backref=backref("parent_item", remote_side=[id]),
411+
)
412+
413+
if TYPE_CHECKING:
414+
parent_item: Optional["CascadeCase"]

Diff for: tests/schemas.py

+14
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,20 @@ class SelfRelationshipSchema(BaseModel):
409409
)
410410

411411

412+
class CascadeCaseSchema(BaseModel):
413+
parent_item: Optional["CascadeCaseSchema"] = Field(
414+
relationship=RelationshipInfo(
415+
resource_type="cascade_case",
416+
),
417+
)
418+
sub_items: Optional[list["CascadeCaseSchema"]] = Field(
419+
relationship=RelationshipInfo(
420+
resource_type="cascade_case",
421+
many=True,
422+
),
423+
)
424+
425+
412426
class CustomUserAttributesSchema(UserBaseSchema):
413427
spam: str
414428
eggs: str

Diff for: tests/test_api/test_api_sqla_with_includes.py

+125
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import logging
33
from collections import defaultdict
4+
from contextlib import suppress
45
from datetime import datetime, timezone
56
from itertools import chain, zip_longest
67
from json import dumps, loads
@@ -18,6 +19,7 @@
1819
from sqlalchemy.orm import InstrumentedAttribute
1920
from starlette.datastructures import QueryParams
2021

22+
from fastapi_jsonapi.api import RoutersJSONAPI
2123
from fastapi_jsonapi.views.view_base import ViewBase
2224
from tests.common import is_postgres_tests
2325
from tests.fixtures.app import build_alphabet_app, build_app_custom
@@ -31,6 +33,7 @@
3133
from tests.models import (
3234
Alpha,
3335
Beta,
36+
CascadeCase,
3437
Computer,
3538
ContainsTimestamp,
3639
CustomUUIDItem,
@@ -44,6 +47,7 @@
4447
Workplace,
4548
)
4649
from tests.schemas import (
50+
CascadeCaseSchema,
4751
CustomUserAttributesSchema,
4852
CustomUUIDItemAttributesSchema,
4953
PostAttributesBaseSchema,
@@ -1744,6 +1748,87 @@ async def test_select_custom_fields(
17441748
"meta": None,
17451749
}
17461750

1751+
@mark.parametrize("check_type", ["ok", "fail"])
1752+
async def test_update_to_many_relationships(self, async_session: AsyncSession, check_type: Literal["ok", "fail"]):
1753+
resource_type = "cascade_case"
1754+
with suppress(KeyError):
1755+
RoutersJSONAPI.all_jsonapi_routers.pop(resource_type)
1756+
1757+
app = build_app_custom(
1758+
model=CascadeCase,
1759+
schema=CascadeCaseSchema,
1760+
resource_type=resource_type,
1761+
)
1762+
1763+
top_item = CascadeCase()
1764+
new_top_item = CascadeCase()
1765+
sub_item_1 = CascadeCase(parent_item=top_item)
1766+
sub_item_2 = CascadeCase(parent_item=top_item)
1767+
async_session.add_all(
1768+
[
1769+
top_item,
1770+
new_top_item,
1771+
sub_item_1,
1772+
sub_item_2,
1773+
],
1774+
)
1775+
await async_session.commit()
1776+
1777+
assert sub_item_1.parent_item_id == top_item.id
1778+
assert sub_item_2.parent_item_id == top_item.id
1779+
1780+
async with AsyncClient(app=app, base_url="http://test") as client:
1781+
params = None
1782+
if check_type == "ok":
1783+
params = {"include": "sub_items"}
1784+
1785+
update_body = {
1786+
"type": resource_type,
1787+
"data": {
1788+
"id": new_top_item.id,
1789+
"attributes": {},
1790+
"relationships": {
1791+
"sub_items": {
1792+
"data": [
1793+
{
1794+
"type": resource_type,
1795+
"id": sub_item_1.id,
1796+
},
1797+
{
1798+
"type": resource_type,
1799+
"id": sub_item_2.id,
1800+
},
1801+
],
1802+
},
1803+
},
1804+
},
1805+
}
1806+
url = app.url_path_for(f"update_{resource_type}_detail", obj_id=new_top_item.id)
1807+
1808+
res = await client.patch(url, params=params, json=update_body)
1809+
1810+
if check_type == "ok":
1811+
assert res.status_code == status.HTTP_200_OK, res.text
1812+
1813+
await async_session.refresh(sub_item_1)
1814+
await async_session.refresh(sub_item_2)
1815+
await async_session.refresh(top_item)
1816+
assert sub_item_1.parent_item_id == new_top_item.id
1817+
assert sub_item_1.parent_item_id == new_top_item.id
1818+
else:
1819+
assert res.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR, res.text
1820+
assert res.json() == {
1821+
"errors": [
1822+
{
1823+
"detail": "Error of loading the 'sub_items' relationship. "
1824+
"Please add this relationship to include query parameter explicitly.",
1825+
"source": {"parameter": "include"},
1826+
"status_code": status.HTTP_500_INTERNAL_SERVER_ERROR,
1827+
"title": "Internal Server Error",
1828+
},
1829+
],
1830+
}
1831+
17471832

17481833
class TestPatchObjectRelationshipsToOne:
17491834
async def test_ok_when_foreign_key_of_related_object_is_nullable(
@@ -2247,6 +2332,46 @@ async def test_select_custom_fields(
22472332
"meta": {"count": 2, "totalPages": 1},
22482333
}
22492334

2335+
async def test_cascade_delete(self, async_session: AsyncSession):
2336+
resource_type = "cascade_case"
2337+
with suppress(KeyError):
2338+
RoutersJSONAPI.all_jsonapi_routers.pop(resource_type)
2339+
2340+
app = build_app_custom(
2341+
model=CascadeCase,
2342+
schema=CascadeCaseSchema,
2343+
resource_type=resource_type,
2344+
)
2345+
2346+
top_item = CascadeCase()
2347+
sub_item_1 = CascadeCase(parent_item=top_item)
2348+
sub_item_2 = CascadeCase(parent_item=top_item)
2349+
async_session.add_all(
2350+
[
2351+
top_item,
2352+
sub_item_1,
2353+
sub_item_2,
2354+
],
2355+
)
2356+
await async_session.commit()
2357+
2358+
assert sub_item_1.parent_item_id == top_item.id
2359+
assert sub_item_2.parent_item_id == top_item.id
2360+
2361+
async with AsyncClient(app=app, base_url="http://test") as client:
2362+
url = app.url_path_for(f"delete_{resource_type}_detail", obj_id=top_item.id)
2363+
2364+
res = await client.delete(url)
2365+
assert res.status_code == status.HTTP_204_NO_CONTENT, res.text
2366+
2367+
top_item_stmt = select(CascadeCase).where(CascadeCase.id == top_item.id)
2368+
top_item = (await async_session.execute(top_item_stmt)).one_or_none()
2369+
assert top_item is None
2370+
2371+
sub_items_stmt = select(CascadeCase).where(CascadeCase.id.in_([sub_item_1.id, sub_item_2.id]))
2372+
sub_items = (await async_session.execute(sub_items_stmt)).all()
2373+
assert sub_items == []
2374+
22502375

22512376
class TestOpenApi:
22522377
def test_openapi_method_ok(self, app: FastAPI):

0 commit comments

Comments
 (0)