diff --git a/bitbots_vision/bitbots_vision/vision_modules/yoeo/yoeo_handlers.py b/bitbots_vision/bitbots_vision/vision_modules/yoeo/yoeo_handlers.py index 9ae7fbe9c..e0d75941e 100644 --- a/bitbots_vision/bitbots_vision/vision_modules/yoeo/yoeo_handlers.py +++ b/bitbots_vision/bitbots_vision/vision_modules/yoeo/yoeo_handlers.py @@ -375,8 +375,10 @@ def __init__(self, try: from yoeo import models as torch_models from yoeo import detect as torch_detect + from yoeo.utils.dataclasses import GroupConfig as torch_GroupConfig self.torch_detect = torch_detect + self.torch_group_config = torch_GroupConfig except ImportError: raise ImportError("Could not import yoeo. The selected handler requires this package.") @@ -385,13 +387,23 @@ def __init__(self, self._conf_thresh: float = config["yoeo_conf_threshold"] self._nms_thresh: float = config["yoeo_nms_threshold"] + self._group_config: torch_GroupConfig = self._update_group_config() logger.debug(f"Leaving {self.__class__.__name__} constructor") + def _update_group_config(self): + robot_class_ids = self.get_robot_class_ids() + + return self.torch_group_config( + group_ids=robot_class_ids, + surrogate_id=robot_class_ids[0] + ) + def configure(self, config: Dict) -> None: super().configure(config) self._conf_thresh = config["yoeo_conf_threshold"] self._nms_thresh = config["yoeo_nms_threshold"] + self._group_config = self._update_group_config() @staticmethod def model_files_exist(model_directory: str) -> bool: @@ -403,7 +415,7 @@ def _compute_new_prediction_for(self, image: np.ndarray) -> Tuple[np.ndarray, np image, conf_thres=self._conf_thresh, nms_thres=self._nms_thresh, - robot_class_ids=self.get_robot_class_ids()) + group_config=self._group_config) segmentation = self._postprocess_segmentation(segmentation)