diff --git a/README.md b/README.md index 7e41019..6de7cb4 100644 --- a/README.md +++ b/README.md @@ -171,6 +171,7 @@ HDF5 caching makes it possible to use the metadata with almost no memory usage. * `--reorder_metadata_by_ivf_index True` option takes advantage of the data locality property of results of a knn ivf indices: it orders the metadata collection in order of the IVF clusters. That makes it possible to have much faster metadata retrieval as the reads are then accessing a few mostly sequential parts of the metadata instead of many non sequential parts. In practice that means being able to retrieve 1M items in 1s whereas only 1000 items can be retrieved in 1s without this method. This will order the metadata using the first image index. * `--provide_safety_model True` will automatically download and load a [safety model](https://github.com/LAION-AI/CLIP-based-NSFW-Detector). You need to `pip install autokeras` optional dependency for this to work. * `--provide_violence_detector True` will load a [violence detector](https://github.com/ml-research/OffImgDetectionCLIP), [paper](https://arxiv.org/abs/2202.06675.pdf) +* `--provide_aesthetic_embeddings True` will load the [aesthetic embeddings](https://github.com/LAION-AI/aesthetic-predictor) and allow users to make the query move towards a nicer point of the clip space These options can also be provided in the config file to have different options for each index. Example: ```json diff --git a/clip_retrieval/clip_back.py b/clip_retrieval/clip_back.py index 0eb666f..8aed937 100644 --- a/clip_retrieval/clip_back.py +++ b/clip_retrieval/clip_back.py @@ -7,6 +7,7 @@ from flask_cors import CORS import faiss from collections import defaultdict +from multiprocessing.pool import ThreadPool import json from io import BytesIO from PIL import Image @@ -22,6 +23,7 @@ from functools import lru_cache from werkzeug.middleware.dispatcher import DispatcherMiddleware import pyarrow as pa +import fsspec import h5py from tqdm import tqdm @@ -178,7 +180,9 @@ def __init__(self, **kwargs): super().__init__() self.clip_resources = kwargs["clip_resources"] - def compute_query(self, clip_resource, text_input, image_input, image_url_input, use_mclip): + def compute_query( + self, clip_resource, text_input, image_input, image_url_input, use_mclip, aesthetic_score, aesthetic_weight + ): """compute the query embedding""" import torch # pylint: disable=import-outside-toplevel import clip # pylint: disable=import-outside-toplevel @@ -210,6 +214,11 @@ def compute_query(self, clip_resource, text_input, image_input, image_url_input, image_features /= image_features.norm(dim=-1, keepdim=True) query = image_features.cpu().detach().numpy().astype("float32") + if clip_resource.aesthetic_embeddings is not None and aesthetic_score is not None: + aesthetic_embedding = clip_resource.aesthetic_embeddings[aesthetic_score] + query = query + aesthetic_embedding * aesthetic_weight + query = query / np.linalg.norm(query) + return query def hash_based_dedup(self, embeddings): @@ -394,6 +403,8 @@ def query( deduplicate=True, use_safety_model=False, use_violence_detector=False, + aesthetic_score=None, + aesthetic_weight=None, ): """implement the querying functionality of the knn service: from text and image to nearest neighbors""" @@ -410,6 +421,8 @@ def query( image_input=image_input, image_url_input=image_url_input, use_mclip=use_mclip, + aesthetic_score=aesthetic_score, + aesthetic_weight=aesthetic_weight, ) distances, indices = self.knn_search( query, @@ -443,6 +456,10 @@ def post(self): deduplicate = json_data.get("deduplicate", False) use_safety_model = json_data.get("use_safety_model", False) use_violence_detector = json_data.get("use_violence_detector", False) + aesthetic_score = json_data.get("aesthetic_score", "") + aesthetic_score = int(aesthetic_score) if aesthetic_score != "" else None + aesthetic_weight = json_data.get("aesthetic_weight", "") + aesthetic_weight = float(aesthetic_weight) if aesthetic_weight != "" else None return self.query( text_input, image_input, @@ -455,6 +472,8 @@ def post(self): deduplicate, use_safety_model, use_violence_detector, + aesthetic_score, + aesthetic_weight, ) @@ -618,6 +637,33 @@ def get_cache_folder(clip_model): return cache_folder +# needs to do this at load time +@lru_cache(maxsize=None) +def get_aesthetic_embedding(model_type): + """get aesthetic embedding""" + if model_type == "ViT-B/32": + model_type = "vit_b_32" + elif model_type == "ViT-L/14": + model_type = "vit_l_14" + + fs, _ = fsspec.core.url_to_fs( + f"https://github.com/LAION-AI/aesthetic-predictor/blob/main/{model_type}_embeddings/rating0.npy?raw=true" + ) + embs = {} + with ThreadPool(10) as pool: + + def get(k): + with fs.open( + f"https://github.com/LAION-AI/aesthetic-predictor/blob/main/{model_type}_embeddings/rating{k}.npy?raw=true", + "rb", + ) as f: + embs[k] = np.load(f) + + for _ in pool.imap_unordered(get, range(10)): + pass + return embs + + @lru_cache(maxsize=None) def load_violence_detector(clip_model): """load violence detector for this clip model""" @@ -701,6 +747,7 @@ class ClipResource: ivf_old_to_new_mapping: Any columns_to_return: List[str] metadata_is_ordered_by_ivf: bool + aesthetic_embeddings: Any @dataclass @@ -718,6 +765,7 @@ class ClipOptions: use_arrow: bool provide_safety_model: bool provide_violence_detector: bool + provide_aesthetic_embeddings: bool def dict_to_clip_options(d, clip_options): @@ -743,6 +791,9 @@ def dict_to_clip_options(d, clip_options): provide_violence_detector=d["provide_violence_detector"] if "provide_violence_detector" in d else clip_options.provide_violence_detector, + provide_aesthetic_embeddings=d["provide_aesthetic_embeddings"] + if "provide_aesthetic_embeddings" in d + else clip_options.provide_aesthetic_embeddings, ) @@ -780,6 +831,9 @@ def load_clip_index(clip_options): violence_detector = ( load_violence_detector(clip_options.clip_model) if clip_options.provide_violence_detector else None ) + aesthetic_embeddings = ( + get_aesthetic_embedding(clip_options.clip_model) if clip_options.provide_aesthetic_embeddings else None + ) image_present = os.path.exists(clip_options.indice_folder + "/image.index") text_present = os.path.exists(clip_options.indice_folder + "/text.index") @@ -820,6 +874,7 @@ def load_clip_index(clip_options): ivf_old_to_new_mapping=ivf_old_to_new_mapping if clip_options.reorder_metadata_by_ivf_index else None, columns_to_return=clip_options.columns_to_return, metadata_is_ordered_by_ivf=clip_options.reorder_metadata_by_ivf_index, + aesthetic_embeddings=aesthetic_embeddings, ) @@ -864,6 +919,7 @@ def clip_back( use_arrow=False, provide_safety_model=False, provide_violence_detector=False, + provide_aesthetic_embeddings=True, ): """main entry point of clip back, start the endpoints""" print("starting boot of clip back") @@ -883,6 +939,7 @@ def clip_back( use_arrow=use_arrow, provide_safety_model=provide_safety_model, provide_violence_detector=provide_violence_detector, + provide_aesthetic_embeddings=provide_aesthetic_embeddings, ), ) print("indices loaded") diff --git a/front/src/clip-front.js b/front/src/clip-front.js index 898a2bf..afbee1e 100644 --- a/front/src/clip-front.js +++ b/front/src/clip-front.js @@ -59,6 +59,8 @@ class ClipFront extends LitElement { this.imageUrl = imageUrl === null ? undefined : imageUrl this.hideDuplicateUrls = true this.hideDuplicateImages = true + this.aestheticScore = '9' + this.aestheticWeight = '0.5' this.initIndices() } @@ -99,7 +101,9 @@ class ClipFront extends LitElement { removeViolence: { type: Boolean }, hideDuplicateUrls: { type: Boolean }, hideDuplicateImages: { type: Boolean }, - useMclip: { type: Boolean } + useMclip: { type: Boolean }, + aestheticWeight: { type: String }, + aestheticScore: { type: String } } } @@ -149,7 +153,8 @@ class ClipFront extends LitElement { } } if (_changedProperties.has('useMclip') || _changedProperties.has('modality') || _changedProperties.has('currentIndex') || - _changedProperties.has('hideDuplicateUrls') || _changedProperties.has('hideDuplicateImages') || _changedProperties.has('safeMode') || _changedProperties.has('removeViolence')) { + _changedProperties.has('hideDuplicateUrls') || _changedProperties.has('hideDuplicateImages') || _changedProperties.has('safeMode') || + _changedProperties.has('removeViolence') || _changedProperties.has('aestheticScore') || _changedProperties.has('aestheticWeight')) { if (this.image !== undefined || this.text !== '' || this.imageUrl !== undefined) { this.redoSearch() } @@ -233,7 +238,7 @@ class ClipFront extends LitElement { const imageUrl = this.imageUrl === undefined ? null : this.imageUrl const count = this.modality === 'image' && this.currentIndex === this.indices[0] ? 10000 : 100 const results = await this.service.callClipService(text, image, imageUrl, this.modality, count, - this.currentIndex, count, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence) + this.currentIndex, count, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence, this.aestheticScore, this.aestheticWeight) downloadFile('clipsubset.json', JSON.stringify(results, null, 2)) } @@ -244,7 +249,7 @@ class ClipFront extends LitElement { this.image = undefined this.imageUrl = undefined const results = await this.service.callClipService(this.text, null, null, this.modality, this.numImages, - this.currentIndex, this.numResultIds, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence) + this.currentIndex, this.numResultIds, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence, this.aestheticScore, this.aestheticWeight) console.log(results) this.images = results this.lastMetadataId = Math.min(this.numImages, results.length) - 1 @@ -257,7 +262,7 @@ class ClipFront extends LitElement { this.text = '' this.imageUrl = undefined const results = await this.service.callClipService(null, this.image, null, this.modality, this.numImages, - this.currentIndex, this.numResultIds, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence) + this.currentIndex, this.numResultIds, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence, this.aestheticScore, this.aestheticWeight) console.log(results) this.images = results this.lastMetadataId = Math.min(this.numImages, results.length) - 1 @@ -270,7 +275,7 @@ class ClipFront extends LitElement { this.text = '' this.image = undefined const results = await this.service.callClipService(null, null, this.imageUrl, this.modality, this.numImages, - this.currentIndex, this.numResultIds, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence) + this.currentIndex, this.numResultIds, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence, this.aestheticScore, this.aestheticWeight) console.log(results) this.images = results this.lastMetadataId = Math.min(this.numImages, results.length) - 1 @@ -547,6 +552,9 @@ class ClipFront extends LitElement {


+
+