Skip to content

Commit

Permalink
improve code structure
Browse files Browse the repository at this point in the history
  • Loading branch information
John Lyu committed Dec 3, 2024
1 parent a3044bb commit 015601c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 18 deletions.
42 changes: 26 additions & 16 deletions sqlmodel/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,29 @@ def _is_union_type(t: Any) -> bool:
finish_init: ContextVar[bool] = ContextVar("finish_init", default=True)


def set_polymorphic_default_value(self_instance, values):
"""By defalut, when init a model, pydantic will set the polymorphic_on
value to field default value. But when inherit a model, the polymorphic_on
should be set to polymorphic_identity value by default."""
cls = type(self_instance)
mapper = inspect(cls)
if isinstance(mapper, Mapper):
polymorphic_on = mapper.polymorphic_on
if polymorphic_on is not None:
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
field_info = get_model_fields(cls).get(polymorphic_property.key)
if field_info:
v = values.get(polymorphic_property.key)
# if model is inherited or polymorphic_on is not explicitly set
# set the polymorphic_on by default
if mapper.inherits or v is None:
setattr(
self_instance,
polymorphic_property.key,
mapper.polymorphic_identity,
)


@contextmanager
def partial_init() -> Generator[None, None, None]:
token = finish_init.set(False)
Expand Down Expand Up @@ -293,22 +316,7 @@ def sqlmodel_table_construct(
setattr(self_instance, key, value)
# End SQLModel override
# Override polymorphic_on default value
mapper = inspect(cls)
if isinstance(mapper, Mapper):
polymorphic_on = mapper.polymorphic_on
if polymorphic_on is not None:
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
field_info = cls.model_fields.get(polymorphic_property.key)
if field_info:
v = values.get(polymorphic_property.key)
# if model is inherited or polymorphic_on is not explicitly set
# set the polymorphic_on by default
if mapper.inherits or v is None:
setattr(
self_instance,
polymorphic_property.key,
mapper.polymorphic_identity,
)
set_polymorphic_default_value(self_instance, values)
return self_instance

def sqlmodel_validate(
Expand Down Expand Up @@ -592,3 +600,5 @@ def sqlmodel_init(*, self: "SQLModel", data: Dict[str, Any]) -> None:
for key in non_pydantic_keys:
if key in self.__sqlmodel_relationships__:
setattr(self, key, data[key])
# Override polymorphic_on default value
set_polymorphic_default_value(self, values)
4 changes: 2 additions & 2 deletions tests/test_polymorphic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class DarkHero(Hero):


@needs_pydanticv2
def test_polymorphic_joined_table_sm_field(clear_sqlmodel) -> None:
def test_polymorphic_joined_table_with_sqlmodel_field(clear_sqlmodel) -> None:
class Hero(SQLModel, table=True):
__tablename__ = "hero"
id: Optional[int] = Field(default=None, primary_key=True)
Expand Down Expand Up @@ -123,7 +123,7 @@ class DarkHero(Hero):
with Session(engine) as db:
hero = Hero()
db.add(hero)
dark_hero = DarkHero()
dark_hero = DarkHero(dark_power="pokey")
db.add(dark_hero)
db.commit()
statement = select(DarkHero)
Expand Down

0 comments on commit 015601c

Please sign in to comment.