Skip to content

Commit

Permalink
Merge pull request #22 from lightonai/update-template
Browse files Browse the repository at this point in the history
update-code-quality
  • Loading branch information
raphaelsty authored Jul 22, 2024
2 parents 9c37ac1 + dffb017 commit 123e9aa
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 35 deletions.
4 changes: 1 addition & 3 deletions evaluation_beir.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
split="test",
)


for batch in utils.iter_batch(documents, batch_size=500):
documents_embeddings = model.encode(
sentences=[document["text"] for document in batch],
Expand All @@ -28,8 +27,8 @@
doc_embeddings=documents_embeddings,
)


scores = []

for batch in utils.iter_batch(queries, batch_size=5):
queries_embeddings = model.encode(
sentences=[query["text"] for query in batch],
Expand All @@ -39,7 +38,6 @@

scores.extend(retriever.retrieve(queries=queries_embeddings, k=10))


print(
evaluation.evaluate(
scores=scores,
Expand Down
2 changes: 1 addition & 1 deletion giga_cherche/evaluation/colbert_triplet_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
show_progress_bar: bool = False,
write_csv: bool = True,
truncate_dim: int | None = None,
):
) -> None:
"""
Initializes a TripletEvaluator object.
Expand Down
55 changes: 26 additions & 29 deletions giga_cherche/indexes/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,47 +14,42 @@
class Weaviate(Base):
def __init__(
self,
host: str | None = "localhost",
port: str | None = "8080",
name: str | None = "colbert_collection",
recreate: bool = False,
max_doc_length: int | None = 180,
host: str = "localhost",
port: str = "8080",
name: str = "colbert_collection",
override_collection: bool = False,
max_doc_length: int = 180,
connect_attempt: int = 5,
connect_retry_delay: float = 5.0,
) -> None:
self.host = host
self.port = port
self.name = name
self.max_doc_length = max_doc_length
fail_counter = 0
attempt_number = 5
retry_delay = 5.0
while fail_counter < attempt_number:
self.connect_attempt = connect_attempt
self.connect_retry_delay = connect_retry_delay

for _ in range(connect_attempt):
try:
with weaviate.connect_to_local(
host=self.host, port=self.port
) as client:
print("Successful connection to the Weaviate container.")
if not client.collections.exists(self.name):
print(f"Collection {self.name} does not exist, creating it.")
self.create_collection(self.name)
elif recreate:
print(f"Collection {self.name} exists, recreating it.")
self.create_collection(name=self.name)
elif override_collection:
client.collections.delete(self.name)
self.create_collection(self.name)
self.create_collection(name=self.name)
else:
print(
f"Loaded collection with {client.collections.get(self.name).aggregate.over_all(total_count=True).total_count} vectors",
n_vectors = (
client.collections.get(self.name)
.aggregate.over_all(total_count=True)
.total_count
)

break
except Exception as e:
print(
f"Could not connect to the Weaviate container, retrying in {retry_delay} secs: {str(e)}"
)
fail_counter += 1
time.sleep(retry_delay)

if fail_counter >= attempt_number:
raise ConnectionError("Could not connect to the Weaviate container")
print(f"Loaded collection with {n_vectors} vectors")
break
except Exception:
print("Could not connect to the Weaviate container, retrying...")
time.sleep(connect_retry_delay)

def create_collection(self, name: str) -> None:
with weaviate.connect_to_local(host=self.host, port=self.port) as client:
Expand Down Expand Up @@ -156,7 +151,9 @@ async def query_all_embeddings(
return res

def query(self, queries_embeddings: list[list[int | float]], k: int = 5):
return asyncio.run(self.query_all_embeddings(queries_embeddings, k))
return asyncio.run(
self.query_all_embeddings(queries_embeddings=queries_embeddings, k=k)
)

async def get_doc_embeddings(self, vector_index, doc_id: str):
return await vector_index.query.fetch_objects(
Expand Down
2 changes: 1 addition & 1 deletion giga_cherche/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __init__(
device: str | None = None,
prompts: dict[str, str] | None = None,
default_prompt_name: str | None = None,
similarity_fn_name: Optional[Union[str, SimilarityFunction]] = None,
similarity_fn_name: Optional[str | SimilarityFunction] = None,
cache_folder: str | None = None,
trust_remote_code: bool = False,
revision: str | None = None,
Expand Down
23 changes: 22 additions & 1 deletion giga_cherche/utils/iter_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,28 @@
def iter_batch(
X: list[str], batch_size: int, tqdm_bar: bool = True, desc: str = ""
) -> list:
"""Iterate over a list of elements by batch."""
"""Iterate over a list of elements by batch.
Example
-------
>>> from giga_cherche import utils
>>> X = [
... "element 0",
... "element 1",
... "element 2",
... "element 3",
... "element 4",
... ]
>>> n_samples = 0
>>> for batch in utils.iter_batch(X, batch_size=2):
... n_samples += len(batch)
>>> n_samples
5
"""
batchs = [X[pos : pos + batch_size] for pos in range(0, len(X), batch_size)]

if tqdm_bar:
Expand Down

0 comments on commit 123e9aa

Please sign in to comment.