forked from pamelafox/learnlive-rag-starter
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathupdate_embeddings.py
61 lines (51 loc) · 2.21 KB
/
update_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import asyncio
import json
import logging
import os
import openai
from dotenv import load_dotenv
from embeddings import compute_text_embedding
from postgres_models import Item
logger = logging.getLogger("ragapp")
async def update_embeddings(in_seed_data=False):
openai_embed_client = openai.AsyncOpenAI(
base_url="https://models.inference.ai.azure.com", api_key=os.getenv("GITHUB_TOKEN")
)
embedding_column = "embedding"
logger.info(f"Updating embeddings in column: {embedding_column}")
if in_seed_data:
current_dir = os.path.dirname(os.path.realpath(__file__))
rows = []
# Create a new file for appending new embeddings
new_seed_data_file = os.path.join(current_dir, "seed_data_two.json")
with open(os.path.join(current_dir, "seed_data.json")) as f:
seed_data_objects = json.load(f)
for ind, seed_data_object in enumerate(seed_data_objects):
print("Computing embedding for seed data index: ", ind)
# for each column in the JSON, store it in the same named attribute in the object
attrs = {key: value for key, value in seed_data_object.items()}
row = Item(
id=attrs["id"],
description=attrs["description"],
type=attrs["type"],
brand=attrs["brand"],
price=attrs["price"],
name=attrs["name"],
)
row.embedding = await compute_text_embedding(
row.to_str_for_embedding(),
openai_client=openai_embed_client,
embed_model="text-embedding-3-small",
embedding_dimensions=256,
)
rows.append(row)
with open(new_seed_data_file, "a") as f:
json.dump([row.to_dict(include_embedding=True)], f, indent=4)
# wait 4 seconds to avoid rate limiting, 15 requests per minute
await asyncio.sleep(5)
return
if __name__ == "__main__":
logging.basicConfig(level=logging.WARNING)
logger.setLevel(logging.INFO)
load_dotenv(override=True)
asyncio.run(update_embeddings(True))