Skip to content

Commit

Permalink
feat: better fuzzy search
Browse files Browse the repository at this point in the history
  • Loading branch information
d3vv3 committed Oct 26, 2024
1 parent 7ca8980 commit a1d09e8
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 48 deletions.
71 changes: 44 additions & 27 deletions app/shared/product_matcher.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,70 @@
import unidecode
from fuzzywuzzy import fuzz
from thefuzz import process
from loguru import logger
from app.models import ProductMatch, ProductPublic
from typing import List, Optional


def find_closest_products_task(
products: List[ProductPublic] = [],
products: List[ProductPublic],
item_name: Optional[str] = None,
item_price: Optional[float] = None,
threshold: float = 60.0,
max_matches: int = 10,
price_tolerance: float = 0.3,
) -> List[ProductMatch]:
"""
Task to find closest products based on name and price similarity.
Find closest products based on name and price similarity using thefuzz process.
"""
logger.info(
f"Processing product matching task for '{item_name}' with price {item_price}"
)
if not products:
logger.warning("No products provided for matching.")
return []
matches = []

for product in products:
name_score: float = 0.0
price_score: float = 0.0

if item_name is not None:
name_score = fuzz.token_set_ratio(
unidecode.unidecode(item_name.lower()),
unidecode.unidecode(product.name.lower()),
)

if item_price:
price_diff = abs(product.price - item_price)
price_score = max(0, 100 - (price_diff / item_price) * 100)

combined_score = (name_score * 0.7) + (price_score * 0.3)

if combined_score >= threshold:
matches.append(
ProductMatch(
score=combined_score, product=ProductPublic.model_validate(product)
matches = []
if item_name is not None:
product_names = [unidecode.unidecode(p.name.lower()) for p in products]
name_matches = process.extract(
unidecode.unidecode(item_name.lower()), product_names, limit=len(products)
)
for name, name_score in name_matches:
index = product_names.index(name)
product = products[index]
# Calculate price score if price is provided
price_score = 100.0
if item_price is not None:
price_diff = abs(product.price - item_price)
if price_diff / item_price > price_tolerance:
price_score = 0.0
else:
price_score = max(0, 100 - (price_diff / item_price) * 100)
# Calculate combined score
combined_score = (name_score * 0.7) + (price_score * 0.3)
if combined_score >= threshold:
matches.append(
ProductMatch(
score=combined_score,
product=ProductPublic.model_validate(product),
)
)
)

# Sort matches by score in descending order
matches.sort(key=lambda x: x.score, reverse=True)
# Remove duplicate products, keeping the highest scoring match for each product
unique_matches = []
seen_products = set()
for match in matches:
if match.product.id not in seen_products:
unique_matches.append(match)
seen_products.add(match.product.id)
# Log debug information
logger.debug(
f"Found {len(matches)} matches for item '{item_name}' with price {item_price}"
f"Found {len(unique_matches)} unique matches for item '{item_name}' with price {item_price}"
)
return matches[:max_matches]
for match in unique_matches[:5]:
logger.debug(
f" Match: {match.product.name}, {match.product.price:.2f} € (Score: {match.score:.2f})"
)
return unique_matches[:max_matches]
2 changes: 1 addition & 1 deletion requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ celery
click
fastapi[standard]
flower
fuzzywuzzy
thefuzz
google-generativeai
loguru
ocrmypdf
Expand Down
38 changes: 18 additions & 20 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile with Python 3.11
# This file is autogenerated by pip-compile with Python 3.13
# by the following command:
#
# pip-compile requirements.in
Expand All @@ -25,8 +25,6 @@ anyio==4.6.2.post1
# httpx
# starlette
# watchfiles
async-timeout==4.0.3
# via redis
attrs==24.2.0
# via aiohttp
billiard==4.2.1
Expand Down Expand Up @@ -73,18 +71,16 @@ dnspython==2.7.0
# via email-validator
email-validator==2.2.0
# via fastapi
fastapi[standard]==0.115.2
fastapi[standard]==0.115.3
# via -r requirements.in
fastapi-cli[standard]==0.0.5
# via fastapi
flower==2.0.1
# via -r requirements.in
frozenlist==1.4.1
frozenlist==1.5.0
# via
# aiohttp
# aiosignal
fuzzywuzzy==0.18.0
# via -r requirements.in
google-ai-generativelanguage==0.6.10
# via google-generativeai
google-api-core[grpc]==2.21.0
Expand All @@ -109,8 +105,6 @@ googleapis-common-protos==1.65.0
# via
# google-api-core
# grpcio-status
greenlet==3.1.1
# via sqlalchemy
grpcio==1.67.0
# via
# google-api-core
Expand Down Expand Up @@ -154,7 +148,7 @@ loguru==0.7.2
# via -r requirements.in
lxml==5.3.0
# via pikepdf
mako==1.3.5
mako==1.3.6
# via alembic
markdown-it-py==3.0.0
# via rich
Expand Down Expand Up @@ -200,11 +194,11 @@ prompt-toolkit==3.0.48
# via click-repl
propcache==0.2.0
# via yarl
proto-plus==1.24.0
proto-plus==1.25.0
# via
# google-ai-generativelanguage
# google-api-core
protobuf==5.28.2
protobuf==5.28.3
# via
# google-ai-generativelanguage
# google-api-core
Expand All @@ -230,7 +224,7 @@ pydantic-core==2.23.4
# via pydantic
pygments==2.18.0
# via rich
pymupdf==1.24.11
pymupdf==1.24.12
# via -r requirements.in
pyparsing==3.2.0
# via httplib2
Expand All @@ -242,23 +236,25 @@ python-dotenv==1.0.1
# via uvicorn
python-levenshtein==0.26.0
# via -r requirements.in
python-multipart==0.0.13
python-multipart==0.0.12
# via
# -r requirements.in
# fastapi
pytz==2024.2
# via flower
pyyaml==6.0.2
# via uvicorn
rapidfuzz==3.10.0
# via levenshtein
redis==5.1.1
rapidfuzz==3.10.1
# via
# levenshtein
# thefuzz
redis==5.2.0
# via -r requirements.in
requests==2.32.3
# via
# -r requirements.in
# google-api-core
rich==13.9.2
rich==13.9.3
# via
# ocrmypdf
# typer
Expand All @@ -280,10 +276,12 @@ sqlalchemy==2.0.36
# sqlmodel
sqlmodel==0.0.22
# via -r requirements.in
starlette==0.40.0
starlette==0.41.0
# via fastapi
tenacity==9.0.0
# via -r requirements.in
thefuzz==0.22.1
# via -r requirements.in
tornado==6.4.1
# via flower
tqdm==4.66.5
Expand Down Expand Up @@ -328,5 +326,5 @@ websockets==13.1
# via uvicorn
wrapt==1.16.0
# via deprecated
yarl==1.15.5
yarl==1.16.0
# via aiohttp

0 comments on commit a1d09e8

Please sign in to comment.