diff --git a/pylate/models/Dense.py b/pylate/models/Dense.py index 2ae1783..955df29 100644 --- a/pylate/models/Dense.py +++ b/pylate/models/Dense.py @@ -2,6 +2,7 @@ import os import torch +from safetensors import safe_open from safetensors.torch import load_model as load_safetensors_model from sentence_transformers.models import Dense as DenseSentenceTransformer from sentence_transformers.util import import_from_string @@ -110,26 +111,49 @@ def from_stanford_weights( """ # Check if the model is locally available if not (os.path.exists(os.path.join(model_name_or_path))): - # Else download the model/use the cached version - model_name_or_path = cached_file( - model_name_or_path, - filename="pytorch_model.bin", - cache_dir=cache_folder, - revision=revision, - local_files_only=local_files_only, - token=token, - use_auth_token=use_auth_token, - ) - # If the model a local folder, load the PyTorch model + # Else download the model/use the cached version. We first try to use the safetensors version and fall back to bin if not existing. All the recent stanford-nlp models are safetensors but we keep bin for compatibility. + try: + model_name_or_path = cached_file( + model_name_or_path, + filename="model.safetensors", + cache_dir=cache_folder, + revision=revision, + local_files_only=local_files_only, + token=token, + use_auth_token=use_auth_token, + ) + except EnvironmentError: + print("No safetensor model found, falling back to bin.") + model_name_or_path = cached_file( + model_name_or_path, + filename="pytorch_model.bin", + cache_dir=cache_folder, + revision=revision, + local_files_only=local_files_only, + token=token, + use_auth_token=use_auth_token, + ) + # If the model a local folder, load the safetensor + # Again, we first try to load the safetensors version and fall back to bin if not existing. + else: + if os.path.exists(os.path.join(model_name_or_path, "model.safetensors")): + model_name_or_path = os.path.join( + model_name_or_path, "model.safetensors" + ) + else: + print("No safetensor model found, falling back to bin.") + model_name_or_path = os.path.join( + model_name_or_path, "pytorch_model.bin" + ) + if model_name_or_path.endswith("safetensors"): + with safe_open(model_name_or_path, framework="pt", device="cpu") as f: + state_dict = {"linear.weight": f.get_tensor("linear.weight")} else: - model_name_or_path = os.path.join(model_name_or_path, "pytorch_model.bin") - - # Load the state dict using torch.load instead of safe_open - state_dict = { - "linear.weight": torch.load(model_name_or_path, map_location="cpu")[ - "linear.weight" - ] - } + state_dict = { + "linear.weight": torch.load(model_name_or_path, map_location="cpu")[ + "linear.weight" + ] + } # Determine input and output dimensions in_features = state_dict["linear.weight"].shape[1] diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..f93cdf1 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,90 @@ +import math + +import torch + +from pylate import models, rank + + +def test_model_creation(**kwargs) -> None: + """Test the creation of different models.""" + query = ["fruits are healthy."] + documents = [["fruits are healthy.", "fruits are good for health."]] + torch.manual_seed(42) + # Creation from a base encoder + model = models.ColBERT(model_name_or_path="bert-base-uncased") + # We don't test the embeddings of newly initied models for now as we need to make it deterministic + # queries_embeddings = model.encode(sentences=query, is_query=True) + # documents_embeddings = model.encode(sentences=documents, is_query=False) + # reranked_documents = rank.rerank( + # documents_ids=[["1", "2"]], + # queries_embeddings=queries_embeddings, + # documents_embeddings=documents_embeddings, + # ) + # assert math.isclose( + # reranked_documents[0][0]["score"], 25.92, rel_tol=0.01, abs_tol=0.01 + # ) + # assert math.isclose(reranked_documents[0][1]["score"], 23.7, rel_tol=0.01, abs_tol=0.01) + + # Creation from a base sentence-transformer + model = models.ColBERT(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2") + # We don't test the embeddings of newly initied models for now as we need to make it deterministic + # queries_embeddings = model.encode(sentences=query, is_query=True) + # documents_embeddings = model.encode(sentences=documents, is_query=False) + # reranked_documents = rank.rerank( + # documents_ids=[["1", "2"]], + # queries_embeddings=queries_embeddings, + # documents_embeddings=documents_embeddings, + # ) + # assert math.isclose( + # reranked_documents[0][0]["score"], 18.77, rel_tol=0.01, abs_tol=0.01 + # ) + # assert math.isclose( + # reranked_documents[0][1]["score"], 18.63, rel_tol=0.01, abs_tol=0.01 + + # Creation from stanford-nlp (safetensor) + model = models.ColBERT(model_name_or_path="answerdotai/answerai-colbert-small-v1") + queries_embeddings = model.encode(sentences=query, is_query=True) + documents_embeddings = model.encode(sentences=documents, is_query=False) + reranked_documents = rank.rerank( + documents_ids=[["1", "2"]], + queries_embeddings=queries_embeddings, + documents_embeddings=documents_embeddings, + ) + assert math.isclose( + reranked_documents[0][0]["score"], 31.71, rel_tol=0.01, abs_tol=0.01 + ) + assert math.isclose( + reranked_documents[0][1]["score"], 31.64, rel_tol=0.01, abs_tol=0.01 + ) + + # Creation from stanford-nlp (bin) + model = models.ColBERT(model_name_or_path="Crystalcareai/Colbertv2") + queries_embeddings = model.encode(sentences=query, is_query=True) + documents_embeddings = model.encode(sentences=documents, is_query=False) + reranked_documents = rank.rerank( + documents_ids=[["1", "2"]], + queries_embeddings=queries_embeddings, + documents_embeddings=documents_embeddings, + ) + assert math.isclose( + reranked_documents[0][0]["score"], 31.15, rel_tol=0.01, abs_tol=0.01 + ) + assert math.isclose( + reranked_documents[0][1]["score"], 30.61, rel_tol=0.01, abs_tol=0.01 + ) + + # Creation from PyLate + model = models.ColBERT(model_name_or_path="lightonai/colbertv2.0") + queries_embeddings = model.encode(sentences=query, is_query=True) + documents_embeddings = model.encode(sentences=documents, is_query=False) + reranked_documents = rank.rerank( + documents_ids=[["1", "2"]], + queries_embeddings=queries_embeddings, + documents_embeddings=documents_embeddings, + ) + assert math.isclose( + reranked_documents[0][0]["score"], 30.01, rel_tol=0.01, abs_tol=0.01 + ) + assert math.isclose( + reranked_documents[0][1]["score"], 26.98, rel_tol=0.01, abs_tol=0.01 + )