diff --git a/app/shared/product_matcher.py b/app/shared/product_matcher.py index 8352a6b..22af047 100644 --- a/app/shared/product_matcher.py +++ b/app/shared/product_matcher.py @@ -1,19 +1,20 @@ import unidecode -from fuzzywuzzy import fuzz +from thefuzz import process from loguru import logger from app.models import ProductMatch, ProductPublic from typing import List, Optional def find_closest_products_task( - products: List[ProductPublic] = [], + products: List[ProductPublic], item_name: Optional[str] = None, item_price: Optional[float] = None, threshold: float = 60.0, max_matches: int = 10, + price_tolerance: float = 0.2, ) -> List[ProductMatch]: """ - Task to find closest products based on name and price similarity. + Find closest products based on name and price similarity using thefuzz process. """ logger.info( f"Processing product matching task for '{item_name}' with price {item_price}" @@ -21,33 +22,49 @@ def find_closest_products_task( if not products: logger.warning("No products provided for matching.") return [] - matches = [] - - for product in products: - name_score: float = 0.0 - price_score: float = 0.0 - - if item_name is not None: - name_score = fuzz.token_set_ratio( - unidecode.unidecode(item_name.lower()), - unidecode.unidecode(product.name.lower()), - ) - if item_price: - price_diff = abs(product.price - item_price) - price_score = max(0, 100 - (price_diff / item_price) * 100) - - combined_score = (name_score * 0.7) + (price_score * 0.3) - - if combined_score >= threshold: - matches.append( - ProductMatch( - score=combined_score, product=ProductPublic.model_validate(product) + matches = [] + if item_name is not None: + product_names = [unidecode.unidecode(p.name.lower()) for p in products] + name_matches = process.extract( + unidecode.unidecode(item_name.lower()), product_names, limit=len(products) + ) + for name, name_score in name_matches: + index = product_names.index(name) + product = products[index] + # Calculate price score if price is provided + price_score = 100.0 + if item_price is not None: + price_diff = abs(product.price - item_price) + if price_diff / item_price > price_tolerance: + price_score = 0.0 + else: + price_score = max(0, 100 - (price_diff / item_price) * 100) + # Calculate combined score + combined_score = (name_score * 0.7) + (price_score * 0.3) + if combined_score >= threshold: + matches.append( + ProductMatch( + score=combined_score, + product=ProductPublic.model_validate(product), + ) ) - ) + # Sort matches by score in descending order matches.sort(key=lambda x: x.score, reverse=True) + # Remove duplicate products, keeping the highest scoring match for each product + unique_matches = [] + seen_products = set() + for match in matches: + if match.product.id not in seen_products: + unique_matches.append(match) + seen_products.add(match.product.id) + # Log debug information logger.debug( - f"Found {len(matches)} matches for item '{item_name}' with price {item_price}" + f"Found {len(unique_matches)} unique matches for item '{item_name}' with price {item_price}" ) - return matches[:max_matches] + for match in unique_matches[:5]: + logger.debug( + f" Match: {match.product.name}, {match.product.price:.2f} € (Score: {match.score:.2f})" + ) + return unique_matches[:max_matches] diff --git a/requirements.in b/requirements.in index ee83973..573896e 100644 --- a/requirements.in +++ b/requirements.in @@ -6,7 +6,7 @@ celery click fastapi[standard] flower -fuzzywuzzy +thefuzz google-generativeai loguru ocrmypdf diff --git a/requirements.txt b/requirements.txt index 04c017d..79e8c8f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.13 # by the following command: # # pip-compile requirements.in @@ -25,8 +25,6 @@ anyio==4.6.2.post1 # httpx # starlette # watchfiles -async-timeout==4.0.3 - # via redis attrs==24.2.0 # via aiohttp billiard==4.2.1 @@ -73,18 +71,16 @@ dnspython==2.7.0 # via email-validator email-validator==2.2.0 # via fastapi -fastapi[standard]==0.115.2 +fastapi[standard]==0.115.3 # via -r requirements.in fastapi-cli[standard]==0.0.5 # via fastapi flower==2.0.1 # via -r requirements.in -frozenlist==1.4.1 +frozenlist==1.5.0 # via # aiohttp # aiosignal -fuzzywuzzy==0.18.0 - # via -r requirements.in google-ai-generativelanguage==0.6.10 # via google-generativeai google-api-core[grpc]==2.21.0 @@ -109,8 +105,6 @@ googleapis-common-protos==1.65.0 # via # google-api-core # grpcio-status -greenlet==3.1.1 - # via sqlalchemy grpcio==1.67.0 # via # google-api-core @@ -154,7 +148,7 @@ loguru==0.7.2 # via -r requirements.in lxml==5.3.0 # via pikepdf -mako==1.3.5 +mako==1.3.6 # via alembic markdown-it-py==3.0.0 # via rich @@ -200,11 +194,11 @@ prompt-toolkit==3.0.48 # via click-repl propcache==0.2.0 # via yarl -proto-plus==1.24.0 +proto-plus==1.25.0 # via # google-ai-generativelanguage # google-api-core -protobuf==5.28.2 +protobuf==5.28.3 # via # google-ai-generativelanguage # google-api-core @@ -230,7 +224,7 @@ pydantic-core==2.23.4 # via pydantic pygments==2.18.0 # via rich -pymupdf==1.24.11 +pymupdf==1.24.12 # via -r requirements.in pyparsing==3.2.0 # via httplib2 @@ -242,7 +236,7 @@ python-dotenv==1.0.1 # via uvicorn python-levenshtein==0.26.0 # via -r requirements.in -python-multipart==0.0.13 +python-multipart==0.0.12 # via # -r requirements.in # fastapi @@ -250,15 +244,17 @@ pytz==2024.2 # via flower pyyaml==6.0.2 # via uvicorn -rapidfuzz==3.10.0 - # via levenshtein -redis==5.1.1 +rapidfuzz==3.10.1 + # via + # levenshtein + # thefuzz +redis==5.2.0 # via -r requirements.in requests==2.32.3 # via # -r requirements.in # google-api-core -rich==13.9.2 +rich==13.9.3 # via # ocrmypdf # typer @@ -280,10 +276,12 @@ sqlalchemy==2.0.36 # sqlmodel sqlmodel==0.0.22 # via -r requirements.in -starlette==0.40.0 +starlette==0.41.0 # via fastapi tenacity==9.0.0 # via -r requirements.in +thefuzz==0.22.1 + # via -r requirements.in tornado==6.4.1 # via flower tqdm==4.66.5 @@ -328,5 +326,5 @@ websockets==13.1 # via uvicorn wrapt==1.16.0 # via deprecated -yarl==1.15.5 +yarl==1.16.0 # via aiohttp