Skip to content

Commit f6dc339

Browse files
committed
Add config option to load relationship fields.
Due to the recursive loading problem loading relations is not yet possible. This change introduces the config option 'include_relations' to also load specifically chosen relations. For the current use case nothing changes unless the user specifically sets fields to be included (and carefully considers the risks of circular includes). The option is very valuable if the table design contains many 1:n relations.
1 parent 75ce455 commit f6dc339

File tree

2 files changed

+289
-4
lines changed

2 files changed

+289
-4
lines changed

sqlmodel/main.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -615,10 +615,16 @@ def _calculate_keys(
615615
if include is None and exclude is None and not exclude_unset:
616616
# Original in Pydantic:
617617
# return None
618-
# Updated to not return SQLAlchemy attributes
619-
# Do not include relationships as that would easily lead to infinite
620-
# recursion, or traversing the whole database
621-
return self.__fields__.keys() # | self.__sqlmodel_relationships__.keys()
618+
# updated to only return SQLAlchemy attributes
619+
# if include_relations is set in the Config for a model
620+
# Otherwise do not include relationships as that would easily lead
621+
# to infinite recursion, or traversing the whole database
622+
model_keys = set(self.__fields__.keys())
623+
include_relations = getattr(self.Config(), "include_relations", {})
624+
for relation_key in self.__sqlmodel_relationships__.keys():
625+
if relation_key in include_relations:
626+
model_keys.add(relation_key)
627+
return model_keys
622628

623629
keys: AbstractSet[str]
624630
if exclude_unset:

tests/test_relation_resolution.py

+279
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
from typing import List, Optional
2+
3+
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select
4+
5+
6+
def test_relation_resolution_if_include_relations_not_set(clear_sqlmodel):
7+
class Team(SQLModel, table=True):
8+
id: Optional[int] = Field(default=None, primary_key=True)
9+
name: str
10+
heroes: List["Hero"] = Relationship(back_populates="team") # noqa: F821
11+
12+
class Config:
13+
orm_mode = True
14+
15+
class Hero(SQLModel, table=True):
16+
id: Optional[int] = Field(default=None, primary_key=True)
17+
name: str
18+
team_id: Optional[int] = Field(default=None, foreign_key="team.id")
19+
team: Optional[Team] = Relationship(back_populates="heroes")
20+
21+
hero_1 = Hero(name="Deadpond")
22+
hero_2 = Hero(name="PhD Strange")
23+
team = Team(name="Marble", heroes=[hero_1, hero_2])
24+
25+
engine = create_engine("sqlite://")
26+
27+
SQLModel.metadata.create_all(engine)
28+
29+
with Session(engine) as session:
30+
session.add(team)
31+
session.commit()
32+
session.refresh(team)
33+
keys = team._calculate_keys(include=None, exclude=None, exclude_unset=False)
34+
35+
# expected not to include the relationship "heroes" since this
36+
# fields since the relationship field was not enabled in
37+
# Config.include_relations
38+
assert keys == {"id", "name"}
39+
40+
41+
def test_relation_resolution_if_include_relations_is_set(clear_sqlmodel):
42+
class Team(SQLModel, table=True):
43+
id: Optional[int] = Field(default=None, primary_key=True)
44+
name: str
45+
heroes: List["Hero"] = Relationship(back_populates="team") # noqa: F821
46+
47+
class Config:
48+
orm_mode = True
49+
include_relations = {"heroes"}
50+
51+
class Hero(SQLModel, table=True):
52+
id: Optional[int] = Field(default=None, primary_key=True)
53+
name: str
54+
team_id: Optional[int] = Field(default=None, foreign_key="team.id")
55+
team: Optional[Team] = Relationship(back_populates="heroes")
56+
57+
hero_1 = Hero(name="Deadpond")
58+
hero_2 = Hero(name="PhD Strange")
59+
team = Team(name="Marble", heroes=[hero_1, hero_2])
60+
61+
engine = create_engine("sqlite://")
62+
63+
SQLModel.metadata.create_all(engine)
64+
65+
with Session(engine) as session:
66+
session.add(team)
67+
session.commit()
68+
session.refresh(team)
69+
keys = team._calculate_keys(include=None, exclude=None, exclude_unset=False)
70+
71+
# expected to include the relationship "heroes" since this
72+
# fields was enabled in Config.include_relations
73+
assert keys == {"id", "name", "heroes"}
74+
75+
76+
def test_relation_resolution_if_include_relations_is_set_for_nested(clear_sqlmodel):
77+
class Team(SQLModel, table=True):
78+
id: Optional[int] = Field(default=None, primary_key=True)
79+
name: str
80+
heroes: List["Hero"] = Relationship(back_populates="team") # noqa: F821
81+
82+
class Config:
83+
orm_mode = True
84+
include_relations = {"heroes"}
85+
86+
class Hero(SQLModel, table=True):
87+
id: Optional[int] = Field(default=None, primary_key=True)
88+
name: str
89+
powers: List["Power"] = Relationship(back_populates="hero") # noqa: F821
90+
team_id: Optional[int] = Field(default=None, foreign_key="team.id")
91+
team: Optional[Team] = Relationship(back_populates="heroes")
92+
93+
class Config:
94+
orm_mode = True
95+
include_relations = {"powers"}
96+
97+
class Power(SQLModel, table=True):
98+
id: Optional[int] = Field(default=None, primary_key=True)
99+
description: str
100+
hero_id: Optional[int] = Field(default=None, foreign_key="hero.id")
101+
hero: Optional[Hero] = Relationship(back_populates="powers")
102+
103+
power_hero_1 = Power(description="Healing Power")
104+
power_hero_2 = Power(description="Levitating Cloak")
105+
hero_1 = Hero(name="Deadpond", powers=[power_hero_1])
106+
hero_2 = Hero(name="PhD Strange", powers=[power_hero_2])
107+
team = Team(name="Marble", heroes=[hero_1, hero_2])
108+
109+
engine = create_engine("sqlite://")
110+
111+
SQLModel.metadata.create_all(engine)
112+
113+
with Session(engine) as session:
114+
session.add(team)
115+
session.commit()
116+
session.refresh(team)
117+
session.refresh(hero_1)
118+
team_keys = team._calculate_keys(include=None, exclude=None, exclude_unset=False)
119+
hero_1_keys = hero_1._calculate_keys(
120+
include=None, exclude=None, exclude_unset=False
121+
)
122+
123+
assert team_keys == {"id", "name", "heroes"}
124+
assert hero_1_keys == {"id", "name", "powers", "team_id"}
125+
126+
127+
def test_relation_resolution_if_lazy_selectin_not_set_with_fastapi(clear_sqlmodel):
128+
class Team(SQLModel, table=True):
129+
id: Optional[int] = Field(default=None, primary_key=True)
130+
name: str
131+
heroes: List["Hero"] = Relationship(back_populates="team") # noqa: F821
132+
133+
class Config:
134+
orm_mode = True
135+
include_relations = {"heroes"}
136+
137+
class Hero(SQLModel, table=True):
138+
id: Optional[int] = Field(default=None, primary_key=True)
139+
name: str
140+
powers: List["Power"] = Relationship(back_populates="hero") # noqa: F821
141+
team_id: Optional[int] = Field(default=None, foreign_key="team.id")
142+
team: Optional[Team] = Relationship(back_populates="heroes")
143+
144+
class Config:
145+
orm_mode = True
146+
include_relations = {"powers"}
147+
148+
class Power(SQLModel, table=True):
149+
id: Optional[int] = Field(default=None, primary_key=True)
150+
description: str
151+
hero_id: Optional[int] = Field(default=None, foreign_key="hero.id")
152+
hero: Optional[Hero] = Relationship(back_populates="powers")
153+
154+
power_hero_1 = Power(description="Healing Power")
155+
power_hero_2 = Power(description="Levitating Cloak")
156+
hero_1 = Hero(name="Deadpond", powers=[power_hero_1])
157+
hero_2 = Hero(name="PhD Strange", powers=[power_hero_2])
158+
team = Team(name="Marble", heroes=[hero_1, hero_2])
159+
160+
engine = create_engine("sqlite://")
161+
162+
SQLModel.metadata.create_all(engine)
163+
164+
with Session(engine) as session:
165+
session.add(team)
166+
session.commit()
167+
session.refresh(team)
168+
169+
from fastapi import FastAPI
170+
from fastapi.testclient import TestClient
171+
172+
app = FastAPI()
173+
174+
@app.get("/")
175+
async def read_main(response_model=List[Team]):
176+
with Session(engine) as session:
177+
teams = session.execute(select(Team)).all()
178+
return teams
179+
180+
client = TestClient(app)
181+
teams = client.get("/")
182+
expected_json = [{"Team": {"name": "Marble", "id": 1}}]
183+
184+
# if sa_relationship_kwargs={"lazy": "selectin"}) not set in relation
185+
# there is no effect on the relations even though the Config was set
186+
# to load the relation fields.
187+
assert teams.json() == expected_json
188+
189+
190+
def test_relation_resolution_if_lazy_selectin_is_set_with_fastapi(clear_sqlmodel):
191+
class Team(SQLModel, table=True):
192+
id: Optional[int] = Field(default=None, primary_key=True)
193+
name: str
194+
heroes: List["Hero"] = Relationship( # noqa: F821
195+
back_populates="team", sa_relationship_kwargs={"lazy": "selectin"}
196+
)
197+
198+
class Config:
199+
orm_mode = True
200+
include_relations = {"heroes"}
201+
202+
class Hero(SQLModel, table=True):
203+
id: Optional[int] = Field(default=None, primary_key=True)
204+
name: str
205+
powers: List["Power"] = Relationship( # noqa: F821
206+
back_populates="hero", sa_relationship_kwargs={"lazy": "selectin"}
207+
)
208+
team_id: Optional[int] = Field(default=None, foreign_key="team.id")
209+
team: Optional[Team] = Relationship(back_populates="heroes")
210+
211+
class Config:
212+
orm_mode = True
213+
include_relations = {"powers"}
214+
215+
class Power(SQLModel, table=True):
216+
id: Optional[int] = Field(default=None, primary_key=True)
217+
description: str
218+
hero_id: Optional[int] = Field(default=None, foreign_key="hero.id")
219+
hero: Optional[Hero] = Relationship(back_populates="powers")
220+
221+
power_hero_1 = Power(description="Healing Power")
222+
power_hero_2 = Power(description="Levitating Cloak")
223+
hero_1 = Hero(name="Deadpond", powers=[power_hero_1])
224+
hero_2 = Hero(name="PhD Strange", powers=[power_hero_2])
225+
team = Team(name="Marble", heroes=[hero_1, hero_2])
226+
227+
engine = create_engine("sqlite://")
228+
229+
SQLModel.metadata.create_all(engine)
230+
231+
with Session(engine) as session:
232+
session.add(team)
233+
session.commit()
234+
session.refresh(team)
235+
236+
from fastapi import FastAPI
237+
from fastapi.testclient import TestClient
238+
239+
app = FastAPI()
240+
241+
@app.get("/")
242+
async def read_main(response_model=List[Team]):
243+
with Session(engine) as session:
244+
teams = session.execute(select(Team)).all()
245+
return teams
246+
247+
client = TestClient(app)
248+
teams = client.get("/")
249+
expected_json = [
250+
{
251+
"Team": {
252+
"name": "Marble",
253+
"id": 1,
254+
"heroes": [
255+
{
256+
"id": 1,
257+
"team_id": 1,
258+
"name": "Deadpond",
259+
"powers": [
260+
{"id": 1, "hero_id": 1, "description": "Healing Power"}
261+
],
262+
},
263+
{
264+
"id": 2,
265+
"team_id": 1,
266+
"name": "PhD Strange",
267+
"powers": [
268+
{"id": 2, "hero_id": 2, "description": "Levitating Cloak"}
269+
],
270+
},
271+
],
272+
}
273+
}
274+
]
275+
276+
# if sa_relationship_kwargs={"lazy": "selectin"}) is set
277+
# the relations in the Config are considered and the relation fields are
278+
# included in the response.
279+
assert teams.json() == expected_json

0 commit comments

Comments
 (0)