diff --git a/label_studio_ml/examples/segment_anything_2_video/README.md b/label_studio_ml/examples/segment_anything_2_video/README.md index d4134e8cd..ed35d984a 100644 --- a/label_studio_ml/examples/segment_anything_2_video/README.md +++ b/label_studio_ml/examples/segment_anything_2_video/README.md @@ -44,10 +44,14 @@ pip install -r requirements.txt ``` 2. Download [`segment-anything-2` repo](https://github.com/facebookresearch/segment-anything-2) into the root directory. Install SegmentAnything model and download checkpoints using [the official Meta documentation](https://github.com/facebookresearch/segment-anything-2?tab=readme-ov-file#installation). Make sure that you complete the steps for downloadingn the checkpoint files! + If you want to install segment-anything-2 repo in a different directory, you must the SEGMENT_ANYTHING_2_PATH environment variable to the path of the segment-anything-2 directory. 3. Export the following environment variables (fill them in with your credentials!): - LABEL_STUDIO_URL: the http:// or https:// link to your label studio instance (include the prefix!) - LABEL_STUDIO_API_KEY: your api key for label studio, available in your profile. +- MAX_FRAMES_TO_TRACK: the maximum number of frames to track in a video each time the model is called. +- PROMPT_TYPE: the type of prompt you want to use: "box" or "point". By choosing "box", you will be able to draw a box around the object you want to track. By choosing "point", five key points will be automatically selected within the box to help identify the object. +- SEGMENT_ANYTHING_2_REPO_PATH: the path to the segment-anything-2 repo. Default is `segment-anything-2`, this means that the segment-anything-2 repo is in the same directory as the label-studio-ml-backend example folder. 4. Then you can start the ML backend on the default port `9090`: @@ -77,9 +81,9 @@ For your project, you can use any labeling config with video properties. Here's ## Known limitations - As of 8/11/2024, SAM2 only runs on GPU servers. -- Currently, we only support the tracking of one object in video, although SAM2 can support multiple. - Currently, we do not support video segmentation. -- No Docker support +- Multi-object tracking is enabled, but due to a bug (https://github.com/HumanSignal/label-studio-ml-backend/issues/664), the UI shows the same label for all objects even though the predictions have different labels. +- Be cautious of the UI, as due to a bug (https://github.com/HumanSignal/label-studio/issues/6593), frames are not displayed properly, causing a misalignment between labels and frames. Under the hood, the labels are still being applied to the correct frames but the UI may not display this correctly. If you want to contribute to this repository to help with some of these limitations, you can submit a PR. diff --git a/label_studio_ml/examples/segment_anything_2_video/docker-compose.yml b/label_studio_ml/examples/segment_anything_2_video/docker-compose.yml index f73413a27..1f85f4064 100644 --- a/label_studio_ml/examples/segment_anything_2_video/docker-compose.yml +++ b/label_studio_ml/examples/segment_anything_2_video/docker-compose.yml @@ -8,6 +8,13 @@ services: context: . args: TEST_ENV: ${TEST_ENV} + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [ gpu ] environment: # specify these parameters if you want to use basic auth for the model server - BASIC_AUTH_USER= @@ -24,9 +31,9 @@ services: # specify device - DEVICE=cuda # or 'cpu' (coming soon) # SAM2 model config - - MODEL_CONFIG=sam2_hiera_l.yaml + - MODEL_CONFIG=configs/sam2.1/sam2.1_hiera_t.yaml # SAM2 checkpoint - - MODEL_CHECKPOINT=sam2_hiera_large.pt + - MODEL_CHECKPOINT=sam2.1_hiera_tiny.pt # Specify the Label Studio URL and API key to access # uploaded, local storage and cloud storage files. diff --git a/label_studio_ml/examples/segment_anything_2_video/model.py b/label_studio_ml/examples/segment_anything_2_video/model.py index 2af5b4046..25e268190 100644 --- a/label_studio_ml/examples/segment_anything_2_video/model.py +++ b/label_studio_ml/examples/segment_anything_2_video/model.py @@ -1,28 +1,38 @@ -import torch -import numpy as np import os import pathlib -import cv2 import tempfile import logging +import json +from typing import List, Dict, Optional, Literal, cast +import sys -from typing import List, Dict, Optional -from uuid import uuid4 +import cv2 +import torch +import numpy as np +import requests from label_studio_ml.model import LabelStudioMLBase from label_studio_ml.response import ModelResponse from label_studio_sdk._extensions.label_studio_tools.core.utils.io import get_local_path from label_studio_sdk.label_interface.objects import PredictionValue -from PIL import Image +# from PIL import Image +from collections import defaultdict +from label_studio_sdk.client import LabelStudio + +# read the environment variables and set the paths just before importing the sam2 module +SEGMENT_ANYTHING_2_REPO_PATH = os.getenv('SEGMENT_ANYTHING_2_REPO_PATH', 'segment-anything-2') +sys.path.append(SEGMENT_ANYTHING_2_REPO_PATH) from sam2.build_sam import build_sam2, build_sam2_video_predictor logger = logging.getLogger(__name__) - DEVICE = os.getenv('DEVICE', 'cuda') -SEGMENT_ANYTHING_2_REPO_PATH = os.getenv('SEGMENT_ANYTHING_2_REPO_PATH', 'segment-anything-2') -MODEL_CONFIG = os.getenv('MODEL_CONFIG', 'sam2_hiera_l.yaml') -MODEL_CHECKPOINT = os.getenv('MODEL_CHECKPOINT', 'sam2_hiera_large.pt') +MODEL_CONFIG = os.getenv('MODEL_CONFIG', './configs/sam2.1/sam2.1_hiera_t.yaml') +MODEL_CHECKPOINT = os.getenv('MODEL_CHECKPOINT', 'sam2.1_hiera_tiny.pt') MAX_FRAMES_TO_TRACK = int(os.getenv('MAX_FRAMES_TO_TRACK', 10)) +PROMPT_TYPE = cast(Literal["box", "point"], os.getenv('PROMPT_TYPE', 'box')) +ANNOTATION_WORKAROUND = os.getenv('ANNOTATION_WORKAROUND', False) +DEBUG = os.getenv('DEBUG', False) +LABEL_STUDIO_API_KEY = os.getenv('LABEL_STUDIO_API_KEY', '') if DEVICE == 'cuda': # use bfloat16 for the entire notebook @@ -36,6 +46,8 @@ # build path to the model checkpoint sam2_checkpoint = str(pathlib.Path(__file__).parent / SEGMENT_ANYTHING_2_REPO_PATH / "checkpoints" / MODEL_CHECKPOINT) +logger.debug(f'Model checkpoint: {sam2_checkpoint}') +logger.debug(f'Model config: {MODEL_CONFIG}') predictor = build_sam2_video_predictor(MODEL_CONFIG, sam2_checkpoint) @@ -51,51 +63,50 @@ def get_inference_state(video_dir): _inference_state = predictor.init_state(video_path=video_dir) return _inference_state - class NewModel(LabelStudioMLBase): """Custom ML Backend model """ def split_frames(self, video_path, temp_dir, start_frame=0, end_frame=100): - # Open the video file logger.debug(f'Opening video file: {video_path}') video = cv2.VideoCapture(video_path) + fps = video.get(cv2.CAP_PROP_FPS) + frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT) + logger.debug(f'fps: {fps}, frame_count: {frame_count}') + duration = frame_count / fps + print(f'duration: {duration}') - # check if loaded correctly if not video.isOpened(): raise ValueError(f"Could not open video file: {video_path}") - else: - # display number of frames - logger.debug(f'Number of frames: {int(video.get(cv2.CAP_PROP_FRAME_COUNT))}') + + logger.debug(f'Number of frames: {int(video.get(cv2.CAP_PROP_FRAME_COUNT))}') frame_count = 0 while True: - # Read a frame from the video success, frame = video.read() - if frame_count < start_frame: - continue - if frame_count + start_frame >= end_frame: - break - # If frame is read correctly, success is True if not success: logger.error(f'Failed to read frame {frame_count}') + # manage this (frame 57 of acutal video test) + # poi risovli il problema del label con diverse etichette + break + + if frame_count < start_frame: + frame_count += 1 + continue + + if frame_count >= end_frame: break - # Generate a filename for the frame using the pattern with frame number: '%05d.jpg' frame_filename = os.path.join(temp_dir, f'{frame_count:05d}.jpg') - if os.path.exists(frame_filename): - logger.debug(f'Frame {frame_count}: {frame_filename} already exists') - yield frame_filename, frame - else: - # Save the frame as an image file + + if not os.path.exists(frame_filename): cv2.imwrite(frame_filename, frame) - logger.debug(f'Frame {frame_count}: {frame_filename}') - yield frame_filename, frame + logger.debug(f'Frame {frame_count}: {frame_filename}') + yield frame_filename, frame frame_count += 1 - # Release the video object video.release() def get_prompts(self, context) -> List[Dict]: @@ -111,24 +122,30 @@ def get_prompts(self, context) -> List[Dict]: box_height = obj['height'] / 100 frame_idx = obj['frame'] - 1 - # SAM2 video works with keypoints - convert the rectangle to the set of keypoints within the rectangle - - # bbox (x, y) is top-left corner - kps = [ - # center of the bbox - [x + box_width / 2, y + box_height / 2], - # half of the bbox width to the left - [x + box_width / 4, y + box_height / 2], - # half of the bbox width to the right - [x + 3 * box_width / 4, y + box_height / 2], - # half of the bbox height to the top - [x + box_width / 2, y + box_height / 4], - # half of the bbox height to the bottom - [x + box_width / 2, y + 3 * box_height / 4] - ] + if PROMPT_TYPE == 'point': + # SAM2 video works with keypoints - convert the rectangle to the set of keypoints within the rectangle + # bbox (x, y) is top-left corner + kps = [ + # center of the bbox + [x + box_width / 2, y + box_height / 2], + # half of the bbox width to the left + [x + box_width / 4, y + box_height / 2], + # half of the bbox width to the right + [x + 3 * box_width / 4, y + box_height / 2], + # half of the bbox height to the top + [x + box_width / 2, y + box_height / 4], + # half of the bbox height to the bottom + [x + box_width / 2, y + 3 * box_height / 4] + ] + elif PROMPT_TYPE == 'box': + # SAM2 video works with boxes - use the rectangle inf xyxy format + kps = [x, y, x + box_width, y + box_height] + else: + raise ValueError(f'Invalid prompt type: {PROMPT_TYPE}') points = np.array(kps, dtype=np.float32) - labels = np.array([1] * len(kps), dtype=np.int32) + # labels are not used for box prompts + labels = np.array([1] * len(kps), dtype=np.int32) if PROMPT_TYPE == 'point' else None prompts.append({ 'points': points, 'labels': labels, @@ -214,10 +231,44 @@ def dump_image_with_mask(self, frame, mask, output_file, obj_id=None, random_col def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse: - """ Returns the predicted mask for a smart keypoint that has been placed.""" + """ + Returns the predicted mask for a smart keypoint that has been placed. - from_name, to_name, value = self.get_first_tag_occurence('VideoRectangle', 'Video') + This function is responsible for processing video annotation tasks and predicting the mask of an object for a given video frame. It uses Label Studio context and draft data to determine the bounding boxes or keypoints that need to be predicted. The prediction is performed using a video tracking model, which processes multiple frames to create a coherent annotation for the target object across a sequence of video frames. + + For multi-object tracking, it is necessary to refer to the drafts instead of the context because the context contains only the data of the box that was most recently modified. + + The logic is as follows: each time the model is called, the prediction starts from the frame containing the last label of the object that appears the earliest in the video. By calling the model multiple times, the prediction is always performed moving forward. + Steps involved in the process: + 1. Extract the relevant data from `tasks` and `context` to determine the prompts for the model. + 2. Cache the video locally and extract relevant frames using `split_frames`. + 3. Use the prompts to guide the model in identifying and tracking the object of interest. + 4. Generate a mask for each frame where the object is detected and track the object through subsequent frames. + 5. Propagate the detected objects through the video sequence to refine annotations and maintain consistency. + 6. Create or update the annotation in Label Studio to provide feedback to the user. + + Args: + tasks (List[Dict]): List of tasks that need annotation. + context (Optional[Dict]): Additional information about the current annotation context. + kwargs: Optional additional arguments. + + Returns: + ModelResponse: Response containing predicted annotations for the video frames. + """ + from_name, to_name, value = self.get_first_tag_occurence('VideoRectangle', 'Video') + try: + drafts = tasks[0]['drafts'][0] + except IndexError: + logger.error('Drafts not found, using annotations') + try: + drafts = tasks[0]['annotations'][0] + except IndexError: + logger.error('Annotations not found, using context') + drafts = context + if not len(drafts): + logger.info('Draft empty, using context') + drafts = context task = tasks[0] task_id = task['id'] # Get the video URL from the task @@ -228,13 +279,30 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) - logger.debug(f'Video path: {video_path}') # get prompts from context - prompts = self.get_prompts(context) - all_obj_ids = set(p['obj_id'] for p in prompts) + # prompts = self.get_prompts(context) + prompts = self.get_prompts(drafts) + + context_ids = set([ctx['id'] for ctx in context['result']]) + all_obj_ids = set([p['id'] for p in drafts['result']] + + ([p['id'] for p in tasks[0]['annotations'][0]['result']] if len(tasks[0]['annotations']) else [])) + if not context_ids.issubset( all_obj_ids): + # Returning here because the case where object ids in the context do not match the ids found in the annotations is not supported. + # This remains an open issue but is not considered a substantial problem. + raise NotImplementedError(f'Context id {context_ids} not found in drafts result: {all_obj_ids}' + f'TODO merge context and drafts') + # create a map from obj_id to integer obj_ids = {obj_id: i for i, obj_id in enumerate(all_obj_ids)} # find the last frame index - first_frame_idx = min(p['frame_idx'] for p in prompts) if prompts else 0 - last_frame_idx = max(p['frame_idx'] for p in prompts) if prompts else 0 + # if there is only one object, use the last frame of the object: continue tracking from last tracked frame + # if there are multiple objects, use the smallest frame index of all objects + if len(all_obj_ids) == 1: + first_frame_idx = min(p['frame_idx'] for p in prompts) if prompts else 0 + last_frame_idx = max(p['frame_idx'] for p in prompts) if prompts else 0 + else: + first_frame_idx = min(p['frame_idx'] for p in prompts) if prompts else 0 + # the minimum of the maximum frame_idx of all objects grouped by id + last_frame_idx = min(max(p['frame_idx'] for p in prompts if p['obj_id'] == obj_id) for obj_id in all_obj_ids) frames_count, duration = self._get_fps(context) fps = frames_count / duration @@ -244,7 +312,7 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) - f'last frame index: {last_frame_idx}, ' f'obj_ids: {obj_ids}') - frames_to_track = MAX_FRAMES_TO_TRACK + frames_to_track = min(MAX_FRAMES_TO_TRACK, frames_count - last_frame_idx) # Split the video into frames with tempfile.TemporaryDirectory() as temp_dir: @@ -257,7 +325,7 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) - frames = list(self.split_frames( video_path, temp_dir, start_frame=first_frame_idx, - end_frame=last_frame_idx + frames_to_track + 1 + end_frame=last_frame_idx + frames_to_track )) height, width, _ = frames[0][1].shape logger.debug(f'Video width={width}, height={height}') @@ -266,24 +334,54 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) - inference_state = get_inference_state(temp_dir) predictor.reset_state(inference_state) - for prompt in prompts: - # multiply points by the frame size - prompt['points'][:, 0] *= width - prompt['points'][:, 1] *= height - - _, out_obj_ids, out_mask_logits = predictor.add_new_points( - inference_state=inference_state, - frame_idx=prompt['frame_idx'], - obj_id=obj_ids[prompt['obj_id']], - points=prompt['points'], - labels=prompt['labels'] - ) + # Group prompts by 'obj_id' and sort them by 'frame_idx' in one step + prompt_id_dict = defaultdict(list) + [prompt_id_dict[prompt['obj_id']].append(prompt) for prompt in prompts] + + # Sort the prompts and extract the highest frame index for each object ID + highest_frames = [sorted(prompts, key=lambda x: x['frame_idx'])[-1]['frame_idx'] for prompts in + prompt_id_dict.values() if prompts] - sequence = [] + # Get the minimum value of the highest frame indices + prompt_idx = min(highest_frames) if highest_frames else None - debug_dir = './debug-frames' - os.makedirs(debug_dir, exist_ok=True) + for prompt in prompts: + frame_idx = prompt['frame_idx'] - first_frame_idx + # sam 2 not predict other frame if are present prompts after the frame: the prompt must be set in the same frame for each object + if frame_idx > prompt_idx: + logger.warning(f'Prompt frame index {frame_idx} is out of bounds') + continue + + + if PROMPT_TYPE == 'point': + # multiply points by the frame size + prompt['points'][:, 0] *= width + prompt['points'][:, 1] *= height + _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( + inference_state=inference_state, + frame_idx=frame_idx, + obj_id=obj_ids[prompt['obj_id']], + points=prompt['points'], + labels=prompt['labels'] + ) + elif PROMPT_TYPE == 'box': + # multiply points by the frame size + prompt['points'][0] *= width + prompt['points'][1] *= height + prompt['points'][2] *= width + prompt['points'][3] *= height + _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( + inference_state=inference_state, + frame_idx=frame_idx, + obj_id=obj_ids[prompt['obj_id']], + box=prompt['points'], + ) + if DEBUG: + debug_dir = './debug-frames' + os.makedirs(debug_dir, exist_ok=True) + + sequences = dict() logger.info(f'Propagating in video from frame {last_frame_idx} to {last_frame_idx + frames_to_track}') for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( inference_state=inference_state, @@ -294,12 +392,16 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) - for i, out_obj_id in enumerate(out_obj_ids): mask = (out_mask_logits[i] > 0.0).cpu().numpy() - # to debug, save the mask as an image - # self.dump_image_with_mask(frames[out_frame_idx][1], mask, f'{debug_dir}/{out_frame_idx:05d}_{out_obj_id}.jpg', obj_id=out_obj_id, random_color=True) + if DEBUG: + + # to debug, save the mask as an image + self.dump_image_with_mask(frames[out_frame_idx][1], mask, f'{debug_dir}/{out_frame_idx:05d}_{out_obj_id}.jpg', obj_id=out_obj_id, random_color=True) bbox = self.convert_mask_to_bbox(mask) if bbox: - sequence.append({ + obj_id = next((k for k, v in obj_ids.items() if v == out_obj_id), None) + sequences[obj_id] = sequences.get(obj_id, []) + sequences[obj_id].append({ 'frame': real_frame_idx + 1, # 'x': bbox['x'] / width * 100, # 'y': bbox['y'] / height * 100, @@ -313,24 +415,71 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) - 'rotation': 0, 'time': out_frame_idx / fps }) - - context_result_sequence = context['result'][0]['value']['sequence'] - - prediction = PredictionValue( - result=[{ + result = [] + for obj_id in all_obj_ids: + # find the context to use by searching on drafts by obj_id + context_result_sequence = next((ctx['value']['sequence'] for ctx in drafts["result"] if ctx['id'] == obj_id), []) + # take the old sequence only for the frames before the first frame of the new sequence + # and after the last frame of the new sequence + new_sequence = [s for s in context_result_sequence if s['frame'] < sequences[obj_id][0]['frame']] + \ + sequences[obj_id] + \ + [s for s in context_result_sequence if s['frame'] >= sequences[obj_id][-1]['frame']] + # take the old labels: take from context if present, otherwise from drafts + labels = next((ctx['value'].get('labels', None) for ctx in context["result"] if ctx['id'] == obj_id), None) or \ + next((ctx['value'].get('labels', None) for ctx in drafts["result"] if ctx['id'] == obj_id), None) + result.append({ 'value': { 'framesCount': frames_count, 'duration': duration, - 'sequence': context_result_sequence + sequence, + 'sequence': new_sequence, + 'labels': labels if labels else [] }, 'from_name': 'box', 'to_name': 'video', 'type': 'videorectangle', 'origin': 'manual', - # TODO: current limitation is tracking only one object - 'id': list(all_obj_ids)[0] - }] + 'id': obj_id + }) + + + prediction = PredictionValue( + model_version=MODEL_CHECKPOINT, + score=1.0, + result=result ) logger.debug(f'Prediction: {prediction.model_dump()}') - + if DEBUG: + with open('prediction.json', 'w') as f: + json.dump(prediction.model_dump(), f) + + if ANNOTATION_WORKAROUND: + # this is a workaround to update the annotation in the Label Studio since using the model response shows all the objects with the same label + # also if the label is different for each object + client = LabelStudio( + api_key=LABEL_STUDIO_API_KEY, + ) + if len(tasks[0]['annotations']) == 0: + logger.debug('Creating new annotation') + ann = client.annotations.create( + id=task_id, + result=result, + task=tasks[0]['id'], + project=tasks[0]['project'] + ) + client.annotations.get(id=ann.id) + else: + logger.debug(f'Updating annotation: {tasks[0]["annotations"][0]["id"]}') + ann = client.annotations.update( + id=tasks[0]['annotations'][0]['id'], + result=result, + task=task_id, + project=tasks[0]['project'] + ) # perche se non lo faccio nella UI mette tutti gli oggetti con la stessa label! sempre! + # convert annotation to draft making POST request to http:///api/annotations/{id}/convert-to-draft + url = f'{os.getenv("LABEL_STUDIO_URL")}/api/annotations/{ann.id}/convert-to-draft' + headers = { + 'Authorization': f'Token {os.getenv("LABEL_STUDIO_API_KEY")}' + } + response = requests.post(url, headers=headers) + # raise NotImplementedError('Stop here') return ModelResponse(predictions=[prediction])