Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
liuning-scu-cn committed Aug 12, 2019
1 parent 449d90c commit aaff087
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
17 changes: 16 additions & 1 deletion core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,22 @@ def erase_feature_maps(self, atten_map_normed, feature_maps, threshold, flag=Fal
return erased_feature_maps

def fix_params(self, is_training=True):
pass
for p in self.backbone.parameters():
p.requires_grad = is_training
for p in self.erase.parameters():
p.requires_grad = is_training
for p in self.fc7.parameters():
p.requires_grad = is_training
for p in self.cls.parameters():
p.requires_grad = is_training
for p in self.cls_direction.parameters():
p.requires_grad = is_training
for p in self.erase_fc7.parameters():
p.requires_grad = is_training
for p in self.erase_cls.parameters():
p.requires_grad = is_training
for p in self.erase_cls_direction.parameters():
p.requires_grad = is_training

def get_loss(self, logits, labels, direction):

Expand Down
10 changes: 8 additions & 2 deletions finetune_tiger_cnn8.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@ def main():
feature_size = 1024

net = tiger_cnn8(classes=107)
ignore_params = list(map(id, net.fuse_fc7.parameters()))
ignore_params = list(map(id, net.fc7.parameters()))
ignore_params += list(map(id, net.cls.parameters()))
ignore_params += list(map(id, net.cls_direction.parameters()))
ignore_params += list(map(id, net.erase_fc7.parameters()))
ignore_params += list(map(id, net.erase_cls.parameters()))
ignore_params += list(map(id, net.erase_cls_direction.parameters()))
ignore_params += list(map(id, net.fuse_fc7.parameters()))
ignore_params += list(map(id, net.fuse_cls.parameters()))
ignore_params += list(map(id, net.fuse_cls_direction.parameters()))
base_params = filter(lambda p: id(p) not in ignore_params, net.parameters())
Expand All @@ -67,7 +73,7 @@ def main():
exp_lr_scheduler = StepLRScheduler(optimizer=optimizer, decay_t=20, decay_rate=0.1, warmup_lr_init=1e-5, warmup_t=3)

net.load_state_dict(torch.load('./model/tiger_cnn3/model.ckpt')['net_state_dict'])
net.fix_params(is_training=False)
# net.fix_params(is_training=False)
net = net.cuda()
if multi_gpus:
net = nn.DataParallel(net)
Expand Down

0 comments on commit aaff087

Please sign in to comment.