From 7c0e67b4aabb3578d2547b06b20b9c4bffa9a942 Mon Sep 17 00:00:00 2001 From: "George A. McCarthy" Date: Thu, 5 Aug 2021 20:49:56 +0000 Subject: [PATCH 1/2] internal: ignore local models --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 8a38644..d704caa 100644 --- a/.gitignore +++ b/.gitignore @@ -134,3 +134,4 @@ dmypy.json data/* embeddings/* results/* +models/* From e930d84ac2937afb91d6d633fabbe65f0bc8a7a1 Mon Sep 17 00:00:00 2001 From: "George A. McCarthy" Date: Tue, 10 Aug 2021 13:05:30 +0000 Subject: [PATCH 2/2] feat: load or download model and tokenizor --- backend/my_executors.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/backend/my_executors.py b/backend/my_executors.py index ee15566..4bc0ca9 100644 --- a/backend/my_executors.py +++ b/backend/my_executors.py @@ -7,7 +7,7 @@ from transformers import BertModel, BertTokenizer from jina import Executor, requests, Document, DocumentArray -from backend_config import top_k, embeddings_path +from backend_config import top_k, embeddings_path from utils import partition from helpers import log @@ -19,12 +19,17 @@ def __init__(self, **kwargs): log("Initialising ProtBertExecutor.") super().__init__() + model_path = "../models/prot_bert" + if not os.path.exists(model_path): + log(f"Downloading model {model_path}.") + model_path = "Rostlab/prot_bert" + else: + log(f"Using local model: {model_path}") + log("Setting tokenizer.") - tokenizer = BertTokenizer.from_pretrained( - "Rostlab/prot_bert", do_lower_case=False - ) + tokenizer = BertTokenizer.from_pretrained(model_path, do_lower_case=False) log("Setting model.") - model = BertModel.from_pretrained("Rostlab/prot_bert") + model = BertModel.from_pretrained(model_path) self.tokenizer = tokenizer self.model = model