diff --git a/app/models.py b/app/models.py index 0f0a9e4..8aa8cde 100644 --- a/app/models.py +++ b/app/models.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Union, List, Any +from typing import Union, Any, List from typing_extensions import Self from sqlmodel import SQLModel, Field, Relationship @@ -162,6 +162,7 @@ class PriceHistoryPublic(PriceHistoryBase): product_id: str +# Other models (unchanged) class ItemStats(BaseModel): calories: float | None proteins: float | None @@ -175,6 +176,19 @@ class ItemStats(BaseModel): kcal_per_euro: float | None +class TicketItem(BaseModel): + product: ProductPublic + original_name: str + quantity: int + unit_price: float + total_price: float + stats: ItemStats | None + + +class TicketStats(BaseModel): + items: List[TicketItem] + + class ProductMatch(BaseModel): score: float product: ProductPublic @@ -186,108 +200,42 @@ def default_quantity(v: Any) -> int: return v -class TicketBase(SQLModel): +class TicketInfo(BaseModel): ticket_number: int | None = None date: str | None = None time: str | None = None total_price: float | None = None - # Convert datetime to ISO format string for JSON serialization - processed_at: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) - - -class TicketItemBase(SQLModel): - name: str - quantity: int - total_price: float - unit_price: float - - -class Ticket(TicketBase, table=True): - id: int = Field(default=None, primary_key=True) - items: List["TicketItem"] = Relationship(back_populates="ticket") + items: list[TicketItem] @model_validator(mode="after") - def calculate_total_price(self) -> Self: - if self.total_price is None: - self.total_price = sum(item.total_price for item in self.items) - return self - - -class TicketItem(TicketItemBase, table=True): - id: int = Field(default=None, primary_key=True) - ticket_id: int = Field(foreign_key="ticket.id") - ticket: Ticket = Relationship(back_populates="items") - matched_product_id: str | None = Field(default=None, foreign_key="product.id") - matched_product: "Product" = Relationship() - - @model_validator(mode="before") - def calculate_unit_price(cls, values: dict) -> dict: - if "quantity" in values and "total_price" in values: - values["unit_price"] = ( - values["total_price"] / values["quantity"] - if values["quantity"] > 0 - else 0.0 - ) - return values - - -class ExtractedTicketItem(TicketItemBase): - @model_validator(mode="before") - def calculate_unit_price(cls, values: dict) -> dict: - if "quantity" in values and "total_price" in values: - values["unit_price"] = ( - values["total_price"] / values["quantity"] - if values["quantity"] > 0 - else 0.0 - ) - return values - - -class ExtractedTicketInfo(TicketBase): - items: List[ExtractedTicketItem] - - @model_validator(mode="after") - def calculate_total_price(self) -> Self: + def guess_total(self): if self.total_price is None: self.total_price = sum( item.total_price for item in self.items if item.total_price is not None ) return self - def to_db_models(self) -> tuple[Ticket, List[TicketItem]]: - """Convert the extracted ticket info to database models""" - ticket = Ticket( - ticket_number=self.ticket_number, - date=self.date, - time=self.time, - total_price=self.total_price, - processed_at=datetime.utcnow().isoformat(), # Store as ISO format string - ) - - ticket_items = [ - TicketItem( - name=item.name, - quantity=item.quantity, - total_price=item.total_price, - unit_price=item.unit_price, - ) - for item in self.items - ] - - return ticket, ticket_items +class ExtractedTicketItem(BaseModel): + name: str + quantity: int + total_price: float + unit_price: float -# Public models -class TicketItemPublic(TicketItemBase): - id: int - ticket_id: int - matched_product: ProductPublic | None = None - + @model_validator(mode="before") + def calculate_total_price(cls, values: Any) -> Any: + values["unit_price"] = ( + values["total_price"] / values["quantity"] + if values["quantity"] > 0 + else 0.0 + ) -class TicketPublic(TicketBase): - id: int - items: List[TicketItemPublic] + return values -class TicketStats(BaseModel): - items: List[TicketItem] +class ExtractedTicketInfo(BaseModel): + ticket_number: int | None + date: str | None + time: str | None + total_price: float | None + items: List[ExtractedTicketItem] diff --git a/app/routers/ticket.py b/app/routers/ticket.py index c38a7d4..b5eb385 100644 --- a/app/routers/ticket.py +++ b/app/routers/ticket.py @@ -115,15 +115,6 @@ async def process_ticket( temp_file_path, TICKET_PROMPT ) - # Save ticket information to database - ticket, ticket_items = ticket_info.to_db_models() - session.add(ticket) - session.flush() # Flush to get the ticket ID - - for item in ticket_items: - item.ticket_id = ticket.id - session.add(item) - # Create a group of tasks for product matching product_tasks = group( [ @@ -139,7 +130,9 @@ async def process_ticket( # Wait for all tasks to complete with timeout with allow_join_result(): try: - results = group_result.get(timeout=20) + results = group_result.get( + timeout=20 + ) # 20 second timeout for entire group except Exception as e: logger.error(f"Error waiting for product matching tasks: {str(e)}") raise HTTPException( @@ -147,38 +140,35 @@ async def process_ticket( ) # Process results and create ticket items - ticket_items_response = [] - for ti, result, db_item in zip(ticket_info.items, results, ticket_items): + ticket_items = [] + for item, result in zip(ticket_info.items, results): if not result: - logger.warning(f"No match found for product '{ti.name}'") + logger.warning(f"No match found for product '{item.name}'") continue product_match = ProductMatch.model_validate(result[0]) product = product_match.product logger.info( - f"Best match for '{ti.name}': {product.name} (Score: {product_match.score:.2f})" + f"Best match for '{item.name}': {product.name} (Score: {product_match.score:.2f})" ) - # Update the database item with the matched product - db_item.matched_product_id = product.id - - item_stats = calculate_item_stats(product, ti.quantity, ti.total_price or 0) + item_stats = calculate_item_stats( + product, item.quantity, item.total_price or 0 + ) ticket_item = TicketItem( product=ProductPublic.model_validate(product), - original_name=ti.name, - quantity=ti.quantity, - unit_price=ti.unit_price or 0, - total_price=ti.total_price or 0, + original_name=item.name, + quantity=item.quantity, + unit_price=item.unit_price or 0, + total_price=item.total_price or 0, stats=item_stats, ) - ticket_items_response.append(ticket_item) + ticket_items.append(ticket_item) - session.commit() - return TicketStats(items=ticket_items_response) + return TicketStats(items=ticket_items) except Exception as e: - session.rollback() logger.error(f"Error processing ticket: {str(e)}") raise HTTPException( status_code=500, detail=f"Error processing ticket: {str(e)}" diff --git a/migrations/versions/783f148d39eb_create_extracted_ticket_info_table.py b/migrations/versions/783f148d39eb_create_extracted_ticket_info_table.py deleted file mode 100644 index 7c24c8b..0000000 --- a/migrations/versions/783f148d39eb_create_extracted_ticket_info_table.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Create extracted ticket info table - -Revision ID: 783f148d39eb -Revises: feff058e9105 -Create Date: 2024-10-22 19:54:52.355993 - -""" - -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa -import sqlmodel - - -# revision identifiers, used by Alembic. -revision: str = "783f148d39eb" -down_revision: Union[str, None] = "feff058e9105" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "ticket", - sa.Column("ticket_number", sa.Integer(), nullable=True), - sa.Column("date", sqlmodel.sql.sqltypes.AutoString(), nullable=True), - sa.Column("time", sqlmodel.sql.sqltypes.AutoString(), nullable=True), - sa.Column("total_price", sa.Float(), nullable=True), - sa.Column("processed_at", sa.DateTime(), nullable=False), - sa.Column("id", sa.Integer(), nullable=False), - sa.PrimaryKeyConstraint("id"), - ) - op.create_table( - "ticketitem", - sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column("quantity", sa.Integer(), nullable=False), - sa.Column("total_price", sa.Float(), nullable=False), - sa.Column("unit_price", sa.Float(), nullable=False), - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("ticket_id", sa.Integer(), nullable=False), - sa.Column( - "matched_product_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True - ), - sa.ForeignKeyConstraint( - ["matched_product_id"], - ["product.id"], - ), - sa.ForeignKeyConstraint( - ["ticket_id"], - ["ticket.id"], - ), - sa.PrimaryKeyConstraint("id"), - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("ticketitem") - op.drop_table("ticket") - # ### end Alembic commands ###