-
Notifications
You must be signed in to change notification settings - Fork 432
/
Copy pathpredict.py
42 lines (37 loc) · 1.44 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import sys
from PIL import Image
from cog import BasePredictor, Input, Path
from clip_interrogator import Interrogator, Config
class Predictor(BasePredictor):
def setup(self):
self.ci = Interrogator(Config(
blip_model_url='cache/model_large_caption.pth',
clip_model_name="ViT-L-14/openai",
clip_model_path='cache',
device='cuda:0',
))
def predict(
self,
image: Path = Input(description="Input image"),
clip_model_name: str = Input(
default="ViT-L-14/openai",
choices=["ViT-L-14/openai", "ViT-H-14/laion2b_s32b_b79k"],
description="Choose ViT-L for Stable Diffusion 1, and ViT-H for Stable Diffusion 2",
),
mode: str = Input(
default="best",
choices=["best", "fast"],
description="Prompt mode (best takes 10-20 seconds, fast takes 1-2 seconds).",
),
) -> str:
"""Run a single prediction on the model"""
image = Image.open(str(image)).convert("RGB")
self.switch_model(clip_model_name)
if mode == "best":
return self.ci.interrogate(image)
else:
return self.ci.interrogate_fast(image)
def switch_model(self, clip_model_name: str):
if clip_model_name != self.ci.config.clip_model_name:
self.ci.config.clip_model_name = clip_model_name
self.ci.load_clip_model()