Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/base/dfine_hgnetv2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ no_aug_epoch: 0

HGNetv2:
pretrained: True
local_model_dir: ../RT-DETR-main/D-FINE/weight/hgnetv2/
local_model_dir: ./weight/hgnetv2/

HybridEncoder:
in_channels: [512, 1024, 2048]
Expand Down
2 changes: 1 addition & 1 deletion configs/deim_dfine/deim_hgnetv2_n_coco.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ __include__: [
'../base/deim.yml'
]

output_dir: ./deim_outputs/deim_hgnetv2_n_coco
output_dir: ./outputs/deim_hgnetv2_n_coco

optimizer:
type: AdamW
Expand Down
6 changes: 1 addition & 5 deletions configs/deim_dfine/dfine_hgnetv2_n_coco.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,11 @@ optimizer:
# Increase to search for the optimal ema
epoches: 160 # 148 + 4n
train_dataloader:
total_batch_size: 128
dataset:
transforms:
policy:
epoch: 148
collate_fn:
stop_epoch: 148
ema_restart_decay: 0.9999
base_size_repeat: ~

val_dataloader:
total_batch_size: 256
base_size_repeat: ~
8 changes: 5 additions & 3 deletions engine/backbone/hgnetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .common import FrozenBatchNorm2d
from ..core import register
import logging
from ..misc.dist_utils import get_rank, is_dist_available_and_initialized

# Constants for initialization
kaiming_normal_ = nn.init.kaiming_normal_
Expand Down Expand Up @@ -495,11 +496,12 @@ def __init__(self,
print(f"Loaded stage1 {name} HGNetV2 from local file.")
else:
# If the file doesn't exist locally, download from the URL
if torch.distributed.get_rank() == 0:
if get_rank() == 0:
print(GREEN + "If the pretrained HGNetV2 can't be downloaded automatically. Please check your network connection." + RESET)
print(GREEN + "Please check your network connection. Or download the model manually from " + RESET + f"{download_url}" + GREEN + " to " + RESET + f"{local_model_dir}." + RESET)
state = torch.hub.load_state_dict_from_url(download_url, map_location='cpu', model_dir=local_model_dir)
torch.distributed.barrier()
if is_dist_available_and_initialized():
torch.distributed.barrier()
else:
torch.distributed.barrier()
state = torch.load(local_model_dir)
Expand All @@ -509,7 +511,7 @@ def __init__(self,
self.load_state_dict(state)

except (Exception, KeyboardInterrupt) as e:
if torch.distributed.get_rank() == 0:
if get_rank() == 0:
print(f"{str(e)}")
logging.error(RED + "CRITICAL WARNING: Failed to load pretrained HGNetV2 model" + RESET)
logging.error(GREEN + "Please check your network connection. Or download the model manually from " \
Expand Down