diff --git a/models/detr.py b/models/detr.py index 23c2376da..f25a3d7e1 100644 --- a/models/detr.py +++ b/models/detr.py @@ -331,8 +331,7 @@ def build(args): if args.masks: model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None)) matcher = build_matcher(args) - weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef} - weight_dict['loss_giou'] = args.giou_loss_coef + weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef, 'loss_giou': args.giou_loss_coef} if args.masks: weight_dict["loss_mask"] = args.mask_loss_coef weight_dict["loss_dice"] = args.dice_loss_coef