forked from pamelafox/learnlive-rag-starter
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpostgres_searcher.py
126 lines (114 loc) · 4.84 KB
/
postgres_searcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from openai import AsyncAzureOpenAI, AsyncOpenAI
from pgvector.utils import to_db
from sqlalchemy import Float, Integer, column, select, text
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from embeddings import compute_text_embedding
from postgres_models import Item
class PostgresSearcher:
def __init__(
self,
postgres_host: str,
postgres_username: str,
postgres_database: str,
postgres_password: str,
openai_embed_client: AsyncOpenAI | AsyncAzureOpenAI,
embed_model: str,
embed_dimensions: int | None,
):
self.openai_embed_client = openai_embed_client
self.embed_model = embed_model
self.embed_dimensions = embed_dimensions
self.engine = create_async_engine(
f"postgresql+asyncpg://{postgres_username}:{postgres_password}@{postgres_host}/{postgres_database}",
echo=False,
)
def build_filter_clause(self, filters) -> tuple[str, str]:
if filters is None:
return "", ""
filter_clauses = []
for filter in filters:
if isinstance(filter["value"], str):
filter["value"] = f"'{filter['value']}'"
filter_clauses.append(f"{filter['column']} {filter['comparison_operator']} {filter['value']}")
filter_clause = " AND ".join(filter_clauses)
if len(filter_clause) > 0:
return f"WHERE {filter_clause}", f"AND {filter_clause}"
return "", ""
async def search(
self, query_text: str | None, query_vector: list[float] | list, top: int = 5, filters: list[dict] | None = None
):
filter_clause_where, filter_clause_and = self.build_filter_clause(filters)
table_name = Item.__tablename__
vector_query = f"""
SELECT id, RANK () OVER (ORDER BY embedding <=> :embedding) AS rank
FROM {table_name}
{filter_clause_where}
ORDER BY embedding <=> :embedding
LIMIT 20
"""
fulltext_query = f"""
SELECT id, RANK () OVER (ORDER BY ts_rank_cd(to_tsvector('english', description), query) DESC)
FROM {table_name}, plainto_tsquery('english', :query) query
WHERE to_tsvector('english', description) @@ query {filter_clause_and}
ORDER BY ts_rank_cd(to_tsvector('english', description), query) DESC
LIMIT 20
"""
hybrid_query = f"""
WITH vector_search AS (
{vector_query}
),
fulltext_search AS (
{fulltext_query}
)
SELECT
COALESCE(vector_search.id, fulltext_search.id) AS id,
COALESCE(1.0 / (:k + vector_search.rank), 0.0) +
COALESCE(1.0 / (:k + fulltext_search.rank), 0.0) AS score
FROM vector_search
FULL OUTER JOIN fulltext_search ON vector_search.id = fulltext_search.id
ORDER BY score DESC
LIMIT 20
"""
if query_text is not None and len(query_vector) > 0:
sql = text(hybrid_query).columns(column("id", Integer), column("score", Float))
elif len(query_vector) > 0:
sql = text(vector_query).columns(column("id", Integer), column("rank", Integer))
elif query_text is not None:
sql = text(fulltext_query).columns(column("id", Integer), column("rank", Integer))
else:
raise ValueError("Both query text and query vector are empty")
async with async_sessionmaker(self.engine, expire_on_commit=False, autoflush=False)() as session:
results = (
await session.execute(
sql,
{"embedding": to_db(query_vector), "query": query_text, "k": 60},
)
).fetchall()
# Convert results to SQLAlchemy models
row_models = []
for id, _ in results[:top]:
item = await session.execute(select(Item).where(Item.id == id))
row_models.append(item.scalar())
return row_models
async def search_and_embed(
self,
query_text: str | None = None,
top: int = 5,
enable_vector_search: bool = True,
enable_text_search: bool = True,
filters: list[dict] | None = None,
) -> list[Item]:
"""
Search rows by query text. Optionally converts the query text to a vector if enable_vector_search is True.
"""
vector: list[float] = []
if enable_vector_search and query_text is not None:
vector = await compute_text_embedding(
query_text,
self.openai_embed_client,
self.embed_model,
self.embed_dimensions,
)
if not enable_text_search:
query_text = None
return await self.search(query_text, vector, top, filters)