Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: avoid "Expected enum but got int" serializer warnings for chats API #602

Merged
merged 2 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions backend/app/alembic/versions/2adc0b597dcd_int_enum_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""int_enum_type

Revision ID: 2adc0b597dcd
Revises: a54f966436ce
Create Date: 2025-01-24 17:58:08.339090

"""

from alembic import op
from sqlalchemy.dialects import mysql

from app.models.base import IntEnumType

# revision identifiers, used by Alembic.
revision = "2adc0b597dcd"
down_revision = "a54f966436ce"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"chats",
"visibility",
existing_type=mysql.SMALLINT(),
type_=IntEnumType(),
existing_nullable=False,
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"chats",
"visibility",
existing_type=IntEnumType(),
type_=mysql.SMALLINT(),
existing_nullable=False,
)
# ### end Alembic commands ###
35 changes: 33 additions & 2 deletions backend/app/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from uuid import UUID
from datetime import datetime
from typing import Optional

from sqlmodel import Field, DateTime, func, SQLModel
from sqlalchemy.types import TypeDecorator, LargeBinary
from sqlalchemy.types import TypeDecorator, LargeBinary, Integer

from app.utils.uuid6 import uuid7
from app.utils.aes import AESCipher
Expand Down Expand Up @@ -52,3 +51,35 @@ def process_result_value(self, value, dialect):
json_str = AESCipher(get_aes_key()).decrypt(value)
return json.loads(json_str)
return value


class IntEnumType(TypeDecorator):
"""
IntEnumType is a custom TypeDecorator that handles conversion between
integer values in the database and Enum types in Python.

This replaces the previous SmallInteger implementation to resolve Pydantic
serialization warnings. When using SmallInteger, SQLAlchemy would return raw
integers from the database (e.g., 0 or 1), causing Pydantic validation warnings
since it expects proper Enum types.
"""

impl = Integer

def __init__(self, enum_class, *args, **kwargs):
super().__init__(*args, **kwargs)
self.enum_class = enum_class

def process_bind_param(self, value, dialect):
# enum -> int
if isinstance(value, self.enum_class):
return value.value
elif value is None:
return None
raise ValueError(f"Invalid value for {self.enum_class}: {value}")

def process_result_value(self, value, dialect):
# int -> enum
if value is not None:
return self.enum_class(value)
return None
9 changes: 6 additions & 3 deletions backend/app/models/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
Column,
DateTime,
JSON,
SmallInteger,
Relationship as SQLRelationship,
)

from .base import UUIDBaseModel, UpdatableBaseModel
from .base import IntEnumType, UUIDBaseModel, UpdatableBaseModel


class ChatVisibility(int, enum.Enum):
Expand Down Expand Up @@ -43,7 +42,11 @@ class Chat(UUIDBaseModel, UpdatableBaseModel, table=True):
browser_id: str = Field(max_length=50, nullable=True)
origin: str = Field(max_length=256, default=None, nullable=True)
visibility: ChatVisibility = Field(
sa_column=Column(SmallInteger, default=ChatVisibility.PRIVATE, nullable=False)
sa_column=Column(
IntEnumType(ChatVisibility),
nullable=False,
default=ChatVisibility.PRIVATE,
)
)

__tablename__ = "chats"
Expand Down
Loading