Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DETR XAI #4184

Merged
merged 7 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ All notable changes to this project will be documented in this file.
(<https://github.com/openvinotoolkit/training_extensions/pull/4017>)
- Add D-Fine Detection Algorithm
(<https://github.com/openvinotoolkit/training_extensions/pull/4142>)
- Add DETR XAI Explain Mode
(<https://github.com/openvinotoolkit/training_extensions/pull/4184>)

### Enhancements

Expand Down
4 changes: 3 additions & 1 deletion docs/source/guide/tutorials/base/explain.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ which are heatmaps with red-colored areas indicating focus. Here's an example ho

(otx) ...$ otx explain --work_dir otx-workspace \
--dump True # Wherether to save saliency map images or not
--explain_config.postprocess True # Resizes and applies colormap to the saliency map

.. tab-item:: CLI (with config)

Expand All @@ -41,6 +42,7 @@ which are heatmaps with red-colored areas indicating focus. Here's an example ho
--data_root data/wgisd \
--checkpoint otx-workspace/20240312_051135/checkpoints/epoch_033.ckpt \
--dump True # Wherether to save saliency map images or not
--explain_config.postprocess True # Resizes and applies colormap to the saliency map

.. tab-item:: API

Expand All @@ -49,7 +51,7 @@ which are heatmaps with red-colored areas indicating focus. Here's an example ho
engine.explain(
checkpoint="<checkpoint-path>",
datamodule=OTXDataModule(...), # The data module to use for predictions
explain_config=ExplainConfig(postprocess=True),
explain_config=ExplainConfig(postprocess=True), # Resizes and applies colormap to the saliency map
dump=True # Wherether to save saliency map images or not
)

Expand Down
56 changes: 56 additions & 0 deletions src/otx/algo/detection/d_fine.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ def _customize_inputs(
)
targets.append({"boxes": scaled_bboxes, "labels": ll})

if self.explain_mode:
return {"entity": entity}

return {
"images": entity.images,
"targets": targets,
Expand Down Expand Up @@ -185,6 +188,33 @@ def _customize_outputs(
original_sizes = [img_info.ori_shape for img_info in inputs.imgs_info]
scores, bboxes, labels = self.model.postprocess(outputs, original_sizes)

if self.explain_mode:
if not isinstance(outputs, dict):
msg = f"Model output should be a dict, but got {type(outputs)}."
raise ValueError(msg)

if "feature_vector" not in outputs:
msg = "No feature vector in the model output."
raise ValueError(msg)

if "saliency_map" not in outputs:
msg = "No saliency maps in the model output."
raise ValueError(msg)

saliency_map = outputs["saliency_map"].detach().cpu().numpy()
feature_vector = outputs["feature_vector"].detach().cpu().numpy()

return DetBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=scores,
bboxes=bboxes,
labels=labels,
feature_vector=feature_vector,
saliency_map=saliency_map,
)

return DetBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
Expand Down Expand Up @@ -306,3 +336,29 @@ def _optimization_config(self) -> dict[str, Any]:
},
},
}

@staticmethod
def _forward_explain_detection(
self, # noqa: ANN001
entity: DetBatchDataEntity,
mode: str = "tensor", # noqa: ARG004
) -> dict[str, torch.Tensor]:
"""Forward function for explainable detection model."""
backbone_feats = self.encoder(self.backbone(entity.images))
predictions = self.decoder(backbone_feats, explain_mode=True)

raw_logits = DETR.split_and_reshape_logits(
backbone_feats,
predictions["raw_logits"],
)

saliency_map = self.explain_fn(raw_logits)
feature_vector = self.feature_vector_fn(backbone_feats)
predictions.update(
{
"feature_vector": feature_vector,
"saliency_map": saliency_map,
},
)

return predictions
40 changes: 32 additions & 8 deletions src/otx/algo/detection/detectors/detection_transformer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# Copyright (C) 2024 Intel Corporation
# Copyright (C) 2024-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Base DETR model implementations."""

from __future__ import annotations

import warnings
from typing import Any

import numpy as np
Expand Down Expand Up @@ -96,22 +95,47 @@ def export(
explain_mode: bool = False,
) -> dict[str, Any] | tuple[list[Any], list[Any], list[Any]]:
"""Exports the model."""
backbone_feats = self.encoder(self.backbone(batch_inputs))
predictions = self.decoder(backbone_feats, explain_mode=True)
results = self.postprocess(
self._forward_features(batch_inputs),
predictions,
[meta["img_shape"] for meta in batch_img_metas],
deploy_mode=True,
)

if explain_mode:
# TODO(Eugene): Implement explain mode for DETR model.
warnings.warn("Explain mode is not supported for DETR model. Return dummy values.", stacklevel=2)
raw_logits = self.split_and_reshape_logits(backbone_feats, predictions["raw_logits"])
feature_vector = self.feature_vector_fn(backbone_feats)
saliency_map = self.explain_fn(raw_logits)
xai_output = {
sovrasov marked this conversation as resolved.
Show resolved Hide resolved
"feature_vector": torch.zeros(1, 1),
"saliency_map": torch.zeros(1),
"feature_vector": feature_vector,
"saliency_map": saliency_map,
}
results.update(xai_output) # type: ignore[union-attr]
return results

@staticmethod
def split_and_reshape_logits(
backbone_feats: tuple[Tensor, ...],
raw_logits: Tensor,
) -> tuple[Tensor, ...]:
"""Splits and reshapes raw logits for explain mode.

Args:
backbone_feats (tuple[Tensor,...]): Tuple of backbone features.
raw_logits (Tensor): Raw logits.

Returns:
tuple[Tensor,...]: The reshaped logits.
"""
splits = [f.shape[-2] * f.shape[-1] for f in backbone_feats]
# Permute and split logits in one line
raw_logits = torch.split(raw_logits.permute(0, 2, 1), splits, dim=-1)

# Reshape each split in a list comprehension
return tuple(
logits.reshape(f.shape[0], -1, f.shape[-2], f.shape[-1]) for logits, f in zip(raw_logits, backbone_feats)
)

def postprocess(
self,
outputs: dict[str, Tensor],
Expand Down
31 changes: 27 additions & 4 deletions src/otx/algo/detection/heads/dfine_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ def _get_decoder_input(
enc_topk_bbox_unact = torch.concat([denoising_bbox_unact, enc_topk_bbox_unact], dim=1)
content = torch.concat([denoising_logits, content], dim=1)

return content, enc_topk_bbox_unact, enc_topk_bboxes_list, enc_topk_logits_list
return content, enc_topk_bbox_unact, enc_topk_bboxes_list, enc_topk_logits_list, enc_outputs_logits

def _select_topk(
self,
Expand Down Expand Up @@ -762,8 +762,22 @@ def _select_topk(

return topk_memory, topk_logits, topk_anchors

def forward(self, feats: Tensor, targets: list[dict[str, Tensor]] | None = None) -> dict[str, Tensor]:
"""Forward pass of the DFine Transformer module."""
def forward(
self,
feats: Tensor,
targets: list[dict[str, Tensor]] | None = None,
explain_mode: bool = False,
) -> dict[str, Tensor]:
"""Forward function of the D-FINE Decoder Transformer Module.

Args:
feats (Tensor): Feature maps.
targets (list[dict[str, Tensor]] | None, optional): target annotations. Defaults to None.
explain_mode (bool, optional): Whether to return raw logits for explanation. Defaults to False.

Returns:
dict[str, Tensor]: Output dictionary containing predicted logits, losses and boxes.
"""
# input projection and embedding
memory, spatial_shapes = self._get_encoder_input(feats)

Expand All @@ -781,7 +795,13 @@ def forward(self, feats: Tensor, targets: list[dict[str, Tensor]] | None = None)
else:
denoising_logits, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None

init_ref_contents, init_ref_points_unact, enc_topk_bboxes_list, enc_topk_logits_list = self._get_decoder_input(
(
init_ref_contents,
init_ref_points_unact,
enc_topk_bboxes_list,
enc_topk_logits_list,
raw_logits,
) = self._get_decoder_input(
memory,
spatial_shapes,
denoising_logits,
Expand Down Expand Up @@ -858,6 +878,9 @@ def forward(self, feats: Tensor, targets: list[dict[str, Tensor]] | None = None)
"pred_boxes": out_bboxes[-1],
}

if explain_mode:
out["raw_logits"] = raw_logits

return out

@torch.jit.unused
Expand Down
35 changes: 26 additions & 9 deletions src/otx/algo/detection/heads/rtdetr_decoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2024 Intel Corporation
# Copyright (C) 2024-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""RTDETR decoder, modified from https://github.com/lyuwenyu/RT-DETR."""
Expand Down Expand Up @@ -546,10 +546,10 @@ def _get_decoder_input(

output_memory = self.enc_output(memory)

enc_outputs_class = self.enc_score_head(output_memory)
enc_outputs_logits = self.enc_score_head(output_memory)
enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors

_, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1)
_, topk_ind = torch.topk(enc_outputs_logits.max(-1).values, self.num_queries, dim=1)

reference_points_unact = enc_outputs_coord_unact.gather(
dim=1,
Expand All @@ -560,9 +560,9 @@ def _get_decoder_input(
if denoising_bbox_unact is not None:
reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1)

enc_topk_logits = enc_outputs_class.gather(
enc_topk_logits = enc_outputs_logits.gather(
dim=1,
index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]),
index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_logits.shape[-1]),
)

# extract region features
Expand All @@ -575,10 +575,24 @@ def _get_decoder_input(
if denoising_class is not None:
target = torch.concat([denoising_class, target], 1)

return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits
return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits, enc_outputs_logits

def forward(self, feats: torch.Tensor, targets: list[dict[str, torch.Tensor]] | None = None) -> torch.Tensor:
"""Forward pass of the RTDETRTransformer module."""
def forward(
self,
feats: torch.Tensor,
targets: list[dict[str, torch.Tensor]] | None = None,
explain_mode: bool = False,
) -> dict[str, torch.Tensor]:
"""Forward function of RTDETRTransformer.

Args:
feats (Tensor): Input features.
targets (List[Dict[str, Tensor]]): List of target dictionaries.
explain_mode (bool): Whether to return raw logits for explanation.

Returns:
dict[str, Tensor]: Output dictionary containing predicted logits, losses and boxes.
"""
# input projection and embedding
(memory, spatial_shapes, level_start_index) = self._get_encoder_input(feats)

Expand All @@ -596,7 +610,7 @@ def forward(self, feats: torch.Tensor, targets: list[dict[str, torch.Tensor]] |
else:
denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None

target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = self._get_decoder_input(
target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits, raw_logits = self._get_decoder_input(
memory,
spatial_shapes,
denoising_class,
Expand Down Expand Up @@ -630,6 +644,9 @@ def forward(self, feats: torch.Tensor, targets: list[dict[str, torch.Tensor]] |
out["dn_aux_outputs"] = self._set_aux_loss(dn_out_logits, dn_out_bboxes)
out["dn_meta"] = dn_meta

if explain_mode:
out["raw_logits"] = raw_logits

return out

@torch.jit.unused
Expand Down
58 changes: 57 additions & 1 deletion src/otx/algo/detection/rtdetr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2024 Intel Corporation
# Copyright (C) 2024-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""RTDetr model implementations."""
Expand Down Expand Up @@ -128,6 +128,9 @@ def _customize_inputs(
)
targets.append({"boxes": scaled_bboxes, "labels": ll})

if self.explain_mode:
return {"entity": entity}

return {
"images": entity.images,
"targets": targets,
Expand Down Expand Up @@ -156,6 +159,33 @@ def _customize_outputs(
original_sizes = [img_info.ori_shape for img_info in inputs.imgs_info]
scores, bboxes, labels = self.model.postprocess(outputs, original_sizes)

if self.explain_mode:
if not isinstance(outputs, dict):
msg = f"Model output should be a dict, but got {type(outputs)}."
raise ValueError(msg)

if "feature_vector" not in outputs:
msg = "No feature vector in the model output."
raise ValueError(msg)

if "saliency_map" not in outputs:
msg = "No saliency maps in the model output."
raise ValueError(msg)

saliency_map = outputs["saliency_map"].detach().cpu().numpy()
feature_vector = outputs["feature_vector"].detach().cpu().numpy()

return DetBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=scores,
bboxes=bboxes,
labels=labels,
feature_vector=feature_vector,
saliency_map=saliency_map,
)

return DetBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
Expand Down Expand Up @@ -271,3 +301,29 @@ def _exporter(self) -> OTXModelExporter:
def _optimization_config(self) -> dict[str, Any]:
"""PTQ config for RT-DETR."""
return {"model_type": "transformer"}

@staticmethod
def _forward_explain_detection(
self, # noqa: ANN001
entity: DetBatchDataEntity,
mode: str = "tensor", # noqa: ARG004
) -> dict[str, torch.Tensor]:
"""Forward function for explainable detection model."""
backbone_feats = self.encoder(self.backbone(entity.images))
predictions = self.decoder(backbone_feats, explain_mode=True)

raw_logits = DETR.split_and_reshape_logits(
backbone_feats,
predictions["raw_logits"],
)

saliency_map = self.explain_fn(raw_logits)
feature_vector = self.feature_vector_fn(backbone_feats)
predictions.update(
{
"feature_vector": feature_vector,
"saliency_map": saliency_map,
},
)

return predictions
Loading
Loading