Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding document pooling option to the encode function #12

Merged
merged 1 commit into from
Jun 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions giga_cherche/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import transformers
from huggingface_hub import HfApi
from numpy import ndarray
from scipy.cluster.hierarchy import fcluster, linkage
from sentence_transformers.evaluation import SentenceEvaluator
from sentence_transformers.fit_mixin import FitMixin
from sentence_transformers.model_card import (
Expand Down Expand Up @@ -139,6 +140,8 @@ class ColBERT(nn.Sequential, FitMixin):
Truncate the inputs to the encoder max lengths or use sliding window encoding.
query_length
The length of the query to pad to with mask tokens.
attend_to_expansion_tokens
Whether to attend to the expansion tokens in the attention layers model. If False, the original tokens will not only attend to the expansion tokens, only the expansion tokens will attend to the original tokens. (Default was False in original ColBERT codebase)

Example:
::
Expand Down Expand Up @@ -451,6 +454,8 @@ def encode(
device: str = None,
normalize_embeddings: bool = True, # Normalize the embedding to compute cosine similarity
is_query: bool = True,
pool_factor: int = 1,
protected_tokens: int = 1,
) -> Union[List[Tensor], ndarray, Tensor]:
"""
Computes sentence embeddings.
Expand Down Expand Up @@ -482,6 +487,8 @@ def encode(
normalize_embeddings (bool, optional): Whether to normalize returned vectors to have length 1. In that case,
the faster dot-product (util.dot_score) instead of cosine similarity can be used. Defaults to False.
is_query (bool, optional): Whether the input sentences are queries. If True, the query prefix is added to the input sentences and the sequence is padded, else the document prefix is added and the sequence is not padded. Defaults to True.
pool_factor (int, optional): The factor by which to pool the documents embeddings, resulting in 1/pool_factor of the original token. If set to 1, no pooling is done, 2 keep only 50% of the tokens, 3, 33%, ... Defaults to 1.
protected_tokens (int, optional): The number of tokens at the beginning of the sequence that should not be pooled. Defaults to 1 (CLS token).

Returns:
Union[List[Tensor], ndarray, Tensor]: By default, a 2d numpy array with shape [num_inputs, output_dimension] is returned.
Expand Down Expand Up @@ -645,7 +652,13 @@ def encode(
token_emb = torch.nn.functional.normalize(
token_emb[0 : last_mask_id + 1], p=2, dim=1
)

embeddings.append(token_emb[0 : last_mask_id + 1])
# If we are encoding documents and the pool factor is greater than 1, we pool the embeddings
if pool_factor > 1 and not is_query:
embeddings = self.pool_embeddings_hierarchical(
embeddings, pool_factor, protected_tokens
)
# elif output_value is None: # Return all outputs
# embeddings = []
# for sent_idx in range(len(out_features["sentence_embedding"])):
Expand Down Expand Up @@ -696,6 +709,61 @@ def encode(

return all_embeddings

# TODO: add typing
"""
Computes sentence embeddings.

Args:
documents_embeddings (List[Tensor]): The embeddings of the documents to pool.
pool_factor (int, optional): The factor by which to pool the documents embeddings, resulting in 1/pool_factor of the original token. If set to 1, no pooling is done, 2 keep only 50% of the tokens, 3, 33%, ... Defaults to 1.
protected_tokens (int, optional): The number of tokens at the beginning of the sequence that should not be pooled. Defaults to 1 (CLS token).

Returns:
List[Tensor]: The list of pooled embeddings.
"""

def pool_embeddings_hierarchical(
self,
documents_embeddings: List[Tensor],
pool_factor: int = 1,
protected_tokens: int = 1,
) -> List[Tensor]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pooled_embeddings = []
for document_embeddings in documents_embeddings:
document_pooled_embeddings = []
document_embeddings.to(device)
# Remove the tokens at protected_tokens indices
protected_embeddings = document_embeddings[:protected_tokens]
document_embeddings = document_embeddings[protected_tokens:]

# Cosine similarity computation (vector are already normalized)
similarities = torch.mm(document_embeddings, document_embeddings.t())

# Convert similarities to a distance for better ward compatibility
similarities = 1 - similarities.cpu().numpy()

# Create hierarchical clusters using ward's method
Z = linkage(similarities, metric="euclidean", method="ward")
length = len(document_embeddings)
# Determine the number of clusters we want in the end based on the pool factor
max_clusters = length // pool_factor if length // pool_factor > 0 else 1
cluster_labels = fcluster(Z, t=max_clusters, criterion="maxclust")
# Pool embeddings within each cluster
for cluster_id in range(1, max_clusters + 1):
cluster_indices = torch.where(
torch.tensor(cluster_labels == cluster_id, device=device)
)[0]
if cluster_indices.numel() > 0:
pooled_embedding = document_embeddings[cluster_indices].mean(dim=0)
document_pooled_embeddings.append(pooled_embedding)

# Re-add the protected tokens to pooled_embeddings
document_pooled_embeddings.extend(protected_embeddings)
pooled_embeddings.append(torch.stack(document_pooled_embeddings))

return pooled_embeddings

@property
def similarity_fn_name(self) -> Optional[str]:
"""Return the name of the similarity function used by :meth:`SentenceTransformer.similarity` and :meth:`SentenceTransformer.similarity_pairwise`.
Expand Down
Loading