diff --git a/clip_interrogator.ipynb b/clip_interrogator.ipynb index 440fc2cb..e56fdefe 100755 --- a/clip_interrogator.ipynb +++ b/clip_interrogator.ipynb @@ -7,7 +7,7 @@ "id": "3jm8RYrLqvzz" }, "source": [ - "# CLIP Interrogator 2.3 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n", + "# CLIP Interrogator 2.4 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n", "\n", "Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n", "\n", @@ -29,7 +29,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "cellView": "form", "id": "aP9FjmWxtLKJ" @@ -42,7 +42,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "cellView": "form", "id": "xpPKQR40qvz2" @@ -54,8 +54,7 @@ "\n", "def setup():\n", " install_cmds = [\n", - " ['pip', 'install', 'transformers==4.15.0'],\n", - " ['pip', 'install', 'gradio'],\n", + " ['pip', 'install', 'gradio'],\n", " ['pip', 'install', 'open_clip_torch'],\n", " ['pip', 'install', 'clip-interrogator'],\n", " ]\n", @@ -65,16 +64,15 @@ "setup()\n", "\n", "\n", + "caption_model_name = 'blip-large' #@param [\"blip-base\", \"blip-large\", \"git-large-coco\"]\n", "clip_model_name = 'ViT-L-14/openai' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n", "\n", - "\n", "import gradio as gr\n", "from clip_interrogator import Config, Interrogator\n", "\n", "config = Config()\n", - "config.blip_num_beams = 64\n", - "config.blip_offload = False\n", "config.clip_model_name = clip_model_name\n", + "config.caption_model_name = caption_model_name\n", "ci = Interrogator(config)\n", "\n", "def image_analysis(image):\n", @@ -112,7 +110,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": { "cellView": "form", "colab": { @@ -122,40 +120,7 @@ "id": "Pf6qkFG6MPRj", "outputId": "8d542b56-8be7-453d-bf27-d0540a774c7d" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`\n", - "\n", - "Using Embedded Colab Mode (NEW). If you have issues, please use share=True and file an issue at https://github.com/gradio-app/gradio/\n", - "Note: opening the browser inspector may crash Embedded Colab Mode.\n", - "\n", - "To create a public link, set `share=True` in `launch()`.\n" - ] - }, - { - "data": { - "application/javascript": "(async (port, path, width, height, cache, element) => {\n if (!google.colab.kernel.accessAllowed && !cache) {\n return;\n }\n element.appendChild(document.createTextNode(''));\n const url = await google.colab.kernel.proxyPort(port, {cache});\n\n const external_link = document.createElement('div');\n external_link.innerHTML = `\n
\n Running on \n https://localhost:${port}${path}\n \n
\n `;\n element.appendChild(external_link);\n\n const iframe = document.createElement('iframe');\n iframe.src = new URL(path, url).toString();\n iframe.height = height;\n iframe.allow = \"autoplay; camera; microphone; clipboard-read; clipboard-write;\"\n iframe.width = width;\n iframe.style.border = 0;\n element.appendChild(iframe);\n })(7860, \"/\", \"100%\", 500, false, window.element)", - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(, 'http://127.0.0.1:7860/', None)" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "#@title Image to prompt! 🖼️ -> 📝\n", " \n", @@ -291,7 +256,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.15 (default, Nov 24 2022, 18:44:54) [MSC v.1916 64 bit (AMD64)]" + "version": "3.9.5" }, "orig_nbformat": 4, "vscode": { diff --git a/clip_interrogator/__init__.py b/clip_interrogator/__init__.py index 9a2936ad..d560ce3a 100644 --- a/clip_interrogator/__init__.py +++ b/clip_interrogator/__init__.py @@ -1,4 +1,4 @@ -from .clip_interrogator import Config, Interrogator, LabelTable, load_list +from .clip_interrogator import Config, Interrogator, LabelTable, list_caption_models, list_clip_models, load_list -__version__ = '0.5.5' +__version__ = '0.6.0' __author__ = 'pharmapsychotic' \ No newline at end of file diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py index 5d936fe1..e7fcb5a3 100644 --- a/clip_interrogator/clip_interrogator.py +++ b/clip_interrogator/clip_interrogator.py @@ -1,5 +1,4 @@ import hashlib -import inspect import math import numpy as np import open_clip @@ -9,18 +8,19 @@ import torch from dataclasses import dataclass -from blip.models.blip import blip_decoder, BLIP_Decoder from PIL import Image -from torchvision import transforms -from torchvision.transforms.functional import InterpolationMode +from transformers import AutoProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, Blip2ForConditionalGeneration from tqdm import tqdm from typing import List, Optional from safetensors.numpy import load_file, save_file -BLIP_MODELS = { - 'base': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth', - 'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' +CAPTION_MODELS = { + 'blip-base': 'Salesforce/blip-image-captioning-base', # 990MB + 'blip-large': 'Salesforce/blip-image-captioning-large', # 1.9GB + 'blip2-2.7b': 'Salesforce/blip2-opt-2.7b', # 15.5GB + 'blip2-flan-t5-xl': 'Salesforce/blip2-flan-t5-xl', # 15.77GB + 'git-large-coco': 'microsoft/git-large-coco', # 1.58GB } CACHE_URL_BASE = 'https://huggingface.co/pharma/ci-preprocess/resolve/main/' @@ -29,16 +29,15 @@ @dataclass class Config: # models can optionally be passed in directly - blip_model: Optional[BLIP_Decoder] = None + caption_model = None + caption_processor = None clip_model = None clip_preprocess = None # blip settings - blip_image_eval_size: int = 384 - blip_max_length: int = 32 - blip_model_type: Optional[str] = 'large' # use 'base', 'large' or None - blip_num_beams: int = 8 - blip_offload: bool = False + caption_max_length: int = 32 + caption_model_name: Optional[str] = 'blip-large' # use a key from CAPTION_MODELS or None + caption_offload: bool = False # clip settings clip_model_name: str = 'ViT-L-14/openai' @@ -55,8 +54,8 @@ class Config: quiet: bool = False # when quiet progress bars are not shown def apply_low_vram_defaults(self): - self.blip_model_type = 'base' - self.blip_offload = True + self.caption_model_name = 'blip-base' + self.caption_offload = True self.clip_offload = True self.chunk_size = 1024 self.flavor_intermediate_count = 1024 @@ -65,29 +64,33 @@ class Interrogator(): def __init__(self, config: Config): self.config = config self.device = config.device - self.blip_offloaded = True + self.dtype = torch.float16 if self.device == 'cuda' else torch.float32 + self.caption_offloaded = True self.clip_offloaded = True + self.load_caption_model() + self.load_clip_model() - if config.blip_model is None and config.blip_model_type: - if not config.quiet: - print("Loading BLIP model...") - blip_path = os.path.dirname(inspect.getfile(blip_decoder)) - configs_path = os.path.join(os.path.dirname(blip_path), 'configs') - med_config = os.path.join(configs_path, 'med_config.json') - blip_model = blip_decoder( - pretrained=BLIP_MODELS[config.blip_model_type], - image_size=config.blip_image_eval_size, - vit=config.blip_model_type, - med_config=med_config - ) - blip_model.eval() - if not self.config.blip_offload: - blip_model = blip_model.to(config.device) - self.blip_model = blip_model + def load_caption_model(self): + if self.config.caption_model is None and self.config.caption_model_name: + if not self.config.quiet: + print(f"Loading caption model {self.config.caption_model_name}...") + + model_path = CAPTION_MODELS[self.config.caption_model_name] + if self.config.caption_model_name.startswith('git-'): + caption_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32) + elif self.config.caption_model_name.startswith('blip2-'): + caption_model = Blip2ForConditionalGeneration.from_pretrained(model_path, torch_dtype=self.dtype) + else: + caption_model = BlipForConditionalGeneration.from_pretrained(model_path, torch_dtype=self.dtype) + self.caption_processor = AutoProcessor.from_pretrained(model_path) + + caption_model.eval() + if not self.config.caption_offload: + caption_model = caption_model.to(self.config.device) + self.caption_model = caption_model else: - self.blip_model = config.blip_model - - self.load_clip_model() + self.caption_model = self.config.caption_model + self.caption_processor = self.config.caption_processor def load_clip_model(self): start_time = time.time() @@ -97,7 +100,7 @@ def load_clip_model(self): if config.clip_model is None: if not config.quiet: - print("Loading CLIP model...") + print(f"Loading CLIP model {config.clip_model_name}...") self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms( clip_model_name, @@ -183,26 +186,13 @@ def check(addition: str, idx: int) -> bool: return best_prompt def generate_caption(self, pil_image: Image) -> str: - assert self.blip_model is not None, "No BLIP model loaded." - self._prepare_blip() - - size = self.config.blip_image_eval_size - gpu_image = transforms.Compose([ - transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC), - transforms.ToTensor(), - transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) - ])(pil_image).unsqueeze(0).to(self.device) - - with torch.no_grad(): - caption = self.blip_model.generate( - gpu_image, - sample=False, - num_beams=self.config.blip_num_beams, - max_length=self.config.blip_max_length, - min_length=5 - ) - - return caption[0] + assert self.caption_model is not None, "No caption model loaded." + self._prepare_caption() + inputs = self.caption_processor(images=pil_image, return_tensors="pt").to(self.device) + if not self.config.caption_model_name.startswith('git-'): + inputs = inputs.to(self.dtype) + tokens = self.caption_model.generate(**inputs, max_new_tokens=self.config.caption_max_length) + return self.caption_processor.batch_decode(tokens, skip_special_tokens=True)[0].strip() def image_to_features(self, image: Image) -> torch.Tensor: self._prepare_clip() @@ -237,7 +227,7 @@ def interrogate_fast(self, image: Image, max_flavors: int=32, caption: Optional[ are less readable.""" caption = caption or self.generate_caption(image) image_features = self.image_to_features(image) - merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) + merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self) tops = merged.rank(image_features, max_flavors) return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize) @@ -254,7 +244,7 @@ def interrogate(self, image: Image, min_flavors: int=8, max_flavors: int=32, cap caption = caption or self.generate_caption(image) image_features = self.image_to_features(image) - merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) + merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self) flaves = merged.rank(image_features, self.config.flavor_intermediate_count) best_prompt, best_sim = caption, self.similarity(image_features, caption) best_prompt = self.chain(image_features, flaves, best_prompt, best_sim, min_count=min_flavors, max_count=max_flavors, desc="Flavor chain") @@ -293,18 +283,18 @@ def similarities(self, image_features: torch.Tensor, text_array: List[str]) -> L similarity = text_features @ image_features.T return similarity.T[0].tolist() - def _prepare_blip(self): + def _prepare_caption(self): if self.config.clip_offload and not self.clip_offloaded: self.clip_model = self.clip_model.to('cpu') self.clip_offloaded = True - if self.blip_offloaded: - self.blip_model = self.blip_model.to(self.device) - self.blip_offloaded = False + if self.caption_offloaded: + self.caption_model = self.caption_model.to(self.device) + self.caption_offloaded = False def _prepare_clip(self): - if self.config.blip_offload and not self.blip_offloaded: - self.blip_model = self.blip_model.to('cpu') - self.blip_offloaded = True + if self.config.caption_offload and not self.caption_offloaded: + self.caption_model = self.caption_model.to('cpu') + self.caption_offloaded = True if self.clip_offloaded: self.clip_model = self.clip_model.to(self.device) self.clip_offloaded = False @@ -425,8 +415,8 @@ def _download_file(url: str, filepath: str, chunk_size: int = 4*1024*1024, quiet progress.update(len(chunk)) progress.close() -def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable: - m = LabelTable([], None, None, None, config) +def _merge_tables(tables: List[LabelTable], ci: Interrogator) -> LabelTable: + m = LabelTable([], None, ci) for table in tables: m.labels.extend(table.labels) m.embeds.extend(table.embeds) @@ -445,6 +435,12 @@ def _truncate_to_fit(text: str, tokenize) -> str: new_text += ', ' + part return new_text +def list_caption_models() -> List[str]: + return list(CAPTION_MODELS.keys()) + +def list_clip_models() -> List[str]: + return ['/'.join(x) for x in open_clip.list_pretrained()] + def load_list(data_path: str, filename: Optional[str] = None) -> List[str]: """Load a list of strings from a file.""" if filename is not None: diff --git a/requirements.txt b/requirements.txt index 1c73285b..d6ff0905 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,5 +5,5 @@ requests safetensors tqdm open_clip_torch -blip-ci -transformers>=4.15.0,<=4.26.1 +accelerate +transformers>=4.27.1 \ No newline at end of file diff --git a/run_cli.py b/run_cli.py index 59d563d8..b1a5ef7c 100755 --- a/run_cli.py +++ b/run_cli.py @@ -1,12 +1,11 @@ #!/usr/bin/env python3 import argparse import csv -import open_clip import os import requests import torch from PIL import Image -from clip_interrogator import Interrogator, Config +from clip_interrogator import Interrogator, Config, list_clip_models def inference(ci, image, mode): image = image.convert('RGB') @@ -36,7 +35,7 @@ def main(): exit(1) # validate clip model name - models = ['/'.join(x) for x in open_clip.list_pretrained()] + models = list_clip_models() if args.clip not in models: print(f"Could not find CLIP model {args.clip}!") print(f" available models: {models}") diff --git a/run_gradio.py b/run_gradio.py index 938171ae..0a178bde 100755 --- a/run_gradio.py +++ b/run_gradio.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 import argparse -import open_clip import torch -from clip_interrogator import Config, Interrogator +from clip_interrogator import Config, Interrogator, list_caption_models, list_clip_models try: import gradio as gr @@ -45,7 +44,11 @@ def image_analysis(image, clip_model_name): return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks -def image_to_prompt(image, mode, clip_model_name): +def image_to_prompt(image, mode, clip_model_name, blip_model_name): + if blip_model_name != ci.config.caption_model_name: + ci.config.caption_model_name = blip_model_name + ci.load_caption_model() + if clip_model_name != ci.config.clip_model_name: ci.config.clip_model_name = clip_model_name ci.load_clip_model() @@ -60,25 +63,23 @@ def image_to_prompt(image, mode, clip_model_name): elif mode == 'negative': return ci.interrogate_negative(image) - -models = ['/'.join(x) for x in open_clip.list_pretrained()] - def prompt_tab(): with gr.Column(): with gr.Row(): image = gr.Image(type='pil', label="Image") with gr.Column(): mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best') - model = gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model') + clip_model = gr.Dropdown(list_clip_models(), value=ci.config.clip_model_name, label='CLIP Model') + blip_model = gr.Dropdown(list_caption_models(), value=ci.config.caption_model_name, label='Caption Model') prompt = gr.Textbox(label="Prompt") button = gr.Button("Generate prompt") - button.click(image_to_prompt, inputs=[image, mode, model], outputs=prompt) + button.click(image_to_prompt, inputs=[image, mode, clip_model, blip_model], outputs=prompt) def analyze_tab(): with gr.Column(): with gr.Row(): image = gr.Image(type='pil', label="Image") - model = gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model') + model = gr.Dropdown(list_clip_models(), value='ViT-L-14/openai', label='CLIP Model') with gr.Row(): medium = gr.Label(label="Medium", num_top_classes=5) artist = gr.Label(label="Artist", num_top_classes=5) diff --git a/setup.py b/setup.py index f2806e0b..a6db97a0 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="clip-interrogator", - version="0.5.5", + version="0.6.0", license='MIT', author='pharmapsychotic', author_email='me@pharmapsychotic.com',