diff --git a/configs/base/dfine_hgnetv2.yml b/configs/base/dfine_hgnetv2.yml index e9de6d1b..51fe1b42 100644 --- a/configs/base/dfine_hgnetv2.yml +++ b/configs/base/dfine_hgnetv2.yml @@ -22,7 +22,7 @@ no_aug_epoch: 0 HGNetv2: pretrained: True - local_model_dir: ../RT-DETR-main/D-FINE/weight/hgnetv2/ + local_model_dir: ./weight/hgnetv2/ HybridEncoder: in_channels: [512, 1024, 2048] diff --git a/configs/deim_dfine/deim_hgnetv2_n_coco.yml b/configs/deim_dfine/deim_hgnetv2_n_coco.yml index 62db245d..5da8ada6 100644 --- a/configs/deim_dfine/deim_hgnetv2_n_coco.yml +++ b/configs/deim_dfine/deim_hgnetv2_n_coco.yml @@ -3,7 +3,7 @@ __include__: [ '../base/deim.yml' ] -output_dir: ./deim_outputs/deim_hgnetv2_n_coco +output_dir: ./outputs/deim_hgnetv2_n_coco optimizer: type: AdamW diff --git a/configs/deim_dfine/dfine_hgnetv2_n_coco.yml b/configs/deim_dfine/dfine_hgnetv2_n_coco.yml index c45e6535..8346fca4 100644 --- a/configs/deim_dfine/dfine_hgnetv2_n_coco.yml +++ b/configs/deim_dfine/dfine_hgnetv2_n_coco.yml @@ -68,7 +68,6 @@ optimizer: # Increase to search for the optimal ema epoches: 160 # 148 + 4n train_dataloader: - total_batch_size: 128 dataset: transforms: policy: @@ -76,7 +75,4 @@ train_dataloader: collate_fn: stop_epoch: 148 ema_restart_decay: 0.9999 - base_size_repeat: ~ - -val_dataloader: - total_batch_size: 256 + base_size_repeat: ~ \ No newline at end of file diff --git a/engine/backbone/hgnetv2.py b/engine/backbone/hgnetv2.py index a26f38b0..917bf128 100644 --- a/engine/backbone/hgnetv2.py +++ b/engine/backbone/hgnetv2.py @@ -12,6 +12,7 @@ from .common import FrozenBatchNorm2d from ..core import register import logging +from ..misc.dist_utils import get_rank, is_dist_available_and_initialized # Constants for initialization kaiming_normal_ = nn.init.kaiming_normal_ @@ -495,11 +496,12 @@ def __init__(self, print(f"Loaded stage1 {name} HGNetV2 from local file.") else: # If the file doesn't exist locally, download from the URL - if torch.distributed.get_rank() == 0: + if get_rank() == 0: print(GREEN + "If the pretrained HGNetV2 can't be downloaded automatically. Please check your network connection." + RESET) print(GREEN + "Please check your network connection. Or download the model manually from " + RESET + f"{download_url}" + GREEN + " to " + RESET + f"{local_model_dir}." + RESET) state = torch.hub.load_state_dict_from_url(download_url, map_location='cpu', model_dir=local_model_dir) - torch.distributed.barrier() + if is_dist_available_and_initialized(): + torch.distributed.barrier() else: torch.distributed.barrier() state = torch.load(local_model_dir) @@ -509,7 +511,7 @@ def __init__(self, self.load_state_dict(state) except (Exception, KeyboardInterrupt) as e: - if torch.distributed.get_rank() == 0: + if get_rank() == 0: print(f"{str(e)}") logging.error(RED + "CRITICAL WARNING: Failed to load pretrained HGNetV2 model" + RESET) logging.error(GREEN + "Please check your network connection. Or download the model manually from " \