-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathvector_embedding.py
More file actions
93 lines (80 loc) · 3.69 KB
/
vector_embedding.py
File metadata and controls
93 lines (80 loc) · 3.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import logging
import hashlib
from typing import List, Dict, Tuple
from sentence_transformers import SentenceTransformer
from unicode_utils import clean_unicode_text, clean_list_strings
# Setup basic logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def initialize_embedding_model(model_name: str = "sentence-transformers/all-MiniLM-L6-v2") -> SentenceTransformer:
"""
Initializes and returns the SentenceTransformer embedding model.
"""
try:
model = SentenceTransformer(model_name)
logging.info(f"Initialized embedding model: {model_name}")
return model
except Exception as e:
logging.error(f"Failed to initialize embedding model ({model_name}): {e}")
raise RuntimeError("Sentence-transformers must be installed to use this function") from e
def generate_embeddings(texts: List[str], embedding_model: SentenceTransformer) -> List[List[float]]:
"""
Generates embeddings for a list of texts using the provided SentenceTransformer model.
"""
if not texts:
logging.warning("No texts provided for embedding generation")
return []
logging.info(f"Generating embeddings for {len(texts)} texts")
try:
embeddings = embedding_model.encode(texts, convert_to_tensor=False) # np.ndarray
embeddings = embeddings.tolist() # convert to list of lists
if len(embeddings) != len(texts):
raise ValueError("Number of generated embeddings does not match number of texts.")
logging.debug(f"Generated embeddings shape: {len(embeddings)} x {len(embeddings[0])}")
return embeddings
except Exception as e:
logging.error(f"Failed to generate embeddings: {e}")
raise
def generate_chunk_id(file_path: str, chunk_index: int) -> str:
"""
Generates a unique ID for a text chunk based on its file path and index.
"""
file_stem = file_path.split('/')[-1].split('\\')[-1]
if '.' in file_stem:
file_stem = file_stem.rsplit('.', 1)[0]
path_hash = hashlib.md5(file_path.encode('utf-8')).hexdigest()[:10]
return f"{file_stem}-{path_hash}-{chunk_index}"
def embed_chunks_for_pipeline(
chunks: List[Dict],
embedding_model: SentenceTransformer
) -> Tuple[List[str], List[List[float]], List[Dict], List[str]]:
"""
Takes a list of chunk dictionaries, generates embeddings, and prepares data
for the next pipeline step.
Args:
chunks (List[Dict]): Each dict must have 'text', 'file_path', 'chunk_index'.
embedding_model (SentenceTransformer): Pre-initialized model.
Returns:
Tuple[List[str], List[List[float]], List[Dict], List[str]]:
- A list of the chunk texts.
- A list of the generated embeddings.
- A list of metadata dictionaries.
- A list of unique IDs for each chunk.
"""
if not chunks:
logging.warning("No chunks provided for embedding")
return [], [], [], []
texts = [clean_unicode_text(chunk["text"]) for chunk in chunks]
embeddings = generate_embeddings(texts, embedding_model)
metadatas, ids = [], []
for chunk in chunks:
chunk_id = generate_chunk_id(chunk["file_path"], chunk["chunk_index"])
ids.append(chunk_id)
metadata = {
"source": clean_unicode_text(chunk.get("source", "")), # keep original source if available
"file_path": clean_unicode_text(chunk["file_path"]),
"chunk_index": chunk["chunk_index"],
"chunk_text_length": len(chunk["text"])
}
metadatas.append(metadata)
logging.info(f"Prepared {len(chunks)} chunks for vector database storage")
return texts, embeddings, metadatas, ids