diff --git a/colbert/modeling/base_colbert.py b/colbert/modeling/base_colbert.py index 19084172..12cf68c5 100644 --- a/colbert/modeling/base_colbert.py +++ b/colbert/modeling/base_colbert.py @@ -26,10 +26,12 @@ def __init__(self, name_or_path, colbert_config=None): try: HF_ColBERT = class_factory(self.name) except: - HF_ColBERT = class_factory('bert-base-uncased') + self.name = 'bert-base-uncased' # TODO: Double check that this is appropriate here in all cases + HF_ColBERT = class_factory(self.name) - assert self.name is not None - HF_ColBERT = class_factory(self.name) + # assert self.name is not None + # HF_ColBERT = class_factory(self.name) + self.model = HF_ColBERT.from_pretrained(name_or_path, colbert_config=self.colbert_config) self.raw_tokenizer = AutoTokenizer.from_pretrained(name_or_path)