diff --git a/label_anything/readme.md b/label_anything/readme.md index 91a660ee..754f4609 100644 --- a/label_anything/readme.md +++ b/label_anything/readme.md @@ -378,3 +378,33 @@ When finished, we can get the model test visualization. On the left is the annot With the semi-automated annotation function of Label-Studio, users can complete object segmentation and detection by simply clicking the mouse during the annotation process, greatly improving the efficiency of annotation. Some of the code was borrowed from Pull Request ID 253 of label-studio-ml-backend. Thank you to the author for their contribution. Also, thanks to fellow community member [ATang0729](https://github.com/ATang0729) for re-labeling the meow dataset for script testing, and [JimmyMa99](https://github.com/JimmyMa99) for the conversion script, config template, and documentation Optimization. + +## (beta)🚀 SAM backend inference using onnx runtime🚀 (optional) + +We use onnx runtime for SAM back-end inference to improve the speed of SAM inference, tested on a 3090, which takes 4.6s with pytorch and 0.24s with onnx runtime. + +First download the converted onnx from huggingface. + +```shell +cd path/to/playground/label_anything +wget https://huggingface.co/visheratin/segment-anything-vit-b/resolve/main/encoder.onnx +wget https://huggingface.co/visheratin/segment-anything-vit-b/resolve/main/decoder.onnx +``` + +Then turn on back-end reasoning. + +```shell +cd path/to/playground/label_anything + +label-studio-ml start sam --port 8003 --with \ +sam_config=vit_b \ +sam_checkpoint_file=. /sam_vit_b_01ec64.pth \ +out_mask=True \ +out_bbox=True \ +device=cuda:0 \ +onnx=True \ +# device=cuda:0 for GPU inference, if cpu inference is used, replace cuda:0 with cpu +# out_poly=True returns the annotation of the external polygon +``` + +⚠ Currently only sam_vit_b is supported. diff --git a/label_anything/readme_zh.md b/label_anything/readme_zh.md index d27f8a2a..b667f7f5 100644 --- a/label_anything/readme_zh.md +++ b/label_anything/readme_zh.md @@ -384,5 +384,35 @@ python tools/test.py data/my_set/mask-rcnn_r50_fpn.py path/of/your/checkpoint -- 到此半自动化标注就完成了, 通过 Label-Studio 的半自动化标注功能,可以让用户在标注过程中,通过点击一下鼠标,就可以完成目标的分割和检测,大大提高了标注效率。部分代码借鉴自 label-studio-ml-backend ID 为 253 的 Pull Request,感谢作者的贡献。同时感谢社区同学 [ATang0729](https://github.com/ATang0729) 为脚本测试重新标注了喵喵数据集,以及 [JimmyMa99](https://github.com/JimmyMa99) 同学提供的转换脚本、 config 模板以及文档优化。 +## (测试阶段)🚀使用 onnx runtime 进行 SAM 后端推理🚀(可选) + +我们使用 onnx runtime 进行 SAM 后端推理以提升 SAM 的推理速度,在一张 3090 上测试,使用 pytorch 需要 4.6s ,使用 onnx runtime 只要 0.24s。 + +首先下载 huggingface 上转换好的 onnx。 + +```shell +cd path/to/playground/label_anything +wget https://huggingface.co/visheratin/segment-anything-vit-b/resolve/main/encoder.onnx +wget https://huggingface.co/visheratin/segment-anything-vit-b/resolve/main/decoder.onnx +#其他版本可以在 https://github.com/vietanhdev/anylabeling-assets/releases/tag/v0.2.0 下载 +``` + +接着开启后端推理。 + +```shell +cd path/to/playground/label_anything + +label-studio-ml start sam --port 8003 --with \ +out_mask=True \ +out_bbox=True \ +device=cuda:0 \ +onnx=True \ +onnx_encoder_file='encoder.onnx' \ +onnx_decoder_file='decoder.onnx' +# device=cuda:0 为使用 GPU 推理,如果使用 cpu 推理,将 cuda:0 替换为 cpu +# out_poly=True 返回外接多边形的标注 +``` + + diff --git a/label_anything/sam/mmdetection.py b/label_anything/sam/mmdetection.py index b7149a78..708fb54c 100644 --- a/label_anything/sam/mmdetection.py +++ b/label_anything/sam/mmdetection.py @@ -7,6 +7,7 @@ import numpy as np from label_studio_converter import brush import torch +from torch.nn import functional as F import cv2 @@ -19,10 +20,45 @@ # from mmdet.apis import inference_detector, init_detector from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator +from segment_anything.utils.transforms import ResizeLongestSide import random import string logger = logging.getLogger(__name__) + +import onnxruntime +import time + +def load_my_onnx(encoder_model_abs_path,decoder_model_abs_path): + # !wget https://huggingface.co/visheratin/segment-anything-vit-b/resolve/main/encoder.onnx + # !wget https://huggingface.co/visheratin/segment-anything-vit-b/resolve/main/decoder.onnx + # if onnx_config == 'vit_b': + # encoder_model_abs_path = "models/segment_anything_vit_b_encoder_quant.onnx" + # decoder_model_abs_path = "models/segment_anything_vit_b_decoder_quant.onnx" + # elif onnx_config == 'vit_l': + # encoder_model_abs_path = "models/segment_anything_vit_l_encoder_quant.onnx" + # decoder_model_abs_path = "models/segment_anything_vit_l_decoder_quant.onnx" + # elif onnx_config == 'vit_h': + # encoder_model_abs_path = "models/segment_anything_vit_h_encoder_quant.onnx" + # decoder_model_abs_path = "models/segment_anything_vit_h_decoder_quant.onnx" + + providers = onnxruntime.get_available_providers() + if providers: + logging.info( + "Available providers for ONNXRuntime: %s", ", ".join(providers) + ) + else: + logging.warning("No available providers for ONNXRuntime") + encoder_session = onnxruntime.InferenceSession( + encoder_model_abs_path, providers=providers + ) + decoder_session = onnxruntime.InferenceSession( + decoder_model_abs_path, providers=providers + ) + + return encoder_session,decoder_session + + def load_my_model(device="cuda:0",sam_config="vit_b",sam_checkpoint_file="sam_vit_b_01ec64.pth"): """ Loads the Segment Anything model on initializing Label studio, so if you call it outside MyModel it doesn't load every time you try to make a prediction @@ -50,11 +86,18 @@ def __init__(self, out_poly=False, score_threshold=0.5, device='cpu', + onnx=False, + onnx_encoder_file=None, + onnx_decoder_file=None, **kwargs): super(MMDetection, self).__init__(**kwargs) - PREDICTOR=load_my_model(device,sam_config,sam_checkpoint_file) + self.onnx=onnx + if self.onnx: + PREDICTOR=load_my_onnx(onnx_encoder_file,onnx_decoder_file) + else: + PREDICTOR=load_my_model(device,sam_config) self.PREDICTOR = PREDICTOR self.out_mask = out_mask @@ -132,6 +175,79 @@ def __init__(self, # self.model = init_detector(config_file, checkpoint_file, device=device) self.score_thresh = score_threshold +#################################################################################################### + + def pre_process(self, image): + image_size = 1024 + transform = ResizeLongestSide(image_size) + + input_image = transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device="cpu") + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] + pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) + pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) + x = (input_image_torch - pixel_mean) / pixel_std + h, w = x.shape[-2:] + padh = image_size - h + padw = image_size - w + x = F.pad(x, (0, padw, 0, padh)) + x = x.numpy() + + encoder_inputs = { + "x": x, + } + return encoder_inputs, image.shape[:2] + + def run_encoder(self, encoder_inputs): + output = self.encoder_session.run(None, encoder_inputs) + image_embedding = output[0] + return image_embedding + + def run_decoder( + self, image_embedding, input_prompt,img_size): + (original_height,original_width)=img_size + points=input_prompt['points'] + masks=input_prompt['mask'] + boxes=input_prompt['boxes'] + labels=input_prompt['label'] + + image_size = 1024 + transform = ResizeLongestSide(image_size) + if boxes is not None: + onnx_box_coords = boxes.reshape(2, 2) + input_labels = np.array([2,3]) + + onnx_coord = np.concatenate([onnx_box_coords, np.array([[0.0, 0.0]])], axis=0)[None, :, :] + onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[None, :].astype(np.float32) + elif points is not None: + input_point=points + input_label = np.array([1]) + onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :] + onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32) + + onnx_coord = transform.apply_coords(onnx_coord, img_size).astype(np.float32) + + onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) + onnx_has_mask_input = np.zeros(1, dtype=np.float32) + + + decoder_inputs = { + "image_embeddings": image_embedding, + "point_coords": onnx_coord, + "point_labels": onnx_label, + "mask_input": onnx_mask_input, + "has_mask_input": onnx_has_mask_input, + "orig_im_size": np.array( + img_size, dtype=np.float32 + ), + } + masks, _, _ = self.decoder_session.run(None, decoder_inputs) + # masks = masks[0, 0, :, :] # Only get 1 mask + masks = masks > 0.0 + # masks = masks.reshape(img_size) + return masks +########################################################################################## + def _get_image_url(self, task): image_url = task['data'].get( self.value) or task['data'].get(DATA_UNDEFINED_NAME) @@ -155,9 +271,8 @@ def _get_image_url(self, task): return image_url def predict(self, tasks, **kwargs): - - predictor = self.PREDICTOR - + #共用区域 + start = time.time() results = [] assert len(tasks) == 1 task = tasks[0] @@ -170,54 +285,99 @@ def predict(self, tasks, **kwargs): # image = cv2.imread(f"./{split}") image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - predictor.set_image(image) - prompt_type = kwargs['context']['result'][0]['type'] original_height = kwargs['context']['result'][0]['original_height'] original_width = kwargs['context']['result'][0]['original_width'] + ############################################# + if self.onnx: + self.encoder_session,self.decoder_session=self.PREDICTOR + encoder_inputs,_ = self.pre_process(image) + + input_prompt={} + + input_prompt['boxes']=input_prompt['mask']=input_prompt['points']=input_prompt['label']=None + if prompt_type == 'keypointlabels': + # getting x and y coordinates of the keypoint + x = kwargs['context']['result'][0]['value']['x'] * original_width / 100 + y = kwargs['context']['result'][0]['value']['y'] * original_height / 100 + output_label = kwargs['context']['result'][0]['value']['labels'][0] + + input_prompt['points']=np.array([[x, y]]) + input_prompt['label']=np.array([1]) + + + if prompt_type == 'rectanglelabels': + + x = kwargs['context']['result'][0]['value']['x'] * original_width / 100 + y = kwargs['context']['result'][0]['value']['y'] * original_height / 100 + w = kwargs['context']['result'][0]['value']['width'] * original_width / 100 + h = kwargs['context']['result'][0]['value']['height'] * original_height / 100 + + output_label = kwargs['context']['result'][0]['value']['rectanglelabels'][0] + + + input_prompt['boxes']=np.array([x, y, x+w, y+h]) + + input_prompt['label'] = np.array([2,3]) + + + #encoder + image_embedding = self.run_encoder(encoder_inputs) + masks = self.run_decoder(image_embedding,input_prompt,\ + (original_height,original_width)) + masks = masks[0].astype(np.uint8) + # mask = masks.astype(np.uint8) + # shapes = self.post_process(masks, resized_ratio) + else: + predictor = self.PREDICTOR - if prompt_type == 'keypointlabels': - # getting x and y coordinates of the keypoint - x = kwargs['context']['result'][0]['value']['x'] * original_width / 100 - y = kwargs['context']['result'][0]['value']['y'] * original_height / 100 - output_label = kwargs['context']['result'][0]['value']['labels'][0] - + predictor.set_image(image) + - masks, scores, logits = predictor.predict( - point_coords=np.array([[x, y]]), - # box=np.array([x.cpu() for x in bbox[:4]]), - point_labels=np.array([1]), - multimask_output=False, - ) - if prompt_type == 'rectanglelabels': + if prompt_type == 'keypointlabels': + # getting x and y coordinates of the keypoint + x = kwargs['context']['result'][0]['value']['x'] * original_width / 100 + y = kwargs['context']['result'][0]['value']['y'] * original_height / 100 + output_label = kwargs['context']['result'][0]['value']['labels'][0] - x = kwargs['context']['result'][0]['value']['x'] * original_width / 100 - y = kwargs['context']['result'][0]['value']['y'] * original_height / 100 - w = kwargs['context']['result'][0]['value']['width'] * original_width / 100 - h = kwargs['context']['result'][0]['value']['height'] * original_height / 100 + masks, scores, logits = predictor.predict( + point_coords=np.array([[x, y]]), + # box=np.array([x.cpu() for x in bbox[:4]]), + point_labels=np.array([1]), + multimask_output=False, + ) - output_label = kwargs['context']['result'][0]['value']['rectanglelabels'][0] - masks, scores, logits = predictor.predict( - # point_coords=np.array([[x, y]]), - box=np.array([x, y, x+w, y+h]), - point_labels=np.array([1]), - multimask_output=False, - ) + if prompt_type == 'rectanglelabels': + x = kwargs['context']['result'][0]['value']['x'] * original_width / 100 + y = kwargs['context']['result'][0]['value']['y'] * original_height / 100 + w = kwargs['context']['result'][0]['value']['width'] * original_width / 100 + h = kwargs['context']['result'][0]['value']['height'] * original_height / 100 - mask = masks[0].astype(np.uint8) # each mask has shape [H, W] - # converting the mask from the model to RLE format which is usable in Label Studio + output_label = kwargs['context']['result'][0]['value']['rectanglelabels'][0] - # 找到轮廓 - contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + masks, scores, logits = predictor.predict( + # point_coords=np.array([[x, y]]), + box=np.array([x, y, x+w, y+h]), + point_labels=np.array([1]), + multimask_output=False, + ) + + # 找到轮廓 + mask = masks[0].astype(np.uint8) # each mask has shape [H, W] + # converting the mask from the model to RLE format which is usable in Label Studio + contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + end = time.time() + print(end-start) +######################## # 计算外接矩形