diff --git a/README.md b/README.md
index 7e41019..6de7cb4 100644
--- a/README.md
+++ b/README.md
@@ -171,6 +171,7 @@ HDF5 caching makes it possible to use the metadata with almost no memory usage.
* `--reorder_metadata_by_ivf_index True` option takes advantage of the data locality property of results of a knn ivf indices: it orders the metadata collection in order of the IVF clusters. That makes it possible to have much faster metadata retrieval as the reads are then accessing a few mostly sequential parts of the metadata instead of many non sequential parts. In practice that means being able to retrieve 1M items in 1s whereas only 1000 items can be retrieved in 1s without this method. This will order the metadata using the first image index.
* `--provide_safety_model True` will automatically download and load a [safety model](https://github.com/LAION-AI/CLIP-based-NSFW-Detector). You need to `pip install autokeras` optional dependency for this to work.
* `--provide_violence_detector True` will load a [violence detector](https://github.com/ml-research/OffImgDetectionCLIP), [paper](https://arxiv.org/abs/2202.06675.pdf)
+* `--provide_aesthetic_embeddings True` will load the [aesthetic embeddings](https://github.com/LAION-AI/aesthetic-predictor) and allow users to make the query move towards a nicer point of the clip space
These options can also be provided in the config file to have different options for each index. Example:
```json
diff --git a/clip_retrieval/clip_back.py b/clip_retrieval/clip_back.py
index 0eb666f..8aed937 100644
--- a/clip_retrieval/clip_back.py
+++ b/clip_retrieval/clip_back.py
@@ -7,6 +7,7 @@
from flask_cors import CORS
import faiss
from collections import defaultdict
+from multiprocessing.pool import ThreadPool
import json
from io import BytesIO
from PIL import Image
@@ -22,6 +23,7 @@
from functools import lru_cache
from werkzeug.middleware.dispatcher import DispatcherMiddleware
import pyarrow as pa
+import fsspec
import h5py
from tqdm import tqdm
@@ -178,7 +180,9 @@ def __init__(self, **kwargs):
super().__init__()
self.clip_resources = kwargs["clip_resources"]
- def compute_query(self, clip_resource, text_input, image_input, image_url_input, use_mclip):
+ def compute_query(
+ self, clip_resource, text_input, image_input, image_url_input, use_mclip, aesthetic_score, aesthetic_weight
+ ):
"""compute the query embedding"""
import torch # pylint: disable=import-outside-toplevel
import clip # pylint: disable=import-outside-toplevel
@@ -210,6 +214,11 @@ def compute_query(self, clip_resource, text_input, image_input, image_url_input,
image_features /= image_features.norm(dim=-1, keepdim=True)
query = image_features.cpu().detach().numpy().astype("float32")
+ if clip_resource.aesthetic_embeddings is not None and aesthetic_score is not None:
+ aesthetic_embedding = clip_resource.aesthetic_embeddings[aesthetic_score]
+ query = query + aesthetic_embedding * aesthetic_weight
+ query = query / np.linalg.norm(query)
+
return query
def hash_based_dedup(self, embeddings):
@@ -394,6 +403,8 @@ def query(
deduplicate=True,
use_safety_model=False,
use_violence_detector=False,
+ aesthetic_score=None,
+ aesthetic_weight=None,
):
"""implement the querying functionality of the knn service: from text and image to nearest neighbors"""
@@ -410,6 +421,8 @@ def query(
image_input=image_input,
image_url_input=image_url_input,
use_mclip=use_mclip,
+ aesthetic_score=aesthetic_score,
+ aesthetic_weight=aesthetic_weight,
)
distances, indices = self.knn_search(
query,
@@ -443,6 +456,10 @@ def post(self):
deduplicate = json_data.get("deduplicate", False)
use_safety_model = json_data.get("use_safety_model", False)
use_violence_detector = json_data.get("use_violence_detector", False)
+ aesthetic_score = json_data.get("aesthetic_score", "")
+ aesthetic_score = int(aesthetic_score) if aesthetic_score != "" else None
+ aesthetic_weight = json_data.get("aesthetic_weight", "")
+ aesthetic_weight = float(aesthetic_weight) if aesthetic_weight != "" else None
return self.query(
text_input,
image_input,
@@ -455,6 +472,8 @@ def post(self):
deduplicate,
use_safety_model,
use_violence_detector,
+ aesthetic_score,
+ aesthetic_weight,
)
@@ -618,6 +637,33 @@ def get_cache_folder(clip_model):
return cache_folder
+# needs to do this at load time
+@lru_cache(maxsize=None)
+def get_aesthetic_embedding(model_type):
+ """get aesthetic embedding"""
+ if model_type == "ViT-B/32":
+ model_type = "vit_b_32"
+ elif model_type == "ViT-L/14":
+ model_type = "vit_l_14"
+
+ fs, _ = fsspec.core.url_to_fs(
+ f"https://github.com/LAION-AI/aesthetic-predictor/blob/main/{model_type}_embeddings/rating0.npy?raw=true"
+ )
+ embs = {}
+ with ThreadPool(10) as pool:
+
+ def get(k):
+ with fs.open(
+ f"https://github.com/LAION-AI/aesthetic-predictor/blob/main/{model_type}_embeddings/rating{k}.npy?raw=true",
+ "rb",
+ ) as f:
+ embs[k] = np.load(f)
+
+ for _ in pool.imap_unordered(get, range(10)):
+ pass
+ return embs
+
+
@lru_cache(maxsize=None)
def load_violence_detector(clip_model):
"""load violence detector for this clip model"""
@@ -701,6 +747,7 @@ class ClipResource:
ivf_old_to_new_mapping: Any
columns_to_return: List[str]
metadata_is_ordered_by_ivf: bool
+ aesthetic_embeddings: Any
@dataclass
@@ -718,6 +765,7 @@ class ClipOptions:
use_arrow: bool
provide_safety_model: bool
provide_violence_detector: bool
+ provide_aesthetic_embeddings: bool
def dict_to_clip_options(d, clip_options):
@@ -743,6 +791,9 @@ def dict_to_clip_options(d, clip_options):
provide_violence_detector=d["provide_violence_detector"]
if "provide_violence_detector" in d
else clip_options.provide_violence_detector,
+ provide_aesthetic_embeddings=d["provide_aesthetic_embeddings"]
+ if "provide_aesthetic_embeddings" in d
+ else clip_options.provide_aesthetic_embeddings,
)
@@ -780,6 +831,9 @@ def load_clip_index(clip_options):
violence_detector = (
load_violence_detector(clip_options.clip_model) if clip_options.provide_violence_detector else None
)
+ aesthetic_embeddings = (
+ get_aesthetic_embedding(clip_options.clip_model) if clip_options.provide_aesthetic_embeddings else None
+ )
image_present = os.path.exists(clip_options.indice_folder + "/image.index")
text_present = os.path.exists(clip_options.indice_folder + "/text.index")
@@ -820,6 +874,7 @@ def load_clip_index(clip_options):
ivf_old_to_new_mapping=ivf_old_to_new_mapping if clip_options.reorder_metadata_by_ivf_index else None,
columns_to_return=clip_options.columns_to_return,
metadata_is_ordered_by_ivf=clip_options.reorder_metadata_by_ivf_index,
+ aesthetic_embeddings=aesthetic_embeddings,
)
@@ -864,6 +919,7 @@ def clip_back(
use_arrow=False,
provide_safety_model=False,
provide_violence_detector=False,
+ provide_aesthetic_embeddings=True,
):
"""main entry point of clip back, start the endpoints"""
print("starting boot of clip back")
@@ -883,6 +939,7 @@ def clip_back(
use_arrow=use_arrow,
provide_safety_model=provide_safety_model,
provide_violence_detector=provide_violence_detector,
+ provide_aesthetic_embeddings=provide_aesthetic_embeddings,
),
)
print("indices loaded")
diff --git a/front/src/clip-front.js b/front/src/clip-front.js
index 898a2bf..afbee1e 100644
--- a/front/src/clip-front.js
+++ b/front/src/clip-front.js
@@ -59,6 +59,8 @@ class ClipFront extends LitElement {
this.imageUrl = imageUrl === null ? undefined : imageUrl
this.hideDuplicateUrls = true
this.hideDuplicateImages = true
+ this.aestheticScore = '9'
+ this.aestheticWeight = '0.5'
this.initIndices()
}
@@ -99,7 +101,9 @@ class ClipFront extends LitElement {
removeViolence: { type: Boolean },
hideDuplicateUrls: { type: Boolean },
hideDuplicateImages: { type: Boolean },
- useMclip: { type: Boolean }
+ useMclip: { type: Boolean },
+ aestheticWeight: { type: String },
+ aestheticScore: { type: String }
}
}
@@ -149,7 +153,8 @@ class ClipFront extends LitElement {
}
}
if (_changedProperties.has('useMclip') || _changedProperties.has('modality') || _changedProperties.has('currentIndex') ||
- _changedProperties.has('hideDuplicateUrls') || _changedProperties.has('hideDuplicateImages') || _changedProperties.has('safeMode') || _changedProperties.has('removeViolence')) {
+ _changedProperties.has('hideDuplicateUrls') || _changedProperties.has('hideDuplicateImages') || _changedProperties.has('safeMode') ||
+ _changedProperties.has('removeViolence') || _changedProperties.has('aestheticScore') || _changedProperties.has('aestheticWeight')) {
if (this.image !== undefined || this.text !== '' || this.imageUrl !== undefined) {
this.redoSearch()
}
@@ -233,7 +238,7 @@ class ClipFront extends LitElement {
const imageUrl = this.imageUrl === undefined ? null : this.imageUrl
const count = this.modality === 'image' && this.currentIndex === this.indices[0] ? 10000 : 100
const results = await this.service.callClipService(text, image, imageUrl, this.modality, count,
- this.currentIndex, count, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence)
+ this.currentIndex, count, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence, this.aestheticScore, this.aestheticWeight)
downloadFile('clipsubset.json', JSON.stringify(results, null, 2))
}
@@ -244,7 +249,7 @@ class ClipFront extends LitElement {
this.image = undefined
this.imageUrl = undefined
const results = await this.service.callClipService(this.text, null, null, this.modality, this.numImages,
- this.currentIndex, this.numResultIds, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence)
+ this.currentIndex, this.numResultIds, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence, this.aestheticScore, this.aestheticWeight)
console.log(results)
this.images = results
this.lastMetadataId = Math.min(this.numImages, results.length) - 1
@@ -257,7 +262,7 @@ class ClipFront extends LitElement {
this.text = ''
this.imageUrl = undefined
const results = await this.service.callClipService(null, this.image, null, this.modality, this.numImages,
- this.currentIndex, this.numResultIds, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence)
+ this.currentIndex, this.numResultIds, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence, this.aestheticScore, this.aestheticWeight)
console.log(results)
this.images = results
this.lastMetadataId = Math.min(this.numImages, results.length) - 1
@@ -270,7 +275,7 @@ class ClipFront extends LitElement {
this.text = ''
this.image = undefined
const results = await this.service.callClipService(null, null, this.imageUrl, this.modality, this.numImages,
- this.currentIndex, this.numResultIds, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence)
+ this.currentIndex, this.numResultIds, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence, this.aestheticScore, this.aestheticWeight)
console.log(results)
this.images = results
this.lastMetadataId = Math.min(this.numImages, results.length) - 1
@@ -547,6 +552,9 @@ class ClipFront extends LitElement {
+
+