Skip to content

Commit

Permalink
Revert "feat: track extracted ticket info"
Browse files Browse the repository at this point in the history
This reverts commit ef83ce0.
  • Loading branch information
m0wer committed Oct 22, 2024
1 parent ef83ce0 commit 8a53474
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 178 deletions.
126 changes: 37 additions & 89 deletions app/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Union, List, Any
from typing import Union, Any, List
from typing_extensions import Self

from sqlmodel import SQLModel, Field, Relationship
Expand Down Expand Up @@ -162,6 +162,7 @@ class PriceHistoryPublic(PriceHistoryBase):
product_id: str


# Other models (unchanged)
class ItemStats(BaseModel):
calories: float | None
proteins: float | None
Expand All @@ -175,6 +176,19 @@ class ItemStats(BaseModel):
kcal_per_euro: float | None


class TicketItem(BaseModel):
product: ProductPublic
original_name: str
quantity: int
unit_price: float
total_price: float
stats: ItemStats | None


class TicketStats(BaseModel):
items: List[TicketItem]


class ProductMatch(BaseModel):
score: float
product: ProductPublic
Expand All @@ -186,108 +200,42 @@ def default_quantity(v: Any) -> int:
return v


class TicketBase(SQLModel):
class TicketInfo(BaseModel):
ticket_number: int | None = None
date: str | None = None
time: str | None = None
total_price: float | None = None
# Convert datetime to ISO format string for JSON serialization
processed_at: str = Field(default_factory=lambda: datetime.utcnow().isoformat())


class TicketItemBase(SQLModel):
name: str
quantity: int
total_price: float
unit_price: float


class Ticket(TicketBase, table=True):
id: int = Field(default=None, primary_key=True)
items: List["TicketItem"] = Relationship(back_populates="ticket")
items: list[TicketItem]

@model_validator(mode="after")
def calculate_total_price(self) -> Self:
if self.total_price is None:
self.total_price = sum(item.total_price for item in self.items)
return self


class TicketItem(TicketItemBase, table=True):
id: int = Field(default=None, primary_key=True)
ticket_id: int = Field(foreign_key="ticket.id")
ticket: Ticket = Relationship(back_populates="items")
matched_product_id: str | None = Field(default=None, foreign_key="product.id")
matched_product: "Product" = Relationship()

@model_validator(mode="before")
def calculate_unit_price(cls, values: dict) -> dict:
if "quantity" in values and "total_price" in values:
values["unit_price"] = (
values["total_price"] / values["quantity"]
if values["quantity"] > 0
else 0.0
)
return values


class ExtractedTicketItem(TicketItemBase):
@model_validator(mode="before")
def calculate_unit_price(cls, values: dict) -> dict:
if "quantity" in values and "total_price" in values:
values["unit_price"] = (
values["total_price"] / values["quantity"]
if values["quantity"] > 0
else 0.0
)
return values


class ExtractedTicketInfo(TicketBase):
items: List[ExtractedTicketItem]

@model_validator(mode="after")
def calculate_total_price(self) -> Self:
def guess_total(self):
if self.total_price is None:
self.total_price = sum(
item.total_price for item in self.items if item.total_price is not None
)
return self

def to_db_models(self) -> tuple[Ticket, List[TicketItem]]:
"""Convert the extracted ticket info to database models"""
ticket = Ticket(
ticket_number=self.ticket_number,
date=self.date,
time=self.time,
total_price=self.total_price,
processed_at=datetime.utcnow().isoformat(), # Store as ISO format string
)

ticket_items = [
TicketItem(
name=item.name,
quantity=item.quantity,
total_price=item.total_price,
unit_price=item.unit_price,
)
for item in self.items
]

return ticket, ticket_items

class ExtractedTicketItem(BaseModel):
name: str
quantity: int
total_price: float
unit_price: float

# Public models
class TicketItemPublic(TicketItemBase):
id: int
ticket_id: int
matched_product: ProductPublic | None = None

@model_validator(mode="before")
def calculate_total_price(cls, values: Any) -> Any:
values["unit_price"] = (
values["total_price"] / values["quantity"]
if values["quantity"] > 0
else 0.0
)

class TicketPublic(TicketBase):
id: int
items: List[TicketItemPublic]
return values


class TicketStats(BaseModel):
items: List[TicketItem]
class ExtractedTicketInfo(BaseModel):
ticket_number: int | None
date: str | None
time: str | None
total_price: float | None
items: List[ExtractedTicketItem]
42 changes: 16 additions & 26 deletions app/routers/ticket.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,6 @@ async def process_ticket(
temp_file_path, TICKET_PROMPT
)

# Save ticket information to database
ticket, ticket_items = ticket_info.to_db_models()
session.add(ticket)
session.flush() # Flush to get the ticket ID

for item in ticket_items:
item.ticket_id = ticket.id
session.add(item)

# Create a group of tasks for product matching
product_tasks = group(
[
Expand All @@ -139,46 +130,45 @@ async def process_ticket(
# Wait for all tasks to complete with timeout
with allow_join_result():
try:
results = group_result.get(timeout=20)
results = group_result.get(
timeout=20
) # 20 second timeout for entire group
except Exception as e:
logger.error(f"Error waiting for product matching tasks: {str(e)}")
raise HTTPException(
status_code=500, detail="Timeout or error while matching products"
)

# Process results and create ticket items
ticket_items_response = []
for ti, result, db_item in zip(ticket_info.items, results, ticket_items):
ticket_items = []
for item, result in zip(ticket_info.items, results):
if not result:
logger.warning(f"No match found for product '{ti.name}'")
logger.warning(f"No match found for product '{item.name}'")
continue

product_match = ProductMatch.model_validate(result[0])
product = product_match.product
logger.info(
f"Best match for '{ti.name}': {product.name} (Score: {product_match.score:.2f})"
f"Best match for '{item.name}': {product.name} (Score: {product_match.score:.2f})"
)

# Update the database item with the matched product
db_item.matched_product_id = product.id

item_stats = calculate_item_stats(product, ti.quantity, ti.total_price or 0)
item_stats = calculate_item_stats(
product, item.quantity, item.total_price or 0
)

ticket_item = TicketItem(
product=ProductPublic.model_validate(product),
original_name=ti.name,
quantity=ti.quantity,
unit_price=ti.unit_price or 0,
total_price=ti.total_price or 0,
original_name=item.name,
quantity=item.quantity,
unit_price=item.unit_price or 0,
total_price=item.total_price or 0,
stats=item_stats,
)
ticket_items_response.append(ticket_item)
ticket_items.append(ticket_item)

session.commit()
return TicketStats(items=ticket_items_response)
return TicketStats(items=ticket_items)

except Exception as e:
session.rollback()
logger.error(f"Error processing ticket: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error processing ticket: {str(e)}"
Expand Down

This file was deleted.

0 comments on commit 8a53474

Please sign in to comment.