-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from torch-points3d/metrics
Metrics
- Loading branch information
Showing
15 changed files
with
443 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
_target_: torch_points3d.metrics.segmentation.segmentation_tracker.SegmentationTracker | ||
num_classes: ${dataset.cfg.num_classes} |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import torch | ||
import os | ||
import sys | ||
import unittest | ||
import pytest | ||
import numpy as np | ||
|
||
|
||
DIR = os.path.dirname(os.path.realpath(__file__)) | ||
ROOT = os.path.join(DIR, "..") | ||
sys.path.insert(0, ROOT) | ||
sys.path.append('.') | ||
|
||
from torch_points3d.metrics.segmentation.metrics import compute_intersection_union_per_class | ||
from torch_points3d.metrics.segmentation.metrics import compute_average_intersection_union | ||
from torch_points3d.metrics.segmentation.metrics import compute_overall_accuracy | ||
from torch_points3d.metrics.segmentation.metrics import compute_mean_class_accuracy | ||
|
||
|
||
|
||
def test_compute_intersection_union_per_class(): | ||
matrix = torch.tensor([[4, 1], [2, 10]]) | ||
iou, _ = compute_intersection_union_per_class(matrix) | ||
miou = compute_average_intersection_union(matrix) | ||
np.testing.assert_allclose(iou[0].item(), 4 / (4.0 + 1.0 + 2.0)) | ||
np.testing.assert_allclose(iou[1].item(), 10 / (10.0 + 1.0 + 2.0)) | ||
np.testing.assert_allclose(iou.mean().item(), miou.item()) | ||
|
||
def test_compute_overall_accuracy(): | ||
list_matrix = [ | ||
torch.tensor([[4, 1], [2, 10]]).float(), | ||
torch.tensor([[4, 1], [2, 10]]).int(), | ||
torch.tensor([[0, 0], [0, 0]]).float() | ||
] | ||
list_answer = [ | ||
(4.0+10.0)/(4.0 + 10.0 + 1.0 +2.0), | ||
(4.0+10.0)/(4.0 + 10.0 + 1.0 +2.0), | ||
0.0 | ||
] | ||
for i in range(len(list_matrix)): | ||
acc = compute_overall_accuracy(list_matrix[i]) | ||
if(isinstance(acc, torch.Tensor)): | ||
np.testing.assert_allclose(acc.item(), list_answer[i]) | ||
else: | ||
np.testing.assert_allclose(acc, list_answer[i]) | ||
|
||
|
||
def test_compute_mean_class_accuracy(): | ||
matrix = torch.tensor([[4, 1], [2, 10]]).float() | ||
macc = compute_mean_class_accuracy(matrix) | ||
np.testing.assert_allclose(macc.item(), (4/5 + 10/12)*0.5) | ||
|
||
|
||
|
||
@pytest.mark.parametrize("missing_as_one, answer", [pytest.param(False, (0.5 + 0.5) / 2), pytest.param(True, (0.5 + 1 + 0.5) / 3)]) | ||
def test_test_getMeanIoUMissing(missing_as_one, answer): | ||
matrix = torch.tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0]]) | ||
np.testing.assert_allclose(compute_average_intersection_union(matrix, missing_as_one=missing_as_one).item(), answer) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import numpy as np | ||
import torch | ||
import sys | ||
import os | ||
|
||
import pytest | ||
|
||
|
||
from torch_geometric.data import Data | ||
|
||
DIR = os.path.dirname(os.path.realpath(__file__)) | ||
ROOT = os.path.join(DIR, "..") | ||
sys.path.insert(0, ROOT) | ||
sys.path.append(".") | ||
|
||
from torch_points3d.metrics.segmentation.segmentation_tracker import SegmentationTracker | ||
|
||
|
||
class MockDataset: | ||
INV_OBJECT_LABEL = {0: "first", 1: "wall", 2: "not", 3: "here", 4: "hoy"} | ||
pos = torch.tensor([[1, 0, 0], [2, 0, 0], [3, 0, 0], [-1, 0, 0]]).float() | ||
test_label = torch.tensor([1, 1, 0, 0]) | ||
|
||
def __init__(self): | ||
self.num_classes = 2 | ||
|
||
@property | ||
def test_data(self): | ||
return Data(pos=self.pos, y=self.test_label) | ||
|
||
def has_labels(self, stage): | ||
return True | ||
|
||
|
||
class MockModel: | ||
def __init__(self): | ||
self.iter = 0 | ||
self.losses = [ | ||
{"loss_1": 1, "loss_2": 2}, | ||
{"loss_1": 2, "loss_2": 2}, | ||
{"loss_1": 1, "loss_2": 2}, | ||
{"loss_1": 1, "loss_2": 2}, | ||
] | ||
self.outputs = [ | ||
torch.tensor([[0, 1], [0, 1]]), | ||
torch.tensor([[1, 0], [1, 0]]), | ||
torch.tensor([[1, 0], [1, 0]]), | ||
torch.tensor([[1, 0], [1, 0], [1, 0]]), | ||
] | ||
self.labels = [torch.tensor([1, 1]), torch.tensor([1, 1]), torch.tensor([1, 1]), torch.tensor([0, 0, -100])] | ||
self.batch_idx = [torch.tensor([0, 1]), torch.tensor([0, 1]), torch.tensor([0, 1]), torch.tensor([0, 0, 1])] | ||
|
||
def get_input(self): | ||
return Data(pos=MockDataset.pos[:2, :], origin_id=torch.tensor([0, 1])) | ||
|
||
def get_output(self): | ||
return self.outputs[self.iter].float() | ||
|
||
def get_labels(self): | ||
return self.labels[self.iter] | ||
|
||
def get_current_losses(self): | ||
return self.losses[self.iter] | ||
|
||
def get_batch(self): | ||
return self.batch_idx[self.iter] | ||
|
||
@property | ||
def device(self): | ||
return "cpu" | ||
|
||
|
||
def test_forward(): | ||
tracker = SegmentationTracker(num_classes=2, stage="train") | ||
model = MockModel() | ||
output = {"preds": model.get_output(), "labels": model.get_labels()} | ||
losses = model.get_current_losses() | ||
metrics = tracker(output, losses) | ||
# metrics = tracker.get_metrics() | ||
|
||
for k in ["train_acc", "train_miou", "train_macc"]: | ||
np.testing.assert_allclose(metrics[k], 100, rtol=1e-5) | ||
model.iter += 1 | ||
output = {"preds": model.get_output(), "labels": model.get_labels()} | ||
losses = model.get_current_losses() | ||
metrics = tracker(output, losses) | ||
# metrics = tracker.get_metrics() | ||
metrics = tracker.finalise() | ||
for k in ["train_acc", "train_macc"]: | ||
assert metrics[k] == 50 | ||
np.testing.assert_allclose(metrics["train_miou"], 25, atol=1e-5) | ||
assert metrics["train_loss_1"] == 1.5 | ||
|
||
tracker.reset("test") | ||
model.iter += 1 | ||
output = {"preds": model.get_output(), "labels": model.get_labels()} | ||
losses = model.get_current_losses() | ||
metrics = tracker(output, losses) | ||
# metrics = tracker.get_metrics() | ||
for name in ["test_acc", "test_miou", "test_macc"]: | ||
np.testing.assert_allclose(metrics[name].item(), 0, atol=1e-5) | ||
|
||
|
||
@pytest.mark.parametrize("finalise", [pytest.param(True), pytest.param(False)]) | ||
def test_ignore_label(finalise): | ||
tracker = SegmentationTracker(num_classes=2, ignore_label=-100) | ||
tracker.reset("test") | ||
model = MockModel() | ||
model.iter = 3 | ||
output = {"preds": model.get_output(), "labels": model.get_labels()} | ||
losses = model.get_current_losses() | ||
metrics = tracker(output, losses) | ||
if not finalise: | ||
# metrics = tracker.get_metrics() | ||
for k in ["test_acc", "test_miou", "test_macc"]: | ||
np.testing.assert_allclose(metrics[k], 100) | ||
else: | ||
tracker.finalise() | ||
with pytest.raises(RuntimeError): | ||
tracker(output) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from typing import Any, Dict, Optional | ||
import torch | ||
from torch import nn | ||
from torchmetrics import AverageMeter | ||
|
||
|
||
class BaseTracker(nn.Module): | ||
""" | ||
pytorch Module to manage the losses and the metrics | ||
""" | ||
|
||
def __init__(self, stage: str = "train"): | ||
super().__init__() | ||
self.stage: str = stage | ||
self._finalised: bool = False | ||
self.loss_metrics: nn.ModuleDict = nn.ModuleDict() | ||
|
||
def track(self, output_model, *args, **kwargs) -> Dict[str, Any]: | ||
raise NotImplementedError | ||
|
||
def track_loss(self, losses: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | ||
out_loss = dict() | ||
for key, loss in losses.items(): | ||
loss_key = f"{self.stage}_{key}" | ||
if loss_key not in self.loss_metrics.keys(): | ||
self.loss_metrics[loss_key] = AverageMeter().to(loss) | ||
val = self.loss_metrics[loss_key](loss) | ||
out_loss[loss_key] = val | ||
return out_loss | ||
|
||
def forward( | ||
self, output_model: Dict[str, Any], losses: Optional[Dict[str, torch.Tensor]] = None, *args, **kwargs | ||
) -> Dict[str, Any]: | ||
if self._finalised: | ||
raise RuntimeError("Cannot track new values with a finalised tracker, you need to reset it first") | ||
tracked_metric = self.track(output_model, *args, **kwargs) | ||
if losses is not None: | ||
tracked_loss = self.track_loss(losses) | ||
tracked_results = dict(**tracked_loss, **tracked_metric) | ||
else: | ||
tracked_results = tracked_metric | ||
return tracked_results | ||
|
||
def _finalise(self) -> Dict[str, Any]: | ||
raise NotImplementedError("method that aggregae metrics") | ||
|
||
def finalise(self) -> Dict[str, Any]: | ||
metrics = self._finalise() | ||
self._finalised = True | ||
loss_metrics = self.get_final_loss_metrics() | ||
final_metrics = {**loss_metrics, **metrics} | ||
return final_metrics | ||
|
||
def get_final_loss_metrics(self): | ||
metrics = dict() | ||
for key, m in self.loss_metrics.items(): | ||
metrics[key] = m.compute() | ||
self.loss_metrics = nn.ModuleDict() | ||
return metrics |
Empty file.
Oops, something went wrong.