Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import torch

from semantic_inference.models.feature_visualizers import *
from semantic_inference.models.instance_segmenter import *
from semantic_inference.models.mask_functions import *
from semantic_inference.models.openset_segmenter import *
from semantic_inference.models.patch_extractor import *
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# BSD 3-Clause License
#
# Copyright (c) 2021-2024, Massachusetts Institute of Technology.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
"""Model to segment an image and encode segments with CLIP embeddings."""

import dataclasses
from dataclasses import dataclass
from typing import Any

import numpy as np
import torch
from spark_config import Config, config_field
from torch import nn


def _map_opt(values, f):
return {k: v if v is None else f(v) for k, v in values.items()}


@dataclass
class Results:
"""Openset Segmentation Results."""

masks: torch.Tensor
boxes: torch.Tensor # bounding boxes for the masks
categories: torch.Tensor
confidences: torch.Tensor

@property
def instances(self):
"""Get instance image (if it exists)."""
if self.masks.shape[0] == 0:
return None

np_masks = self.masks.numpy()
img = np.zeros(np_masks[0].shape, dtype=np.uint16)
for i in range(self.masks.shape[0]):
# instance ids are 1-indexed
img[np_masks[i, ...] > 0] = i + 1

# TODO: 16 + 16 int for instance id and category id

return img

def cpu(self):
"""Move results to CPU."""
values = dataclasses.asdict(self)
return Results(**_map_opt(values, lambda v: v.cpu()))

def to(self, *args, **kwargs):
"""Forward to to all tensors."""
values = dataclasses.asdict(self)
return Results(**_map_opt(values, lambda v: v.to(*args, **kwargs)))


@dataclass
class InstanceSegmenterConfig(Config):
"""Main config for instance segmenter."""

instance_model: Any = config_field("instance_model", default="yolov11")
# relevant configs (model path, model weights) for the model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(minor) not needed at the moment?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, looking at the yolov11 wrapper, I'd consider adding a minimum confidence score for the detected objects, I'd assume most instance segmenters return some sort of 0-1 confidence score per mask



class InstanceSegmenter(nn.Module):
"""Module to segment and encode an image."""

def __init__(self, config):
"""Construct an instance segmenter."""
super().__init__()
# for detecting model device
self._canary_param = nn.Parameter(torch.empty(0))

self.config = config
self.segmenter = self.config.instance_model.create()

def eval(self):
"""
Override eval to avoid issues with certain models
"""
self.segmenter.eval()

@classmethod
def construct(cls, **kwargs):
"""Load model from configuration dictionary."""
config = InstanceSegmenterConfig()
config.update(kwargs)
return cls(config)

@torch.no_grad()
def segment(self, rgb_img, is_rgb_order=True):
"""
Segment image and compute language embeddings for each mask.
Args:
img (np.ndarry): uint8 image of shape (R, C, 3) in rgb order
is_rgb_order (bool): whether the image is rgb order or not
Returns:
Encoded image
"""
img = rgb_img if is_rgb_order else rgb_img[:, :, ::-1].copy()
return self(img)

@property
def device(self):
"""Get current model device."""
return self._canary_param.device

def forward(self, rgb_img):
"""
Segment image and compute language embeddings for each mask.
Args:
img (np.ndarray): uint8 image of shape (R, C, 3) in rgb order
Returns:
Encoded image
"""
categories, masks, boxes, confidences = self.segmenter(rgb_img)

# img = torch.from_numpy(rgb_img).to(self.device)
# return self.encode(img, masks, boxes)
# TODO: return the results of the actual instance segmentation model here
return Results(
masks=masks, boxes=boxes, categories=categories, confidences=confidences
)
52 changes: 52 additions & 0 deletions semantic_inference/python/semantic_inference/models/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,55 @@ class OpenClipConfig(Config):
def load(cls, filepath):
"""Load config from file."""
return Config.load(cls, filepath)


class Yolov11InstanceSegmenterWrapper(nn.Module):
"""Yolov11 instance segmentation wrapper."""

def __init__(self, config):
"""Load Yolov11 model."""
super().__init__()
from ultralytics import YOLO

self.config = config
self.model = YOLO(config.model_name)

def eval(self):
"""
override eval to avoid issues with yolo model
"""
self.model.model.eval()

@classmethod
def construct(cls, **kwargs):
"""Load model from configuration dictionary."""
config = Yolov11InstanceSegmenterConfig()
config.update(kwargs)
return cls(config)

def forward(self, img):
"""Segment image."""
result = self.model(img)[0] # assume batch size 1
if result.masks is None:
return None, None, None, None
categories = result.boxes.cls # int8
masks = result.masks.data.to(torch.bool) #
boxes = result.boxes.xyxy # float32
confidences = result.boxes.conf # float32
# assume the instance id is the index in the result?
return categories, masks, boxes, confidences


@register_config(
"instance_model", name="yolov11", constructor=Yolov11InstanceSegmenterWrapper
)
@dataclasses.dataclass
class Yolov11InstanceSegmenterConfig(Config):
"""Configuration for Yolov11 instance segmenter."""

model_name: str = "yolo11n-seg.pt"

@classmethod
def load(cls, filepath):
"""Load config from file."""
return Config.load(cls, filepath)
2 changes: 1 addition & 1 deletion semantic_inference_ros/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ install(
LIBRARY DESTINATION lib
RUNTIME DESTINATION lib/${PROJECT_NAME}
)
install(PROGRAMS app/image_embedding_node app/open_set_node app/text_embedding_node
install(PROGRAMS app/image_embedding_node app/open_set_node app/text_embedding_node app/instance_segmentation_node
DESTINATION lib/${PROJECT_NAME}
)
install(DIRECTORY include/${PROJECT_NAME}/ DESTINATION include/${PROJECT_NAME}/)
Expand Down
Loading