Skip to content

Commit

Permalink
adapted pytorch handler to changes in yoeo api (#359)
Browse files Browse the repository at this point in the history
  • Loading branch information
jaagut authored Jan 20, 2024
2 parents aaba145 + d4ebd68 commit f29d061
Showing 1 changed file with 13 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,10 @@ def __init__(self,
try:
from yoeo import models as torch_models
from yoeo import detect as torch_detect
from yoeo.utils.dataclasses import GroupConfig as torch_GroupConfig

self.torch_detect = torch_detect
self.torch_group_config = torch_GroupConfig
except ImportError:
raise ImportError("Could not import yoeo. The selected handler requires this package.")

Expand All @@ -385,13 +387,23 @@ def __init__(self,

self._conf_thresh: float = config["yoeo_conf_threshold"]
self._nms_thresh: float = config["yoeo_nms_threshold"]
self._group_config: torch_GroupConfig = self._update_group_config()

logger.debug(f"Leaving {self.__class__.__name__} constructor")

def _update_group_config(self):
robot_class_ids = self.get_robot_class_ids()

return self.torch_group_config(
group_ids=robot_class_ids,
surrogate_id=robot_class_ids[0]
)

def configure(self, config: Dict) -> None:
super().configure(config)
self._conf_thresh = config["yoeo_conf_threshold"]
self._nms_thresh = config["yoeo_nms_threshold"]
self._group_config = self._update_group_config()

@staticmethod
def model_files_exist(model_directory: str) -> bool:
Expand All @@ -403,7 +415,7 @@ def _compute_new_prediction_for(self, image: np.ndarray) -> Tuple[np.ndarray, np
image,
conf_thres=self._conf_thresh,
nms_thres=self._nms_thresh,
robot_class_ids=self.get_robot_class_ids())
group_config=self._group_config)

segmentation = self._postprocess_segmentation(segmentation)

Expand Down

0 comments on commit f29d061

Please sign in to comment.