diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 10742d80d5..8a7e6fd75d 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -66,6 +66,29 @@ def _is_union_type(t: Any) -> bool: finish_init: ContextVar[bool] = ContextVar("finish_init", default=True) +def set_polymorphic_default_value(self_instance, values): + """By defalut, when init a model, pydantic will set the polymorphic_on + value to field default value. But when inherit a model, the polymorphic_on + should be set to polymorphic_identity value by default.""" + cls = type(self_instance) + mapper = inspect(cls) + if isinstance(mapper, Mapper): + polymorphic_on = mapper.polymorphic_on + if polymorphic_on is not None: + polymorphic_property = mapper.get_property_by_column(polymorphic_on) + field_info = get_model_fields(cls).get(polymorphic_property.key) + if field_info: + v = values.get(polymorphic_property.key) + # if model is inherited or polymorphic_on is not explicitly set + # set the polymorphic_on by default + if mapper.inherits or v is None: + setattr( + self_instance, + polymorphic_property.key, + mapper.polymorphic_identity, + ) + + @contextmanager def partial_init() -> Generator[None, None, None]: token = finish_init.set(False) @@ -293,22 +316,7 @@ def sqlmodel_table_construct( setattr(self_instance, key, value) # End SQLModel override # Override polymorphic_on default value - mapper = inspect(cls) - if isinstance(mapper, Mapper): - polymorphic_on = mapper.polymorphic_on - if polymorphic_on is not None: - polymorphic_property = mapper.get_property_by_column(polymorphic_on) - field_info = cls.model_fields.get(polymorphic_property.key) - if field_info: - v = values.get(polymorphic_property.key) - # if model is inherited or polymorphic_on is not explicitly set - # set the polymorphic_on by default - if mapper.inherits or v is None: - setattr( - self_instance, - polymorphic_property.key, - mapper.polymorphic_identity, - ) + set_polymorphic_default_value(self_instance, values) return self_instance def sqlmodel_validate( @@ -592,3 +600,5 @@ def sqlmodel_init(*, self: "SQLModel", data: Dict[str, Any]) -> None: for key in non_pydantic_keys: if key in self.__sqlmodel_relationships__: setattr(self, key, data[key]) + # Override polymorphic_on default value + set_polymorphic_default_value(self, values) diff --git a/tests/test_polymorphic_model.py b/tests/test_polymorphic_model.py index 8cded1bac5..f17e030a86 100644 --- a/tests/test_polymorphic_model.py +++ b/tests/test_polymorphic_model.py @@ -51,7 +51,7 @@ class DarkHero(Hero): @needs_pydanticv2 -def test_polymorphic_joined_table_sm_field(clear_sqlmodel) -> None: +def test_polymorphic_joined_table_with_sqlmodel_field(clear_sqlmodel) -> None: class Hero(SQLModel, table=True): __tablename__ = "hero" id: Optional[int] = Field(default=None, primary_key=True) @@ -123,7 +123,7 @@ class DarkHero(Hero): with Session(engine) as db: hero = Hero() db.add(hero) - dark_hero = DarkHero() + dark_hero = DarkHero(dark_power="pokey") db.add(dark_hero) db.commit() statement = select(DarkHero)