Skip to content

Commit

Permalink
Merge pull request #21 from lightonai/update-template
Browse files Browse the repository at this point in the history
Update-modules
  • Loading branch information
raphaelsty authored Jul 22, 2024
2 parents e57ea3a + 32f713a commit 9c37ac1
Show file tree
Hide file tree
Showing 26 changed files with 437 additions and 393 deletions.
148 changes: 78 additions & 70 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,32 +32,31 @@ The following parameters can be passed to the constructor to set different prope

Given that giga-cherche ColBERT models are sentence-transformers models, we can benefit from all the bells and whistles from the latest update, including multi-gpu and BF16 training.
For now, you can train ColBERT models using triplets dataset (datasets containing a positive and a negative for each query). The syntax is the same as sentence-transformers, using the specific elements adapted to ColBERT from giga-cherche:
```

```python
from datasets import load_dataset
from sentence_transformers import (
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)

from giga_cherche.data_collator import ColBERTDataCollator
from giga_cherche.evaluation import ColBERTTripletEvaluator
from giga_cherche.losses import ColBERTLoss
from giga_cherche.models import ColBERT
from giga_cherche import losses, models, data_collator, evaluation

model_name = "bert-base-uncased"
batch_size = 32
num_train_epochs = 1
output_dir = "colbert_base"

model = ColBERT(model_name_or_path=model_name)
model = models.ColBERT(model_name_or_path=model_name)

dataset = load_dataset("sentence-transformers/msmarco-bm25", "triplet", split="train")
splits = dataset.train_test_split(test_size=0.1)
train_dataset = splits["train"]
eval_dataset = splits["test"]

train_loss = ColBERTLoss(model=model)
dev_evaluator = ColBERTTripletEvaluator(
train_loss = losses.ColBERT(model=model)

dev_evaluator = evaluation.ColBERTTripletEvaluator(
anchors=eval_dataset["query"],
positives=eval_dataset["positive"],
negatives=eval_dataset["negative"],
Expand All @@ -78,144 +77,153 @@ trainer = SentenceTransformerTrainer(
eval_dataset=eval_dataset,
loss=train_loss,
evaluator=dev_evaluator,
data_collator=ColBERTDataCollator(model.tokenize),
data_collator=data_collator.ColBERT(model.tokenize),
)

trainer.train()
```

# Inference
## Inference
Once trained, the model can then be loaded to perform inference (you can also load the models directly from Hugging Face, for example using the provided ColBERTv2 model [NohTow/colbertv2_sentence_transformer](https://huggingface.co/NohTow/colbertv2_sentence_transformer)):

```
```python
model = ColBERT(
"NohTow/colbertv2_sentence_transformer",
)
```

You can then call the ```encode``` function to get the embeddings corresponding to your queries:
```

```python
queries_embeddings = model.encode(
["Who is the president of the USA?", "When was the last president of the USA elected?"],
)
```

When encoding documents, simply set the ```is_query``` parameter to false:
```

```python
documents_embeddings = model.encode(
["Joseph Robinette Biden Jr. is an American politician who is the 46th and current president of the United States since 2021. A member of the Democratic Party, he previously served as the 47th vice president from 2009 to 2017 under President Barack Obama and represented Delaware in the United States Senate from 1973 to 2009.", "Donald John Trump (born June 14, 1946) is an American politician, media personality, and businessman who served as the 45th president of the United States from 2017 to 2021."],
is_query=False,
)
```

By default, this will return a list of numpy arrays containing the different embeddings of each sequence in the batch. You can pass the argument ```convert_to_tensor=True``` to get a list of tensors.

We also provide the option to pool the document embeddings using hierarchical clustering. Our recent study showed that we can pool the document embeddings by a factor of 2 to halve the memory consumption of the embeddings without degrading performance. This is done by feeding ```pool_factor=2```to the encode function. Bigger pooling values can be used to obtain different size/performance trade-offs.
Note that query embeddings cannot be pooled.

You can then compute the ColBERT max-sim scores like this:
```
from giga_cherche.scores.colbert_score import colbert_score
similarity_scores = colbert_score(query_embeddings, document_embeddings)

```python
from giga_cherche import scores
similarity_scores = scores.colbert_score(query_embeddings, document_embeddings)
```

# Indexing
## Indexing

We provide a ColBERT index based on the [Weaviate vectordb](https://weaviate.io/). To speed-up the processing, the latest async client is used and the document candidates are generated using an HNSW index, which replace the IVF index from the original ColBERT.

Before being able to create and use an index, you need to need to launch the Weaviate server using Docker (```docker compose up```).

To populate an index, simply create it and then add the computed embeddings with their corresponding ids:
````
WeaviateIndex = WeaviateIndex(name="test_index")

```python
from giga_cherche import indexes

index = indexes.Weaviate(name="test_index")

documents_embeddings = model.encode(
["Document text 1", "Document text 2"],
is_query=False,
)
WeaviateIndex.add_documents(

index.add_documents(
doc_ids=["1", "2"],
doc_embeddings=documents_embeddings,
)
```
You can then remove documents based on their ids:
```
WeaviateIndex.remove_documents(["1"])
```

You can then search into the documents of your index using a retrieval object:
We can also remove documents from the index using their ids:

```python
index.remove_documents(["1"])
```
from giga_cherche.retriever import ColBERTRetriever

To retrieve documents from the index, you can use the following code snippet:

```python
from giga_cherche import retrieve

retriever = retrieve.ColBERT(Weaviate)

queries_embeddings = model.encode(
["My query"],
["A query related to the documents", "Another query"],
)

retriever = ColBERTRetriever(WeaviateIndex)
retrieved_chunks = retriever.retrieve(queries_embeddings, k=10)
```

You can also simply rerank a list of ids produced by an upstream retrieval module (such as BM25):

```
from giga_cherche.reranker import ColBERTReranker
reranker = ColBERTReranker(WeaviateIndex)
```python
from giga_cherche import rerank

reranker = rerank.ColBERT(Weaviate)

reranked_chunks = reranker.rerank(
queries_embeddings, batch_doc_ids=[["7912", "4983"], ["8726", "7891"]]
)
```

# BEIR evaluation
## Evaluation

You can evaluate your ColBERT model on BEIR by indexing the corresponding dataset and then performing retrieval:
```
from tqdm import tqdm
import giga_cherche.evaluation.beir as beir
from giga_cherche.indexes import WeaviateIndex
from giga_cherche.models import ColBERT
from giga_cherche.retriever import ColBERTRetriever
We can eavaluate the performance of the model using the BEIR evaluation framework. The following code snippet shows how to evaluate the model on the SciFact dataset:

dataset = "scifact"
model = ColBERT(
"NohTow/colbertv2_sentence_transformer",
```python
from giga_cherche import evaluation, indexes, models, retrieve, utils

model = models.ColBERT(
model_name_or_path="NohTow/colbertv2_sentence_transformer",
)
WeaviateIndex = WeaviateIndex(name=dataset, recreate=True)
retriever = ColBERTRetriever(WeaviateIndex)
index = indexes.Weaviate(recreate=True, max_doc_length=model.document_length)

retriever = retrieve.ColBERT(index=index)

# Input dataset for evaluation
documents, queries, qrels = beir.load_beir(
dataset,
documents, queries, qrels = evaluation.load_beir(
dataset_name="scifact",
split="test",
)
batch_size = 500
i = 0
pbar = tqdm(total=len(documents))
while i < len(documents):
end_batch = min(i + batch_size, len(documents))
batch = documents[i:end_batch]


for batch in utils.iter_batch(documents, batch_size=500):
documents_embeddings = model.encode(
[doc["text"] for doc in batch],
sentences=[document["text"] for document in batch],
convert_to_numpy=True,
is_query=False,
)
doc_ids = [doc["id"] for doc in batch]
WeaviateIndex.add_documents(
doc_ids=doc_ids,

index.add_documents(
doc_ids=[document["id"] for document in batch],
doc_embeddings=documents_embeddings,
)
i += batch_size
pbar.update(batch_size)

i = 0
pbar = tqdm(total=len(queries))
batch_size = 5

scores = []
while i < len(queries):
end_batch = min(i + batch_size, len(queries))
batch = queries[i:end_batch]
for batch in utils.iter_batch(queries, batch_size=5):
queries_embeddings = model.encode(
queries[i:end_batch],
sentences=[query["text"] for query in batch],
convert_to_numpy=True,
is_query=True,
)
res = retriever.retrieve(queries_embeddings, 10)
scores.extend(res)
pbar.update(batch_size)
i += batch_size

scores.extend(retriever.retrieve(queries=queries_embeddings, k=10))


print(
beir.evaluate(
evaluation.evaluate(
scores=scores,
qrels=qrels,
queries=queries,
Expand Down
25 changes: 11 additions & 14 deletions compose.yml
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
services:
weaviate:
command:
- --host
- 0.0.0.0
- --port
- '8080'
- --scheme
- http
- --host
- 0.0.0.0
- --port
- '8080'
- --scheme
- http
image: cr.weaviate.io/semitechnologies/weaviate:1.24.1
ports:
- 8080:8080
- 50051:50051
- 8080:8080
- 50051:50051
volumes:
- ${HOME}/weaviate:/var/lib/weaviate
# It's a pain when I forget to down the container before ending the remote instance
# restart: on-failure:0
- ${HOME}/weaviate:/var/lib/weaviate
# It's a pain when I forget to down the container before ending the remote instance
# restart: on-failure:0
environment:
QUERY_DEFAULTS_LIMIT: 25
AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true'
Expand All @@ -23,6 +23,3 @@ services:
DEFAULT_VECTORIZER_MODULE: 'none'
ENABLE_MODULES: ''
CLUSTER_HOSTNAME: 'node1'



58 changes: 23 additions & 35 deletions evaluation_beir.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,47 @@
from tqdm import tqdm
"""Evaluation script for the SciFact dataset using the Beir library."""

import giga_cherche.evaluation.beir as beir
from giga_cherche.indexes import WeaviateIndex
from giga_cherche.models import ColBERT
from giga_cherche.retriever import ColBERTRetriever
from giga_cherche import evaluation, indexes, models, retrieve, utils

model = ColBERT(
"NohTow/colbertv2_sentence_transformer",
model = models.ColBERT(
model_name_or_path="NohTow/colbertv2_sentence_transformer",
)
WeaviateIndex = WeaviateIndex(recreate=True, max_doc_length=model.document_length)
retriever = ColBERTRetriever(WeaviateIndex)
index = indexes.Weaviate(recreate=True, max_doc_length=model.document_length)

retriever = retrieve.ColBERT(index=index)

# Input dataset for evaluation
documents, queries, qrels = beir.load_beir(
"scifact",
documents, queries, qrels = evaluation.load_beir(
dataset_name="scifact",
split="test",
)
batch_size = 500
i = 0
pbar = tqdm(total=len(documents))
while i < len(documents):
end_batch = min(i + batch_size, len(documents))
batch = documents[i:end_batch]


for batch in utils.iter_batch(documents, batch_size=500):
documents_embeddings = model.encode(
[doc["text"] for doc in batch],
sentences=[document["text"] for document in batch],
convert_to_numpy=True,
is_query=False,
)
doc_ids = [doc["id"] for doc in batch]
WeaviateIndex.add_documents(
doc_ids=doc_ids,

index.add_documents(
doc_ids=[document["id"] for document in batch],
doc_embeddings=documents_embeddings,
)
i += batch_size
pbar.update(batch_size)

i = 0
pbar = tqdm(total=len(queries))
batch_size = 5

scores = []
while i < len(queries):
end_batch = min(i + batch_size, len(queries))
batch = queries[i:end_batch]
for batch in utils.iter_batch(queries, batch_size=5):
queries_embeddings = model.encode(
queries[i:end_batch],
sentences=[query["text"] for query in batch],
convert_to_numpy=True,
is_query=True,
)
res = retriever.retrieve(queries_embeddings, 10)
scores.extend(res)
pbar.update(batch_size)
i += batch_size

scores.extend(retriever.retrieve(queries=queries_embeddings, k=10))


print(
beir.evaluate(
evaluation.evaluate(
scores=scores,
qrels=qrels,
queries=queries,
Expand Down
4 changes: 2 additions & 2 deletions giga_cherche/data_collator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .colbert_data_collator import ColBERTDataCollator
from .colbert_data_collator import ColBERT

__all__ = ["ColBERTDataCollator"]
__all__ = ["ColBERT"]
Loading

0 comments on commit 9c37ac1

Please sign in to comment.