-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathIsFakePipeline.py
More file actions
38 lines (27 loc) · 1.21 KB
/
IsFakePipeline.py
File metadata and controls
38 lines (27 loc) · 1.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from typing import List, Tuple
import numpy
from transformers import Pipeline
class IsFakePipeline():
def predict(self, texts) -> List[float]:
pass
def getLeastFake(self, texts: List[str]) -> Tuple[List[str], float]:
classifications = numpy.array(self.predict(texts))
return texts[classifications.argmin()], classifications.min(initial=1)
class IsFakePipelineHF(Pipeline):
def predict(self, texts):
outputs = super().__call__(texts)
scores = numpy.exp(outputs) / numpy.exp(outputs).sum(-1, keepdims=True)
return numpy.array([item[0] for item in scores])
def getLeastFake(self, texts: List[str]) -> Tuple[List[str], float]:
classifications = numpy.array(self.predict(texts))
return texts[classifications.argmin()], classifications.min(initial=1)
class IsFakePipelineSklearn(IsFakePipeline):
def __init__(self, model, vectorizer):
super().__init__()
self.model = model
self.vectorizer = vectorizer
def predict(self, texts):
# if texts is not list:
# texts = [texts]
scores = self.model.predict_proba(self.vectorizer.transform(texts))
return numpy.array([item[1] for item in scores])