Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lightly_studio/.python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.8
3.12
4 changes: 4 additions & 0 deletions lightly_studio/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ dependencies = [
"av>=10.0.0",
"opencv-python>=4.11.0.86",
"requests>=2.32.3",
"einops>=0.8.1",
"onnxruntime>=1.20.1",
"onnx>=1.17.0",
"onnxscript>=0.2.7",
]

[project.optional-dependencies]
Expand Down
11 changes: 11 additions & 0 deletions lightly_studio/src/lightly_studio/dataset/embedding_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,17 @@ def _load_embedding_generator_from_env() -> EmbeddingGenerator | None:
except ImportError:
print("Embedding functionality is disabled.")
return None
elif env.LIGHTLY_STUDIO_EMBEDDINGS_MODEL_TYPE == "PE":
try:
from lightly_studio.dataset.pe_embedding_generator_onnx import (
PEEmbeddingGenerator,
)

print("Using PE embedding generator.")
return PEEmbeddingGenerator()
except ImportError:
print("Embedding functionality is disabled.")
return None

print(
f"Unsupported model type: '{env.LIGHTLY_STUDIO_EMBEDDINGS_MODEL_TYPE}'",
Expand Down
139 changes: 139 additions & 0 deletions lightly_studio/src/lightly_studio/dataset/pe_embedding_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""PE embedding generator."""

from __future__ import annotations

from typing import Callable
from uuid import UUID

import fsspec
import numpy as np
import torch
from numpy.typing import NDArray
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from lightly_studio.models.embedding_model import EmbeddingModelCreate
from lightly_studio.vendor.pe.vision_encoder import pe, transforms

from .embedding_generator import EmbeddingGenerator

MODEL_NAME = "PE-Core-T16-384"
MAX_BATCH_SIZE: int = 16
EMBEDDING_DIMENSION: int = 512

# Dataset for efficient batched image loading and preprocessing
class _ImageFileDataset(Dataset[torch.Tensor]):
"""Dataset wrapping image file paths and a preprocess function."""

def __init__(
self,
filepaths: list[str],
preprocess: Callable[[Image.Image], torch.Tensor],
) -> None:
self.filepaths = filepaths
self.preprocess = preprocess

def __len__(self) -> int:
return len(self.filepaths)

def __getitem__(self, idx: int) -> torch.Tensor:
with fsspec.open(self.filepaths[idx], "rb") as file:
image = Image.open(file).convert("RGB")
return self.preprocess(image)


class PEEmbeddingGenerator(EmbeddingGenerator):
"""PE embedding model."""

def __init__(self) -> None:
"""Initialize the PE embedding model.

This method loads the PE model and its tokenizer. The model
checkpoint is downloaded and cached locally for future use.
"""
self._model = pe.CLIP.from_config(MODEL_NAME, pretrained=True)
self._preprocess = transforms.get_image_transform(self._model.image_size)
self._tokenizer = transforms.get_text_tokenizer(self._model.context_length)

# Auto select device: CUDA > MPS (Apple Silicon) > CPU
self._device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
self._model = self._model.to(self._device)
self._model_hash = "abc"

def get_embedding_model_input(self, dataset_id: UUID) -> EmbeddingModelCreate:
"""Generate an EmbeddingModelCreate instance.

Args:
dataset_id: The ID of the dataset.

Returns:
An EmbeddingModelCreate instance with the model details.
"""
return EmbeddingModelCreate(
name=MODEL_NAME,
embedding_model_hash=self._model_hash,
embedding_dimension=EMBEDDING_DIMENSION,
dataset_id=dataset_id,
)

def embed_text(self, text: str) -> list[float]:
"""Embed a text with PE.

Args:
text: The text to embed.

Returns:
A list of floats representing the generated embedding.
"""
tokenized = self._tokenizer([text]).to(self._device)
with torch.no_grad():
embedding = self._model.encode_text(tokenized, normalize=True)[0]
# Convert embedding to list of floats.
embedding_list: list[float] = embedding.cpu().numpy().flatten().tolist()
return embedding_list

def embed_images(self, filepaths: list[str]) -> NDArray[np.float32]:
"""Embed images with PE.

Args:
filepaths: A list of file paths to the images to embed.

Returns:
A numpy array representing the generated embeddings
in the same order as the input file paths.
"""
dataset = _ImageFileDataset(filepaths, self._preprocess)

# To avoid issues with db locking and multiprocessing we set the
# number of workers to 0 (no multiprocessing). The DataLoader is still
# very useful for batching and async prefetching of images.
loader = DataLoader(
dataset,
batch_size=MAX_BATCH_SIZE,
num_workers=0, # must be 0 to avoid multiprocessing issues
)
total_images = len(filepaths)
if not total_images:
return np.empty((0, EMBEDDING_DIMENSION), dtype=np.float32)

embeddings = np.empty((total_images, EMBEDDING_DIMENSION), dtype=np.float32)
position = 0
with tqdm(
total=total_images, desc="Generating embeddings", unit=" images"
) as progress_bar, torch.no_grad():
for images_tensor in loader:
imgs = images_tensor.to(self._device, non_blocking=True)
batch_embeddings = self._model.encode_image(imgs, normalize=True).cpu().numpy()
batch_size = imgs.size(0)
embeddings[position : position + batch_size] = batch_embeddings
position += batch_size
progress_bar.update(batch_size)

return embeddings
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""PE embedding generator."""

from __future__ import annotations

import copy
from typing import Callable
from uuid import UUID

import fsspec
import numpy as np
import onnxruntime as ort
import torch
from numpy.typing import NDArray
from PIL import Image
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from lightly_studio.models.embedding_model import EmbeddingModelCreate
from lightly_studio.vendor.pe.vision_encoder import pe, transforms

from .embedding_generator import EmbeddingGenerator

MODEL_NAME = "PE-Core-T16-384"
MAX_BATCH_SIZE: int = 16
EMBEDDING_DIMENSION: int = 512

class PEVision(torch.nn.Module):
def __init__(
self,
CLIP_model,
):
super().__init__()
self.model = CLIP_model.visual
self.image_size = CLIP_model.image_size
self.logit_scale = CLIP_model.logit_scale

def forward(self, image):
x = self.model(image)
return F.normalize(x, dim=-1)

class PEText(torch.nn.Module):
def __init__(
self,
CLIP_model,
):
super().__init__()
self.model = copy.deepcopy(CLIP_model)
delattr(self.model, "visual")

def forward(self, text):
return self.model.encode_text(text, normalize=True)

# Dataset for efficient batched image loading and preprocessing
class _ImageFileDataset(Dataset[torch.Tensor]):
"""Dataset wrapping image file paths and a preprocess function."""

def __init__(
self,
filepaths: list[str],
preprocess: Callable[[Image.Image], torch.Tensor],
) -> None:
self.filepaths = filepaths
self.preprocess = preprocess

def __len__(self) -> int:
return len(self.filepaths)

def __getitem__(self, idx: int) -> torch.Tensor:
with fsspec.open(self.filepaths[idx], "rb") as file:
image = Image.open(file).convert("RGB")
return self.preprocess(image).half()


class PEEmbeddingGenerator(EmbeddingGenerator):
"""PE embedding model."""

def __init__(self) -> None:
"""Initialize the PE embedding model.

This method loads the PE model and its tokenizer. The model
checkpoint is downloaded and cached locally for future use.
"""
CLIP_model = pe.CLIP.from_config(MODEL_NAME, pretrained=True).half()
model_vision = PEVision(CLIP_model=CLIP_model)
model_text = PEText(CLIP_model=CLIP_model)
example_inputs = (torch.randn(MAX_BATCH_SIZE, 3, CLIP_model.image_size, CLIP_model.image_size).half(),)
torch.onnx.export(
model=model_vision,
args=example_inputs,
f="visual_PE.onnx",
input_names = ['image'], # the model's input name
output_names = ['embedding'], # the model's output name
dynamic_axes={"image": {0: "batch_size"}, "embedding": {0: "batch_size"}})
self._model_vision = ort.InferenceSession("visual_PE.onnx")

example_inputs = (torch.rand(1, CLIP_model.context_length).long(),)
torch.onnx.export(
model=model_text,
args=example_inputs,
f="textual_PE.onnx",
input_names = ['text'], # the model's input name
output_names = ['embedding_text'], # the model's output name
dynamic_axes={"text": {0: "batch_size"}, "embedding_text": {0: "batch_size"}})
self._model_text = ort.InferenceSession("textual_PE.onnx")

self._preprocess = transforms.get_image_transform(CLIP_model.image_size)
self._tokenizer = transforms.get_text_tokenizer(CLIP_model.context_length)

self._model_hash = "abc"

def get_embedding_model_input(self, dataset_id: UUID) -> EmbeddingModelCreate:
"""Generate an EmbeddingModelCreate instance.

Args:
dataset_id: The ID of the dataset.

Returns:
An EmbeddingModelCreate instance with the model details.
"""
return EmbeddingModelCreate(
name=MODEL_NAME,
embedding_model_hash=self._model_hash,
embedding_dimension=EMBEDDING_DIMENSION,
dataset_id=dataset_id,
)

def embed_text(self, text: str) -> list[float]:
"""Embed a text with PE.

Args:
text: The text to embed.

Returns:
A list of floats representing the generated embedding.
"""
tokenized = self._tokenizer([text]).long().cpu().numpy()
with torch.no_grad():
embedding = self._model_text.run(None, {"text":tokenized})[0]
# Convert embedding to list of floats.
embedding_list: list[float] = embedding.flatten().tolist()
return embedding_list

def embed_images(self, filepaths: list[str]) -> NDArray[np.float32]:
"""Embed images with PE.

Args:
filepaths: A list of file paths to the images to embed.

Returns:
A numpy array representing the generated embeddings
in the same order as the input file paths.
"""
dataset = _ImageFileDataset(filepaths, self._preprocess)

# To avoid issues with db locking and multiprocessing we set the
# number of workers to 0 (no multiprocessing). The DataLoader is still
# very useful for batching and async prefetching of images.
loader = DataLoader(
dataset,
batch_size=MAX_BATCH_SIZE,
num_workers=0, # must be 0 to avoid multiprocessing issues
)
total_images = len(filepaths)
if not total_images:
return np.empty((0, EMBEDDING_DIMENSION), dtype=np.float32)

embeddings = np.empty((total_images, EMBEDDING_DIMENSION), dtype=np.float32)
position = 0
with tqdm(
total=total_images, desc="Generating embeddings", unit=" images"
) as progress_bar, torch.no_grad():
for images_tensor in loader:
imgs = images_tensor.to("cpu", non_blocking=True).cpu().numpy()
batch_embeddings = self._model_vision.run(None, {"image":imgs})
embeddings[position:position+MAX_BATCH_SIZE,:] = batch_embeddings[0][:,:]
position += MAX_BATCH_SIZE
progress_bar.update(MAX_BATCH_SIZE)

return embeddings
4 changes: 4 additions & 0 deletions lightly_studio/src/lightly_studio/examples/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

import lightly_studio as ls
from lightly_studio import db_manager
from lightly_studio.dataset.pe_embedding_generator import (
PEEmbeddingGenerator,
)

PEEmbeddingGenerator()
# Read environment variables
env = Env()
env.read_env()
Expand Down
Loading
Loading