diff --git a/tableqa/clauses.py b/tableqa/clauses.py index b32d2e5..aa61850 100755 --- a/tableqa/clauses.py +++ b/tableqa/clauses.py @@ -1,28 +1,15 @@ -from tensorflow.keras.models import load_model -from sentence_transformers import SentenceTransformer -from numpy import asarray - - -class Clause: - def __init__(self): - self.bert_model = SentenceTransformer('bert-base-nli-mean-tokens') - self.model=load_model("Question_Classifier.h5") - self.types={0:'SELECT {} FROM {}', 1:'SELECT MAX({}) FROM {}', 2:'SELECT MIN({}) FROM {}', 3:'SELECT COUNT({}) FROM {}', 4:'SELECT SUM({}) FROM {}', 5:'SELECT AVG({}) FROM {}'} - - def adapt(self,q,inttype=False,priority=False): - emb=asarray(self.bert_model.encode(q)) - self.clause=self.types[self.model.predict_classes(emb)[0]] +#from nlp import qa +# class Clause: +# def __init__(self): - if priority and inttype and "COUNT" in self.clause: - self.clause= '''SELECT SUM({}) FROM {}''' - return self.clause - - - - - - - +# self.base_q="what is {} here" +# self.types={"the entity":'SELECT {} FROM {}', "the maximum":'SELECT MAX({}) FROM {}', "the minimum":'SELECT MIN({}) FROM {}', "counted":'SELECT COUNT({}) FROM {}', "summed":'SELECT SUM({}) FROM {}', "averaged":'SELECT AVG({}) FROM {}'} + +# def adapt(self,q,inttype=False,priority=False): +# scores={} +# for k,v in self.types.items(): +# scores[k]=qa(q,self.base_q.format(k),return_score=True)[1] +# return self.types[max(scores, key=scores.get)] - + \ No newline at end of file diff --git a/tableqa/nlp.py b/tableqa/nlp.py index 19ab897..48c5bff 100644 --- a/tableqa/nlp.py +++ b/tableqa/nlp.py @@ -2,12 +2,10 @@ from data_utils import data_utils import os from transformers import TFBertForQuestionAnswering, BertTokenizer -from transformers import TFAutoModelForSequenceClassification, AutoTokenizer import tensorflow as tf from rake_nltk import Rake import column_types import json -from clauses import Clause from conditionmaps import conditions @@ -15,6 +13,7 @@ qa_tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad') + import nltk from nltk.stem import WordNetLemmatizer from nltk.corpus import stopwords @@ -159,6 +158,22 @@ def _find(lst, sublst): def _window_overlap(s1, e1, s2, e2): return s2 <= e1 if s1