From b4db42186445f455be9326bc3a997f68d1be1ef7 Mon Sep 17 00:00:00 2001 From: Diego Canez Date: Wed, 4 Sep 2024 20:05:53 +0200 Subject: [PATCH] feat: add torch.export support to detector_postprocess and Boxes --- detectron2/modeling/postprocessing.py | 5 ++++- detectron2/structures/boxes.py | 14 ++++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/detectron2/modeling/postprocessing.py b/detectron2/modeling/postprocessing.py index 84512606a4..910ce02072 100644 --- a/detectron2/modeling/postprocessing.py +++ b/detectron2/modeling/postprocessing.py @@ -2,6 +2,7 @@ import torch from torch.nn import functional as F +from detectron2.layers.wrappers import check_if_dynamo_compiling from detectron2.structures import Instances, ROIMasks @@ -55,7 +56,9 @@ def detector_postprocess( output_boxes.scale(scale_x, scale_y) output_boxes.clip(results.image_size) - results = results[output_boxes.nonempty()] + if not check_if_dynamo_compiling(): + # If we're tracing with Dynamo, we can't guard on a data-dependent condition + results = results[output_boxes.nonempty()] if results.has("pred_masks"): if isinstance(results.pred_masks, ROIMasks): diff --git a/detectron2/structures/boxes.py b/detectron2/structures/boxes.py index fd396f6864..eea80fcf8c 100644 --- a/detectron2/structures/boxes.py +++ b/detectron2/structures/boxes.py @@ -6,6 +6,8 @@ import torch from torch import device +from detectron2.layers.wrappers import check_if_dynamo_compiling + _RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray] @@ -147,7 +149,13 @@ def __init__(self, tensor: torch.Tensor): if not isinstance(tensor, torch.Tensor): tensor = torch.as_tensor(tensor, dtype=torch.float32, device=torch.device("cpu")) else: - tensor = tensor.to(torch.float32) + if tensor.dtype != torch.float32: + # If we're tracing with Dynamo, this will prevent `tensor` to be mutated. + # This if-statement works at least for the case where needless recasting is done. + tensor = tensor.to(torch.float32) + # If some use case still fails with 'cannot mutate tensors with frozen storage' + # we can add `tensor = tensor.clone()` here. + # Reference: https://github.com/pytorch/pytorch/issues/127571 if tensor.numel() == 0: # Use reshape, so we don't end up creating a new tensor that does not depend on # the inputs (and consequently confuses jit) @@ -188,7 +196,9 @@ def clip(self, box_size: Tuple[int, int]) -> None: Args: box_size (height, width): The clipping box's size. """ - assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!" + if not check_if_dynamo_compiling(): + # If we're tracing with Dynamo we can't guard on a data-dependent expression + assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!" h, w = box_size x1 = self.tensor[:, 0].clamp(min=0, max=w) y1 = self.tensor[:, 1].clamp(min=0, max=h)