Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add torch.export support to detector_postprocess and Boxes #5361

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion detectron2/modeling/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from torch.nn import functional as F

from detectron2.layers.wrappers import check_if_dynamo_compiling
from detectron2.structures import Instances, ROIMasks


Expand Down Expand Up @@ -55,7 +56,9 @@ def detector_postprocess(
output_boxes.scale(scale_x, scale_y)
output_boxes.clip(results.image_size)

results = results[output_boxes.nonempty()]
if not check_if_dynamo_compiling():
# If we're tracing with Dynamo, we can't guard on a data-dependent condition
results = results[output_boxes.nonempty()]

if results.has("pred_masks"):
if isinstance(results.pred_masks, ROIMasks):
Expand Down
14 changes: 12 additions & 2 deletions detectron2/structures/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch
from torch import device

from detectron2.layers.wrappers import check_if_dynamo_compiling

_RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray]


Expand Down Expand Up @@ -147,7 +149,13 @@ def __init__(self, tensor: torch.Tensor):
if not isinstance(tensor, torch.Tensor):
tensor = torch.as_tensor(tensor, dtype=torch.float32, device=torch.device("cpu"))
else:
tensor = tensor.to(torch.float32)
if tensor.dtype != torch.float32:
# If we're tracing with Dynamo, this will prevent `tensor` to be mutated.
# This if-statement works at least for the case where needless recasting is done.
tensor = tensor.to(torch.float32)
# If some use case still fails with 'cannot mutate tensors with frozen storage'
# we can add `tensor = tensor.clone()` here.
# Reference: https://github.com/pytorch/pytorch/issues/127571
if tensor.numel() == 0:
# Use reshape, so we don't end up creating a new tensor that does not depend on
# the inputs (and consequently confuses jit)
Expand Down Expand Up @@ -188,7 +196,9 @@ def clip(self, box_size: Tuple[int, int]) -> None:
Args:
box_size (height, width): The clipping box's size.
"""
assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!"
if not check_if_dynamo_compiling():
# If we're tracing with Dynamo we can't guard on a data-dependent expression
assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!"
h, w = box_size
x1 = self.tensor[:, 0].clamp(min=0, max=w)
y1 = self.tensor[:, 1].clamp(min=0, max=h)
Expand Down