Skip to content

Commit

Permalink
Merge remote-tracking branch 'bitbots_vision/master' into monorepo
Browse files Browse the repository at this point in the history
  • Loading branch information
Flova committed Jan 20, 2024
2 parents f969bbe + f29d061 commit b077830
Showing 1 changed file with 10 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,10 @@ def __init__(
try:
from yoeo import detect as torch_detect
from yoeo import models as torch_models
from yoeo.utils.dataclasses import GroupConfig as torch_GroupConfig

self.torch_detect = torch_detect
self.torch_group_config = torch_GroupConfig
except ImportError as e:
raise ImportError("Could not import yoeo. The selected handler requires this package.") from e

Expand All @@ -396,13 +398,20 @@ def __init__(

self._conf_thresh: float = config["yoeo_conf_threshold"]
self._nms_thresh: float = config["yoeo_nms_threshold"]
self._group_config: torch_GroupConfig = self._update_group_config()

logger.debug(f"Leaving {self.__class__.__name__} constructor")

def _update_group_config(self):
robot_class_ids = self.get_robot_class_ids()

return self.torch_group_config(group_ids=robot_class_ids, surrogate_id=robot_class_ids[0])

def configure(self, config: dict) -> None:
super().configure(config)
self._conf_thresh = config["yoeo_conf_threshold"]
self._nms_thresh = config["yoeo_nms_threshold"]
self._group_config = self._update_group_config()

@staticmethod
def model_files_exist(model_directory: str) -> bool:
Expand All @@ -416,7 +425,7 @@ def _compute_new_prediction_for(self, image: np.ndarray) -> tuple[np.ndarray, np
image,
conf_thres=self._conf_thresh,
nms_thres=self._nms_thresh,
robot_class_ids=self.get_robot_class_ids(),
group_config=self._group_config,
)

segmentation = self._postprocess_segmentation(segmentation)
Expand Down

0 comments on commit b077830

Please sign in to comment.