diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index 686595d7..1a2fbcd2 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -12,6 +12,7 @@ ) from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker + supported_colbert_models = [ { "model": "colbert-ir/colbertv2.0", @@ -41,7 +42,7 @@ class Colbert(LateInteractionTextEmbeddingBase, OnnxTextModel[np.ndarray]): QUERY_MARKER_TOKEN_ID = 1 DOCUMENT_MARKER_TOKEN_ID = 2 - MIN_QUERY_LENGTH = 32 + MIN_QUERY_LENGTH = 31 # it's 32, we add one additional special token in the beginning MASK_TOKEN = "[MASK]" def _post_process_onnx_output( @@ -69,10 +70,9 @@ def _post_process_onnx_output( def _preprocess_onnx_input( self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True ) -> Dict[str, np.ndarray]: - if is_doc: - onnx_input["input_ids"][:, 1] = self.DOCUMENT_MARKER_TOKEN_ID - else: - onnx_input["input_ids"][:, 1] = self.QUERY_MARKER_TOKEN_ID + marker_token = self.DOCUMENT_MARKER_TOKEN_ID if is_doc else self.QUERY_MARKER_TOKEN_ID + onnx_input["input_ids"] = np.insert(onnx_input["input_ids"], 1, marker_token, axis=1) + onnx_input["attention_mask"] = np.insert(onnx_input["attention_mask"], 1, 1, axis=1) return onnx_input def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]: @@ -83,9 +83,6 @@ def tokenize(self, documents: List[str], is_doc: bool = True) -> List[Encoding]: ) def _tokenize_query(self, query: str) -> List[Encoding]: - # "@ " is added to a query to be replaced with a special query token - # make sure that "@ " is considered as a single token - query = f"@ {query}" encoded = self.tokenizer.encode_batch([query]) # colbert authors recommend to pad queries with [MASK] tokens for query augmentation to improve performance if len(encoded[0].ids) < self.MIN_QUERY_LENGTH: @@ -105,9 +102,6 @@ def _tokenize_query(self, query: str) -> List[Encoding]: return encoded def _tokenize_documents(self, documents: List[str]) -> List[Encoding]: - # "@ " is added to a document to be replaced with a special document token - # make sure that "@ " is considered as a single token - documents = ["@ " + doc for doc in documents] encoded = self.tokenizer.encode_batch(documents) return encoded diff --git a/fastembed/late_interaction/jina_colbert.py b/fastembed/late_interaction/jina_colbert.py index 42232aeb..9f7c4b32 100644 --- a/fastembed/late_interaction/jina_colbert.py +++ b/fastembed/late_interaction/jina_colbert.py @@ -5,6 +5,7 @@ from fastembed.late_interaction.colbert import Colbert from fastembed.text.onnx_text_model import TextEmbeddingWorker + supported_jina_colbert_models = [ { "model": "jinaai/jina-colbert-v2", @@ -24,7 +25,7 @@ class JinaColbert(Colbert): QUERY_MARKER_TOKEN_ID = 250002 DOCUMENT_MARKER_TOKEN_ID = 250003 - MIN_QUERY_LENGTH = 32 + MIN_QUERY_LENGTH = 31 # it's 32, we add one additional special token in the beginning MASK_TOKEN = "" @classmethod @@ -43,11 +44,10 @@ def list_supported_models(cls) -> List[Dict[str, Any]]: def _preprocess_onnx_input( self, onnx_input: Dict[str, np.ndarray], is_doc: bool = True ) -> Dict[str, np.ndarray]: - if is_doc: - onnx_input["input_ids"][:, 1] = self.DOCUMENT_MARKER_TOKEN_ID - else: - onnx_input["input_ids"][:, 1] = self.QUERY_MARKER_TOKEN_ID - # the attention mask for jina-colbert-v2 is always 1 in queries + onnx_input = super()._preprocess_onnx_input(onnx_input, is_doc) + + # the attention mask for jina-colbert-v2 is always 1 in queries + if not is_doc: onnx_input["attention_mask"][:] = 1 return onnx_input