diff --git a/core/model.py b/core/model.py index 6905b09..b1e551d 100755 --- a/core/model.py +++ b/core/model.py @@ -815,7 +815,22 @@ def erase_feature_maps(self, atten_map_normed, feature_maps, threshold, flag=Fal return erased_feature_maps def fix_params(self, is_training=True): - pass + for p in self.backbone.parameters(): + p.requires_grad = is_training + for p in self.erase.parameters(): + p.requires_grad = is_training + for p in self.fc7.parameters(): + p.requires_grad = is_training + for p in self.cls.parameters(): + p.requires_grad = is_training + for p in self.cls_direction.parameters(): + p.requires_grad = is_training + for p in self.erase_fc7.parameters(): + p.requires_grad = is_training + for p in self.erase_cls.parameters(): + p.requires_grad = is_training + for p in self.erase_cls_direction.parameters(): + p.requires_grad = is_training def get_loss(self, logits, labels, direction): diff --git a/finetune_tiger_cnn8.py b/finetune_tiger_cnn8.py index 5cd447a..2f7a063 100755 --- a/finetune_tiger_cnn8.py +++ b/finetune_tiger_cnn8.py @@ -54,7 +54,13 @@ def main(): feature_size = 1024 net = tiger_cnn8(classes=107) - ignore_params = list(map(id, net.fuse_fc7.parameters())) + ignore_params = list(map(id, net.fc7.parameters())) + ignore_params += list(map(id, net.cls.parameters())) + ignore_params += list(map(id, net.cls_direction.parameters())) + ignore_params += list(map(id, net.erase_fc7.parameters())) + ignore_params += list(map(id, net.erase_cls.parameters())) + ignore_params += list(map(id, net.erase_cls_direction.parameters())) + ignore_params += list(map(id, net.fuse_fc7.parameters())) ignore_params += list(map(id, net.fuse_cls.parameters())) ignore_params += list(map(id, net.fuse_cls_direction.parameters())) base_params = filter(lambda p: id(p) not in ignore_params, net.parameters()) @@ -67,7 +73,7 @@ def main(): exp_lr_scheduler = StepLRScheduler(optimizer=optimizer, decay_t=20, decay_rate=0.1, warmup_lr_init=1e-5, warmup_t=3) net.load_state_dict(torch.load('./model/tiger_cnn3/model.ckpt')['net_state_dict']) - net.fix_params(is_training=False) + # net.fix_params(is_training=False) net = net.cuda() if multi_gpus: net = nn.DataParallel(net)