Skip to content

Commit

Permalink
Merge pull request #16 from lightonai/tokenization_and_skiplist
Browse files Browse the repository at this point in the history
Making skiplist configurable
  • Loading branch information
NohTow authored Jun 24, 2024
2 parents 8bb3b2f + cd02227 commit 740d0cf
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions giga_cherche/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def __init__(
query_length: Optional[int] = None,
document_length: Optional[int] = None,
attend_to_expansion_tokens: Optional[bool] = False,
skiplist_words: Optional[List[str]] = None,
):
# Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name`
self.prompts = prompts or {}
Expand All @@ -214,6 +215,7 @@ def __init__(
self.query_length = query_length
self.document_length = document_length
self.attend_to_expansion_tokens = attend_to_expansion_tokens
self.skiplist_words = skiplist_words
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v3 of SentenceTransformers.",
Expand Down Expand Up @@ -444,21 +446,25 @@ def __init__(
self.query_prefix_id = self.tokenizer.convert_tokens_to_ids(self.query_prefix)
# We are using the mask token as the padding token for padding queries
self.tokenizer.pad_token_id = self.tokenizer.mask_token_id
self.skiplist = [
self.tokenizer.convert_tokens_to_ids(symbol)
for symbol in string.punctuation
]

# We override the config values with the ones provided by the user
if document_length is not None:
self.document_length = document_length
if query_length is not None:
self.query_length = query_length
if skiplist_words is not None:
self.skiplist_words = skiplist_words
# If no values are provided and there is no value in the config, use default values
if self.document_length is None:
self.document_length = 180
if self.query_length is None:
self.query_length = 32
if self.skiplist_words is None:
self.skiplist_words = [symbol for symbol in string.punctuation]
# Converting to ids (we do not store the ids in the config because it is less readable)
self.skiplist = [
self.tokenizer.convert_tokens_to_ids(word) for word in self.skiplist_words
]

def encode(
self,
Expand Down Expand Up @@ -1343,6 +1349,7 @@ def save(
config["query_length"] = self.query_length
config["document_length"] = self.document_length
config["attend_to_expansion_tokens"] = self.attend_to_expansion_tokens
config["skiplist_words"] = self.skiplist_words
json.dump(config, fOut, indent=2)

# Save modules
Expand Down Expand Up @@ -1754,6 +1761,8 @@ def _load_sbert_model(
self.attend_to_expansion_tokens = self._model_config[
"attend_to_expansion_tokens"
]
if "skiplist_words" in self._model_config:
self.skiplist_words = self._model_config["skiplist_words"]

# Check if a readme exists
model_card_path = load_file_path(
Expand Down

0 comments on commit 740d0cf

Please sign in to comment.