From 62653151650c7dd023487bf053833fda5250439a Mon Sep 17 00:00:00 2001 From: Andrzej Pacuk Date: Tue, 27 Dec 2022 10:11:26 +0000 Subject: [PATCH] fix initialization of Accuracy to work with new version of torchmetrics --- hydra_lightning/tasks/example/modules/base_adv.py | 7 ++++--- .../tasks/example/modules/resnet_robust_train.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/hydra_lightning/tasks/example/modules/base_adv.py b/hydra_lightning/tasks/example/modules/base_adv.py index 6c7f8f1..1aa7d29 100644 --- a/hydra_lightning/tasks/example/modules/base_adv.py +++ b/hydra_lightning/tasks/example/modules/base_adv.py @@ -15,9 +15,10 @@ class BaseAdvModule(BaseModule): def initialize_model(self): self.cross_entropy_criterion = torch.nn.CrossEntropyLoss() - self.train_accuracy = Accuracy() - self.std_val_accuracy = Accuracy() - self.adv_val_accuracy = Accuracy() + cfg: DictConfig = self.hparams.model_config + self.train_accuracy = Accuracy(task='multiclass', num_classes=cfg.num_classes) + self.std_val_accuracy = Accuracy(task='multiclass', num_classes=cfg.num_classes) + self.adv_val_accuracy = Accuracy(task='multiclass', num_classes=cfg.num_classes) # for robby.input_transforms.PGD @staticmethod diff --git a/hydra_lightning/tasks/example/modules/resnet_robust_train.py b/hydra_lightning/tasks/example/modules/resnet_robust_train.py index 3f2a23e..16b2920 100644 --- a/hydra_lightning/tasks/example/modules/resnet_robust_train.py +++ b/hydra_lightning/tasks/example/modules/resnet_robust_train.py @@ -14,9 +14,9 @@ class ResnetRobustTrainModule(BaseAdvModule): def initialize_model(self): super().initialize_model() - self.adv_train_accuracy = Accuracy() - cfg: DictConfig = self.hparams.model_config + self.adv_train_accuracy = Accuracy(task='multiclass', num_classes=cfg.num_classes) + self.backbone, num_features = get_robust_backbone(cfg.arch, cfg.eps) self.head = nn.Linear(num_features, cfg.num_classes)