-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdb.py
30 lines (23 loc) · 1017 Bytes
/
db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os
from psycopg2 import connect
PGHOST = os.environ["PGHOST"]
PGUSER = os.environ["PGUSER"]
PGPASSWORD = os.environ["PGPASSWORD"]
PGDATABASE = os.environ["PGDATABASE"]
class VectorDatabase:
def __init__(self):
self.conn = connect(user=PGUSER, password=PGPASSWORD, host=PGHOST, port=5432, dbname=PGDATABASE)
def __exit__(self, exc_type, exc_value, traceback):
self.conn.close()
def save_embedding(self, _id: int, data: str, embedding: list[float]):
with self.conn.cursor() as cursor:
cursor.execute("INSERT INTO embeddings (id, data, embedding) VALUES (%s, %s, %s)", (_id, data, embedding))
self.conn.commit()
def search_documents(self, question_embedding):
cursor = self.conn.cursor()
cursor.execute(
"SELECT data FROM embeddings v ORDER BY v.embedding <#> (%s)::vector LIMIT 1",
(question_embedding,),
)
results = cursor.fetchall()
return list(map(lambda x: x[0], results))