forked from vara-prasad-07/Group_AC
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathchroma.py
More file actions
79 lines (64 loc) · 2.8 KB
/
Copy pathchroma.py
File metadata and controls
79 lines (64 loc) · 2.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
from typing import List, Dict, Any
import chromadb
from dotenv import load_dotenv
# Load env vars from .env file
load_dotenv()
class ChromaClientWrapper:
"""Wrapper around a Chroma CloudClient for simple upsert and search operations.
Expects environment variables or explicit args:
- CHROMA_API_KEY
- CHROMA_TENANT
- CHROMA_DB
"""
def __init__(self, collection_name: str = "tickets", api_key: str | None = None, tenant: str | None = None, database: str | None = None):
api_key = api_key or os.getenv("CHROMA_API_KEY")
tenant = tenant or os.getenv("CHROMA_TENANT")
database = database or os.getenv("CHROMA_DB")
if not api_key:
raise ValueError("CHROMA_API_KEY is required in environment or constructor")
# Create CloudClient
self.client = chromadb.CloudClient(api_key=api_key, tenant=tenant, database=database)
# get or create collection
try:
self.collection = self.client.get_collection(collection_name)
except Exception:
self.collection = self.client.create_collection(collection_name)
def upsert(self, ids: List[str], embeddings: List[List[float]], metadatas: List[dict], documents: List[str]):
"""Add or update documents in the collection."""
return self.collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents)
def search_by_embedding(self, query_embedding: List[float], n_results: int = 3) -> Dict[str, Any]:
"""Query the collection using a vector and return Chroma result format.
Returns a dict with keys: 'ids', 'documents', 'metadatas', 'distances'
Compatible with format expected by rag_chain.py
"""
result = self.collection.query(
query_embeddings=[query_embedding],
n_results=n_results,
include=["metadatas", "documents", "distances"]
)
# Chroma returns nested lists; unpack and structure for downstream use
if not result or not result.get("documents"):
return {
"ids": [],
"documents": [],
"metadatas": [],
"distances": []
}
# Extract first result set (query returns list of lists)
docs_list = result.get("documents", [[]])[0] if result.get("documents") else []
metas_list = result.get("metadatas", [[]])[0] if result.get("metadatas") else []
dists_list = result.get("distances", [[]])[0] if result.get("distances") else []
# Build IDs list from metadata ticket_id
ids_list = []
for i, meta in enumerate(metas_list):
tid = meta.get("ticket_id", f"unknown_{i}") if meta else f"unknown_{i}"
ids_list.append(tid)
return {
"ids": ids_list,
"documents": docs_list,
"metadatas": metas_list,
"distances": dists_list
}
if __name__ == "__main__":
print("Chroma client wrapper loaded. Set CHROMA_API_KEY env var to use.")