diff --git a/configs/yolov12/hyp.scratch.l.yaml b/configs/yolov12/hyp.scratch.l.yaml new file mode 100644 index 00000000..fe99c71b --- /dev/null +++ b/configs/yolov12/hyp.scratch.l.yaml @@ -0,0 +1,67 @@ +optimizer: + optimizer: sgd + lr_init: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3) + momentum: 0.937 # SGD momentum/Adam beta1 + nesterov: True # update gradients with NAG(Nesterov Accelerated Gradient) algorithm + loss_scale: 1.0 # loss scale for optimizer + warmup_epochs: 3 # warmup epochs (fractions ok) + warmup_momentum: 0.8 # warmup initial momentum + warmup_bias_lr: 0.1 # warmup initial bias lr + min_warmup_step: 1000 # minimum warmup step + group_param: yolov8 # group param strategy + gp_weight_decay: 0.0005 # group param weight decay 5e-4 + start_factor: 1.0 + end_factor: 0.01 + +loss: + name: YOLOv12Loss + box: 7.5 # box loss gain + cls: 0.5 # cls loss gain + dfl: 1.5 # dfl loss gain + reg_max: 16 + +data: + num_parallel_workers: 8 + + # multi-stage data augment + train_transforms: { + stage_epochs: [ 590, 10 ], + trans_list: [ + [ + {func_name: mosaic, prob: 1.0}, + {func_name: copy_paste, prob: 0.5, sorted: True}, + {func_name: resample_segments}, + {func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.9, shear: 0.0}, + {func_name: mixup, alpha: 32.0, beta: 32.0, prob: 0.15, pre_transform: [ + {func_name: mosaic, prob: 1.0}, + {func_name: copy_paste, prob: 0.5, sorted: True}, + {func_name: resample_segments}, + {func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.9, shear: 0.0}, ] + }, + {func_name: albumentations}, + {func_name: hsv_augment, prob: 1.0, hgain: 0.015, sgain: 0.7, vgain: 0.4}, + {func_name: fliplr, prob: 0.5}, + {func_name: label_norm, xyxy2xywh_: True}, + {func_name: label_pad, padding_size: 160, padding_value: -1}, + {func_name: image_norm, scale: 255.}, + {func_name: image_transpose, bgr2rgb: True, hwc2chw: True} + ], + [ + {func_name: letterbox, scaleup: True}, + {func_name: resample_segments}, + {func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.9, shear: 0.0}, + {func_name: albumentations}, + {func_name: hsv_augment, prob: 1.0, hgain: 0.015, sgain: 0.7, vgain: 0.4}, + {func_name: fliplr, prob: 0.5}, + {func_name: label_norm, xyxy2xywh_: True}, + {func_name: label_pad, padding_size: 160, padding_value: -1}, + {func_name: image_norm, scale: 255.}, + {func_name: image_transpose, bgr2rgb: True, hwc2chw: True} + ]] + } + + test_transforms: [ + {func_name: letterbox, scaleup: False, only_image: True}, + {func_name: image_norm, scale: 255.}, + {func_name: image_transpose, bgr2rgb: True, hwc2chw: True} + ] diff --git a/configs/yolov12/hyp.scratch.m.yaml b/configs/yolov12/hyp.scratch.m.yaml new file mode 100644 index 00000000..b10fdaa6 --- /dev/null +++ b/configs/yolov12/hyp.scratch.m.yaml @@ -0,0 +1,67 @@ +optimizer: + optimizer: sgd + lr_init: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3) + momentum: 0.937 # SGD momentum/Adam beta1 + nesterov: True # update gradients with NAG(Nesterov Accelerated Gradient) algorithm + loss_scale: 1.0 # loss scale for optimizer + warmup_epochs: 3 # warmup epochs (fractions ok) + warmup_momentum: 0.8 # warmup initial momentum + warmup_bias_lr: 0.1 # warmup initial bias lr + min_warmup_step: 1000 # minimum warmup step + group_param: yolov8 # group param strategy + gp_weight_decay: 0.0005 # group param weight decay 5e-4 + start_factor: 1.0 + end_factor: 0.01 + +loss: + name: YOLOv12Loss + box: 7.5 # box loss gain + cls: 0.5 # cls loss gain + dfl: 1.5 # dfl loss gain + reg_max: 16 + +data: + num_parallel_workers: 8 + + # multi-stage data augment + train_transforms: { + stage_epochs: [ 590, 10 ], + trans_list: [ + [ + {func_name: mosaic, prob: 1.0}, + {func_name: copy_paste, prob: 0.4, sorted: True}, + {func_name: resample_segments}, + {func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.9, shear: 0.0}, + {func_name: mixup, alpha: 32.0, beta: 32.0, prob: 0.15, pre_transform: [ + {func_name: mosaic, prob: 1.0}, + {func_name: copy_paste, prob: 0.4, sorted: True}, + {func_name: resample_segments}, + {func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.9, shear: 0.0}, ] + }, + {func_name: albumentations}, + {func_name: hsv_augment, prob: 1.0, hgain: 0.015, sgain: 0.7, vgain: 0.4}, + {func_name: fliplr, prob: 0.5}, + {func_name: label_norm, xyxy2xywh_: True}, + {func_name: label_pad, padding_size: 160, padding_value: -1}, + {func_name: image_norm, scale: 255.}, + {func_name: image_transpose, bgr2rgb: True, hwc2chw: True} + ], + [ + {func_name: letterbox, scaleup: True}, + {func_name: resample_segments}, + {func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.9, shear: 0.0}, + {func_name: albumentations}, + {func_name: hsv_augment, prob: 1.0, hgain: 0.015, sgain: 0.7, vgain: 0.4}, + {func_name: fliplr, prob: 0.5}, + {func_name: label_norm, xyxy2xywh_: True}, + {func_name: label_pad, padding_size: 160, padding_value: -1}, + {func_name: image_norm, scale: 255.}, + {func_name: image_transpose, bgr2rgb: True, hwc2chw: True} + ]] + } + + test_transforms: [ + {func_name: letterbox, scaleup: False, only_image: True}, + {func_name: image_norm, scale: 255.}, + {func_name: image_transpose, bgr2rgb: True, hwc2chw: True} + ] diff --git a/configs/yolov12/hyp.scratch.n.yaml b/configs/yolov12/hyp.scratch.n.yaml new file mode 100644 index 00000000..f3fedae1 --- /dev/null +++ b/configs/yolov12/hyp.scratch.n.yaml @@ -0,0 +1,61 @@ +optimizer: + optimizer: sgd + lr_init: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3) + momentum: 0.937 # SGD momentum/Adam beta1 + nesterov: True # update gradients with NAG(Nesterov Accelerated Gradient) algorithm + loss_scale: 1.0 # loss scale for optimizer + warmup_epochs: 3 # warmup epochs (fractions ok) + warmup_momentum: 0.8 # warmup initial momentum + warmup_bias_lr: 0.1 # warmup initial bias lr + min_warmup_step: 1000 # minimum warmup step + group_param: yolov8 # group param strategy + gp_weight_decay: 0.0005 # group param weight decay 5e-4 + start_factor: 1.0 + end_factor: 0.01 + +loss: + name: YOLOv12Loss + box: 7.5 # box loss gain + cls: 0.5 # cls loss gain + dfl: 1.5 # dfl loss gain + reg_max: 16 + +data: + num_parallel_workers: 8 + + # multi-stage data augment + train_transforms: { + stage_epochs: [ 590, 10 ], + trans_list: [ + [ + {func_name: mosaic, prob: 1.0}, + {func_name: copy_paste, prob: 0.1, sorted: True}, + {func_name: resample_segments}, + {func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.5, shear: 0.0}, + {func_name: albumentations}, + {func_name: hsv_augment, prob: 1.0, hgain: 0.015, sgain: 0.7, vgain: 0.4}, + {func_name: fliplr, prob: 0.5}, + {func_name: label_norm, xyxy2xywh_: True}, + {func_name: label_pad, padding_size: 160, padding_value: -1}, + {func_name: image_norm, scale: 255.}, + {func_name: image_transpose, bgr2rgb: True, hwc2chw: True} + ], + [ + {func_name: letterbox, scaleup: True}, + {func_name: resample_segments}, + {func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.5, shear: 0.0}, + {func_name: albumentations}, + {func_name: hsv_augment, prob: 1.0, hgain: 0.015, sgain: 0.7, vgain: 0.4}, + {func_name: fliplr, prob: 0.5}, + {func_name: label_norm, xyxy2xywh_: True}, + {func_name: label_pad, padding_size: 160, padding_value: -1}, + {func_name: image_norm, scale: 255.}, + {func_name: image_transpose, bgr2rgb: True, hwc2chw: True} + ]] + } + + test_transforms: [ + {func_name: letterbox, scaleup: False, only_image: True}, + {func_name: image_norm, scale: 255.}, + {func_name: image_transpose, bgr2rgb: True, hwc2chw: True} + ] diff --git a/configs/yolov12/hyp.scratch.s.yaml b/configs/yolov12/hyp.scratch.s.yaml new file mode 100644 index 00000000..9c7e17d2 --- /dev/null +++ b/configs/yolov12/hyp.scratch.s.yaml @@ -0,0 +1,67 @@ +optimizer: + optimizer: sgd + lr_init: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3) + momentum: 0.937 # SGD momentum/Adam beta1 + nesterov: True # update gradients with NAG(Nesterov Accelerated Gradient) algorithm + loss_scale: 1.0 # loss scale for optimizer + warmup_epochs: 3 # warmup epochs (fractions ok) + warmup_momentum: 0.8 # warmup initial momentum + warmup_bias_lr: 0.1 # warmup initial bias lr + min_warmup_step: 1000 # minimum warmup step + group_param: yolov8 # group param strategy + gp_weight_decay: 0.0005 # group param weight decay 5e-4 + start_factor: 1.0 + end_factor: 0.01 + +loss: + name: YOLOv12Loss + box: 7.5 # box loss gain + cls: 0.5 # cls loss gain + dfl: 1.5 # dfl loss gain + reg_max: 16 + +data: + num_parallel_workers: 8 + + # multi-stage data augment + train_transforms: { + stage_epochs: [ 590, 10 ], + trans_list: [ + [ + {func_name: mosaic, prob: 1.0}, + {func_name: copy_paste, prob: 0.15, sorted: True}, + {func_name: resample_segments}, + {func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.9, shear: 0.0}, + {func_name: mixup, alpha: 32.0, beta: 32.0, prob: 0.05, pre_transform: [ + {func_name: mosaic, prob: 1.0}, + {func_name: copy_paste, prob: 0.15, sorted: True}, + {func_name: resample_segments}, + {func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.9, shear: 0.0}, ] + }, + {func_name: albumentations}, + {func_name: hsv_augment, prob: 1.0, hgain: 0.015, sgain: 0.7, vgain: 0.4}, + {func_name: fliplr, prob: 0.5}, + {func_name: label_norm, xyxy2xywh_: True}, + {func_name: label_pad, padding_size: 160, padding_value: -1}, + {func_name: image_norm, scale: 255.}, + {func_name: image_transpose, bgr2rgb: True, hwc2chw: True} + ], + [ + {func_name: letterbox, scaleup: True}, + {func_name: resample_segments}, + {func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.9, shear: 0.0}, + {func_name: albumentations}, + {func_name: hsv_augment, prob: 1.0, hgain: 0.015, sgain: 0.7, vgain: 0.4}, + {func_name: fliplr, prob: 0.5}, + {func_name: label_norm, xyxy2xywh_: True}, + {func_name: label_pad, padding_size: 160, padding_value: -1}, + {func_name: image_norm, scale: 255.}, + {func_name: image_transpose, bgr2rgb: True, hwc2chw: True} + ]] + } + + test_transforms: [ + {func_name: letterbox, scaleup: False, only_image: True}, + {func_name: image_norm, scale: 255.}, + {func_name: image_transpose, bgr2rgb: True, hwc2chw: True} + ] diff --git a/configs/yolov12/hyp.scratch.x.yaml b/configs/yolov12/hyp.scratch.x.yaml new file mode 100644 index 00000000..3d04e57f --- /dev/null +++ b/configs/yolov12/hyp.scratch.x.yaml @@ -0,0 +1,67 @@ +optimizer: + optimizer: sgd + lr_init: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3) + momentum: 0.937 # SGD momentum/Adam beta1 + nesterov: True # update gradients with NAG(Nesterov Accelerated Gradient) algorithm + loss_scale: 1.0 # loss scale for optimizer + warmup_epochs: 3 # warmup epochs (fractions ok) + warmup_momentum: 0.8 # warmup initial momentum + warmup_bias_lr: 0.1 # warmup initial bias lr + min_warmup_step: 1000 # minimum warmup step + group_param: yolov8 # group param strategy + gp_weight_decay: 0.0005 # group param weight decay 5e-4 + start_factor: 1.0 + end_factor: 0.01 + +loss: + name: YOLOv12Loss + box: 7.5 # box loss gain + cls: 0.5 # cls loss gain + dfl: 1.5 # dfl loss gain + reg_max: 16 + +data: + num_parallel_workers: 8 + + # multi-stage data augment + train_transforms: { + stage_epochs: [ 590, 10 ], + trans_list: [ + [ + {func_name: mosaic, prob: 1.0}, + {func_name: copy_paste, prob: 0.6, sorted: True}, + {func_name: resample_segments}, + {func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.9, shear: 0.0}, + {func_name: mixup, alpha: 32.0, beta: 32.0, prob: 0.2, pre_transform: [ + {func_name: mosaic, prob: 1.0}, + {func_name: copy_paste, prob: 0.6, sorted: True}, + {func_name: resample_segments}, + {func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.9, shear: 0.0}, ] + }, + {func_name: albumentations}, + {func_name: hsv_augment, prob: 1.0, hgain: 0.015, sgain: 0.7, vgain: 0.4}, + {func_name: fliplr, prob: 0.5}, + {func_name: label_norm, xyxy2xywh_: True}, + {func_name: label_pad, padding_size: 160, padding_value: -1}, + {func_name: image_norm, scale: 255.}, + {func_name: image_transpose, bgr2rgb: True, hwc2chw: True} + ], + [ + {func_name: letterbox, scaleup: True}, + {func_name: resample_segments}, + {func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.9, shear: 0.0}, + {func_name: albumentations}, + {func_name: hsv_augment, prob: 1.0, hgain: 0.015, sgain: 0.7, vgain: 0.4}, + {func_name: fliplr, prob: 0.5}, + {func_name: label_norm, xyxy2xywh_: True}, + {func_name: label_pad, padding_size: 160, padding_value: -1}, + {func_name: image_norm, scale: 255.}, + {func_name: image_transpose, bgr2rgb: True, hwc2chw: True} + ]] + } + + test_transforms: [ + {func_name: letterbox, scaleup: False, only_image: True}, + {func_name: image_norm, scale: 255.}, + {func_name: image_transpose, bgr2rgb: True, hwc2chw: True} + ] diff --git a/configs/yolov12/yolov12-base.yaml b/configs/yolov12/yolov12-base.yaml new file mode 100644 index 00000000..cd406f34 --- /dev/null +++ b/configs/yolov12/yolov12-base.yaml @@ -0,0 +1,50 @@ +epochs: 600 # total train epochs +per_batch_size: 32 # 32 * 8 = 256 +img_size: 640 +conf_free: True +iou_thres: 0.7 +ms_loss_scaler: dynamic +ms_loss_scaler_value: 65536.0 +clip_grad: True +anchor_base: False +opencv_threads_num: 0 # opencv: disable threading optimizations + +network: + model_name: yolov12 + nc: 80 # number of classes + reg_max: 16 + + stride: [8, 16, 32] + + # YOLOv12n backbone + backbone: + # [from, repeats, module, args] + - [-1, 1, ConvNormAct, [64, 3, 2]] # 0-P1/2 + - [-1, 1, ConvNormAct, [128, 3, 2]] # 1-P2/4 + - [-1, 2, C3k2, [256, False, 0.25]] + - [-1, 1, ConvNormAct, [256, 3, 2]] # 3-P3/8 + - [-1, 2, C3k2, [512, False, 0.25]] + - [-1, 1, ConvNormAct, [512, 3, 2]] # 5-P4/16 + - [-1, 4, A2C2f, [512, True, 4]] + - [-1, 1, ConvNormAct, [1024, 3, 2]] # 7-P5/32 + - [-1, 4, A2C2f, [1024, True, 1]] + + # YOLOv12n head + head: + - [-1, 1, Upsample, [None, 2, 'nearest']] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 2, A2C2f, [512, False, -1]] # 11 + + - [-1, 1, Upsample, [None, 2, 'nearest']] + - [[-1, 4], 1, Concat, [1] ] # cat backbone P3 + - [-1, 2, A2C2f, [256, False, -1]] # 14 + + - [-1, 1, ConvNormAct, [256, 3, 2]] + - [[ -1, 11], 1, Concat, [1]] # cat head P4 + - [-1, 2, A2C2f, [512, False, -1]] # 17 + + - [-1, 1, ConvNormAct, [512, 3, 2]] + - [[-1, 8], 1, Concat, [1]] # cat head P5 + - [-1, 2, C3k2, [1024, True]] # 20 (P5/32-large) + + - [[14, 17, 20], 1, YOLOv12Head, [nc, reg_max, stride]] # Detect(P3, P4, P5) diff --git a/configs/yolov12/yolov12l.yaml b/configs/yolov12/yolov12l.yaml new file mode 100644 index 00000000..6f1ec377 --- /dev/null +++ b/configs/yolov12/yolov12l.yaml @@ -0,0 +1,12 @@ +__BASE__: [ + '../coco.yaml', + './hyp.scratch.l.yaml', + './yolov12-base.yaml', +] + +overflow_still_update: False +network: + scale: l + depth_multiple: 1.00 # model depth multiple + width_multiple: 1.00 # layer channel multiple + max_channels: 512 diff --git a/configs/yolov12/yolov12m.yaml b/configs/yolov12/yolov12m.yaml new file mode 100644 index 00000000..cf787ed7 --- /dev/null +++ b/configs/yolov12/yolov12m.yaml @@ -0,0 +1,12 @@ +__BASE__: [ + '../coco.yaml', + './hyp.scratch.m.yaml', + './yolov12-base.yaml', +] + +overflow_still_update: False +network: + scale: m + depth_multiple: 0.50 # model depth multiple + width_multiple: 1.00 # layer channel multiple + max_channels: 512 diff --git a/configs/yolov12/yolov12n.yaml b/configs/yolov12/yolov12n.yaml new file mode 100644 index 00000000..b3c16375 --- /dev/null +++ b/configs/yolov12/yolov12n.yaml @@ -0,0 +1,12 @@ +__BASE__: [ + '../coco.yaml', + './hyp.scratch.n.yaml', + './yolov12-base.yaml', +] + +overflow_still_update: False +network: + scale: n + depth_multiple: 0.50 # model depth multiple + width_multiple: 0.25 # layer channel multiple + max_channels: 1024 \ No newline at end of file diff --git a/configs/yolov12/yolov12s.yaml b/configs/yolov12/yolov12s.yaml new file mode 100644 index 00000000..c6633ea2 --- /dev/null +++ b/configs/yolov12/yolov12s.yaml @@ -0,0 +1,12 @@ +__BASE__: [ + '../coco.yaml', + './hyp.scratch.s.yaml', + './yolov12-base.yaml', +] + +overflow_still_update: False +network: + scale: s + depth_multiple: 0.50 # model depth multiple + width_multiple: 0.50 # layer channel multiple + max_channels: 1024 diff --git a/configs/yolov12/yolov12x.yaml b/configs/yolov12/yolov12x.yaml new file mode 100644 index 00000000..849885e2 --- /dev/null +++ b/configs/yolov12/yolov12x.yaml @@ -0,0 +1,12 @@ +__BASE__: [ + '../coco.yaml', + './hyp.scratch.x.yaml', + './yolov12-base.yaml', +] + +overflow_still_update: False +network: + scale: x + depth_multiple: 1.00 # model depth multiple + width_multiple: 1.50 # layer channel multiple + max_channels: 512 diff --git a/mindyolo/data/dataset.py b/mindyolo/data/dataset.py index ef9feba3..bf1bc1a8 100644 --- a/mindyolo/data/dataset.py +++ b/mindyolo/data/dataset.py @@ -308,7 +308,8 @@ def __getitem__(self, index): _trans = ori_trans.copy() func_name, prob = _trans.pop("func_name"), _trans.pop("prob", 1.0) if func_name == 'copy_paste': - sample = self.copy_paste(sample, prob) + sorted = _trans.pop("sorted", False) + sample = self.copy_paste(sample, prob, sorted) elif random.random() < prob: if func_name == "albumentations" and getattr(self, "albumentations", None) is None: self.albumentations = Albumentations(size=self.img_size, **_trans) @@ -321,6 +322,8 @@ def __getitem__(self, index): sample['img'] = np.ascontiguousarray(sample['img']) if self.is_training: train_sample = [] + if len(sample['segments']) > 0 and not self.return_segments: + sample['segments'] = np.nan for col_name in self.column_names_getitem: if sample.get(col_name) is None: train_sample.append(np.nan) @@ -574,7 +577,7 @@ def resample_segments(self, sample, n=1000): sample['segments'] = segments return sample - def copy_paste(self, sample, probability=0.5): + def copy_paste(self, sample, probability=0.5, sorted=False): # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy) bbox_format, segment_format = sample['bbox_format'], sample['segment_format'] assert bbox_format == 'ltrb', f'The bbox format should be ltrb, but got {bbox_format}' @@ -586,12 +589,16 @@ def copy_paste(self, sample, probability=0.5): segments = sample['segments'] n = len(segments) - if probability and n: - h, w, _ = img.shape # height, width, channels - im_new = np.zeros(img.shape, np.uint8) + if len(segments) == 0 or probability == 0: + return sample + + h, w, _ = img.shape # height, width, channels + im_new = np.zeros(img.shape, np.uint8) + + if not sorted: for j in random.sample(range(n), k=round(probability * n)): c, l, s = cls[j], bboxes[j], segments[j] - box = w - l[2], l[1], w - l[0], l[3] + box = np.array([[w - l[2], l[1], w - l[0], l[3]]], dtype=np.float32) ioa = bbox_ioa(box, bboxes) # intersection over area if (ioa < 0.30).all(): # allow 30% obscuration of existing labels cls = np.concatenate((cls, [c]), 0) @@ -601,11 +608,31 @@ def copy_paste(self, sample, probability=0.5): else: segments = np.concatenate((segments, [np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1)]), 0) cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED) + else: + bboxes2 = bboxes.copy() + bboxes2[:, 0] = w - bboxes[:, 2] + bboxes2[:, 2] = w - bboxes[:, 0] + + ioa = bbox_ioa(bboxes2, bboxes) # intersection over area, (N, M) + indexes = np.nonzero((ioa < 0.30).all(1))[0] # (N, ) allow 30% obscuration of existing labels + + n = len(indexes) + sorted_idx = np.argsort(ioa.max(1)[indexes]) + indexes = indexes[sorted_idx] + for j in indexes[: round(probability * n)]: + c, s = cls[j], segments[j] + cls = np.concatenate((cls, [c]), 0) + bboxes = np.concatenate((bboxes, [bboxes2[j]]), 0) + if isinstance(segments, list): + segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1)) + else: + segments = np.concatenate((segments, [np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1)]), 0) + cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED) - result = cv2.bitwise_and(src1=img, src2=im_new) - result = cv2.flip(result, 1) # augment segments (flip left-right) - i = result > 0 # pixels to replace - img[i] = result[i] # cv2.imwrite('debug.jpg', img) # debug + result = cv2.bitwise_and(src1=img, src2=im_new) + result = cv2.flip(result, 1) # augment segments (flip left-right) + i = result > 0 # pixels to replace + img[i] = result[i] sample['img'] = img sample['cls'] = cls @@ -725,7 +752,8 @@ def mixup(self, sample, alpha: 32.0, beta: 32.0, pre_transform=None): _trans = ori_trans.copy() func_name, prob = _trans.pop("func_name"), _trans.pop("prob", 1.0) if func_name == 'copy_paste': - sample2 = self.copy_paste(sample2, prob) + sorted = _trans.pop("sorted", False) + sample2 = self.copy_paste(sample2, prob, sorted) elif random.random() < prob: if func_name == "albumentations" and getattr(self, "albumentations", None) is None: self.albumentations = Albumentations(size=self.img_size, **_trans) @@ -753,8 +781,6 @@ def pastein(self, sample, num_sample=30): assert bbox_format == 'ltrb', f'The bbox format should be ltrb, but got {bbox_format}' assert not self.return_segments, "pastein currently does not support seg data." assert not self.return_keypoints, "pastein currently does not support keypoint data." - sample.pop('segments', None) - sample.pop('keypoints', None) image = sample['img'] cls = sample['cls'] @@ -786,7 +812,7 @@ def pastein(self, sample, num_sample=30): xmax = min(w, xmin + mask_w) ymax = min(h, ymin + mask_h) - box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32) + box = np.array([[xmin, ymin, xmax, ymax]], dtype=np.float32) if len(bboxes): ioa = bbox_ioa(box, bboxes) # intersection over area else: diff --git a/mindyolo/data/utils.py b/mindyolo/data/utils.py index 70e006c2..f1f03ed4 100644 --- a/mindyolo/data/utils.py +++ b/mindyolo/data/utils.py @@ -57,24 +57,36 @@ def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1): return mask -def bbox_ioa(box1, box2): - # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2 - box2 = box2.transpose() +def bbox_ioa(box1, box2, iou=False, eps=1e-7): + """ + Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format. + + Args: + box1 (np.ndarray): A numpy array of shape (n, 4) representing n bounding boxes. + box2 (np.ndarray): A numpy array of shape (m, 4) representing m bounding boxes. + iou (bool): Calculate the standard IoU if True else return inter_area/box2_area. + eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7. + Returns: + (np.ndarray): A numpy array of shape (n, m) representing the intersection over box2 area. + """ # Get the coordinates of bounding boxes - b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] - b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] + b1_x1, b1_y1, b1_x2, b1_y2 = box1.T + b2_x1, b2_y1, b2_x2, b2_y2 = box2.T # Intersection area - inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * ( - np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1) + inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * ( + np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1) ).clip(0) - # box2 area - box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16 + # Box2 area + area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + if iou: + box1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) + area = area + box1_area[:, None] - inter_area # Intersection over box2 area - return inter_area / box2_area + return inter_area / (area + eps) def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): diff --git a/mindyolo/models/__init__.py b/mindyolo/models/__init__.py index 566ca765..8b888ccb 100644 --- a/mindyolo/models/__init__.py +++ b/mindyolo/models/__init__.py @@ -1,10 +1,11 @@ from . import (heads, initializer, layers, losses, model_factory, yolov3, - yolov4, yolov5, yolov7, yolov8, yolov9, yolov10) + yolov4, yolov5, yolov7, yolov8, yolov9, yolov10, yolov12) __all__ = [] __all__.extend(heads.__all__) __all__.extend(layers.__all__) __all__.extend(losses.__all__) +__all__.extend(yolov12.__all__) __all__.extend(yolov9.__all__) __all__.extend(yolov10.__all__) __all__.extend(yolov8.__all__) @@ -29,4 +30,5 @@ from .yolov8 import * from .yolov9 import * from .yolov10 import * +from .yolov12 import * from .yolox import * diff --git a/mindyolo/models/heads/__init__.py b/mindyolo/models/heads/__init__.py index 6a910599..5dd9ff21 100644 --- a/mindyolo/models/heads/__init__.py +++ b/mindyolo/models/heads/__init__.py @@ -7,6 +7,7 @@ from .yolov9_head import * from .yolox_head import * from .yolov10_head import * +from .yolov12_head import * __all__ = [ "YOLOv3Head", @@ -16,5 +17,6 @@ "YOLOv8Head", "YOLOv8SegHead", "YOLOXHead", "YOLOv9Head", - "YOLOv10Head" + "YOLOv10Head", + "YOLOv12Head", ] diff --git a/mindyolo/models/heads/yolov12_head.py b/mindyolo/models/heads/yolov12_head.py new file mode 100644 index 00000000..94d03c5b --- /dev/null +++ b/mindyolo/models/heads/yolov12_head.py @@ -0,0 +1,119 @@ +import math +import numpy as np + +import mindspore as ms +import mindspore.numpy as mnp +from mindspore import Parameter, Tensor, nn, ops + +from ..layers import DFL, ConvNormAct, Identity, DWConv +from ..layers.utils import meshgrid + +class YOLOv12Head(nn.Cell): + # YOLOv12 Detect head for detection models + def __init__(self, nc=80, reg_max=16, stride=(), ch=(), sync_bn=False): # detection layer + super().__init__() + + assert isinstance(stride, (tuple, list)) and len(stride) > 0 + assert isinstance(ch, (tuple, list)) and len(ch) > 0 + + self.nc = nc # number of classes + self.nl = len(ch) # number of detection layers + self.reg_max = reg_max # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x) + self.no = nc + self.reg_max * 4 # number of outputs per anchor + self.stride = Parameter(Tensor(stride, ms.int32), requires_grad=False) + + c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels + self.cv2 = nn.CellList( + [ + nn.SequentialCell( + [ + ConvNormAct(x, c2, 3, sync_bn=sync_bn), + ConvNormAct(c2, c2, 3, sync_bn=sync_bn), + nn.Conv2d(c2, 4 * self.reg_max, 1, has_bias=True), + ] + ) + for x in ch + ] + ) + self.cv3 = nn.CellList( + [ + nn.SequentialCell( + [ + nn.SequentialCell( + [ + DWConv(x, x, 3, sync_bn=sync_bn), + ConvNormAct(x, c3, 1, sync_bn=sync_bn), + ] + ), + nn.SequentialCell([ + DWConv(c3, c3, 3, sync_bn=sync_bn), + ConvNormAct(c3, c3, 1, sync_bn=sync_bn), + ] + ), + nn.Conv2d(c3, self.nc, 1, has_bias=True) + ] + ) + for x in ch + ] + ) + self.dfl = DFL(self.reg_max) if self.reg_max > 1 else Identity() + + def construct(self, x): + # Performs forward pass of the YOLOv12Head module. Returns predicted bounding boxes and class probabilities + shape = x[0].shape # BCHW + out = () + for i in range(self.nl): + out += (ops.concat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1),) + + p = None + if not self.training: + _anchors, _strides = self.make_anchors(out, self.stride, 0.5) + _anchors, _strides = _anchors.swapaxes(0, 1), _strides.swapaxes(0, 1) + _x = () + for i in range(len(out)): + _x += (out[i].view(shape[0], self.no, -1),) + _x = ops.concat(_x, 2) + box, cls = _x[:, : self.reg_max * 4, :], _x[:, self.reg_max * 4 : self.reg_max * 4 + self.nc, :] + # box, cls = ops.concat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1) + dbox = self.dist2bbox(self.dfl(box), ops.expand_dims(_anchors, 0), xywh=True, axis=1) * _strides + p = ops.concat((dbox, ops.Sigmoid()(cls)), 1) + p = ops.transpose(p, (0, 2, 1)) # (bs, no-84, nbox) -> (bs, nbox, no-84) + + return out if self.training else (p, out) + + @staticmethod + def make_anchors(feats, strides, grid_cell_offset=0.5): + """Generate anchors from features.""" + anchor_points, stride_tensor = (), () + dtype = feats[0].dtype + for i, stride in enumerate(strides): + _, _, h, w = feats[i].shape + sx = mnp.arange(w, dtype=dtype) + grid_cell_offset # shift x + sy = mnp.arange(h, dtype=dtype) + grid_cell_offset # shift y + # FIXME: Not supported on a specific model of machine + sy, sx = meshgrid((sy, sx), indexing="ij") + anchor_points += (ops.stack((sx, sy), -1).view(-1, 2),) + stride_tensor += (ops.ones((h * w, 1), dtype) * stride,) + return ops.concat(anchor_points), ops.concat(stride_tensor) + + @staticmethod + def dist2bbox(distance, anchor_points, xywh=True, axis=-1): + """Transform distance(ltrb) to box(xywh or xyxy).""" + lt, rb = ops.split(distance, split_size_or_sections=2, axis=axis) + x1y1 = anchor_points - lt + x2y2 = anchor_points + rb + if xywh: + c_xy = (x1y1 + x2y2) / 2 + wh = x2y2 - x1y1 + return ops.concat((c_xy, wh), axis) # xywh bbox + return ops.concat((x1y1, x2y2), axis) # xyxy bbox + + def initialize_biases(self): + # Initialize Detect() biases, WARNING: requires stride availability + m = self + for a, b, s in zip(m.cv2, m.cv3, m.stride): # from + s = s.asnumpy() + a[-1].bias = ops.assign(a[-1].bias, Tensor(np.ones(a[-1].bias.shape), ms.float32)) + b_np = b[-1].bias.data.asnumpy() + b_np[: m.nc] = math.log(5 / m.nc / (640 / int(s)) ** 2) + b[-1].bias = ops.assign(b[-1].bias, Tensor(b_np, ms.float32)) \ No newline at end of file diff --git a/mindyolo/models/initializer.py b/mindyolo/models/initializer.py index 39f42f35..59caf752 100644 --- a/mindyolo/models/initializer.py +++ b/mindyolo/models/initializer.py @@ -1,9 +1,9 @@ import math -from mindspore import nn +from mindspore import nn, Parameter from mindspore.common import initializer as init -__all__ = ["initialize_defult"] +__all__ = ["initialize_defult", "trunc_normal_"] def initialize_defult(model): @@ -43,3 +43,8 @@ def _calculate_fan_in_and_fan_out(shape): fan_out = num_output_fmaps * receptive_field_size return fan_in, fan_out + +def trunc_normal_(tensor: Parameter, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0): + tensor.set_data( + init.initializer(init.TruncatedNormal(std, mean, a, b), tensor.shape, tensor.dtype) + ) \ No newline at end of file diff --git a/mindyolo/models/layers/__init__.py b/mindyolo/models/layers/__init__.py index 307b411b..d6f379ad 100644 --- a/mindyolo/models/layers/__init__.py +++ b/mindyolo/models/layers/__init__.py @@ -46,4 +46,6 @@ "SCDown", "PSA", "C2fCIB", + "C3k2", + "A2C2f", ] diff --git a/mindyolo/models/layers/bottleneck.py b/mindyolo/models/layers/bottleneck.py index c2b192ac..6eadd6f2 100644 --- a/mindyolo/models/layers/bottleneck.py +++ b/mindyolo/models/layers/bottleneck.py @@ -1,7 +1,8 @@ -from mindspore import nn, ops +from mindspore import nn, ops, Parameter +from mindspore.common.initializer import Constant, initializer from .conv import ConvNormAct, DWConvNormAct, RepConv - +from ..initializer import trunc_normal_ class Bottleneck(nn.Cell): # Standard bottleneck @@ -37,7 +38,7 @@ def construct(self, x): class C3(nn.Cell): # CSP Bottleneck with 3 convolutions - def __init__(self, c1, c2, n=1, shortcut=True, e=0.5, momentum=0.97, eps=1e-3, sync_bn=False): + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, momentum=0.97, eps=1e-3, sync_bn=False): super(C3, self).__init__() c_ = int(c2 * e) # hidden channels self.conv1 = ConvNormAct(c1, c_, 1, 1, momentum=momentum, eps=eps, sync_bn=sync_bn) @@ -45,7 +46,7 @@ def __init__(self, c1, c2, n=1, shortcut=True, e=0.5, momentum=0.97, eps=1e-3, s self.conv3 = ConvNormAct(2 * c_, c2, 1, momentum=momentum, eps=eps, sync_bn=sync_bn) # act=FReLU(c2) self.m = nn.SequentialCell( [ - Bottleneck(c_, c_, shortcut, k=(1, 3), e=1.0, momentum=momentum, eps=eps, sync_bn=sync_bn) + Bottleneck(c_, c_, shortcut, k=(1, 3), g=(1, g), e=1.0, momentum=momentum, eps=eps, sync_bn=sync_bn) for _ in range(n) ] ) @@ -293,3 +294,147 @@ def __init__( ] ) +class C3k2(C2f): + # CSP Bottleneck with 2 convolutions + def __init__(self, c1, c2, n=1, c3k=False, e=0.5, g=1, shortcut=True, sync_bn=False): + # ch_in, ch_out, number, c3k, expansion, groups, shortcut, sync_bn + super().__init__(c1, c2, n, shortcut, g, e, sync_bn=sync_bn) + self.m = nn.CellList( + [ + C3k(self.c, self.c, 2, shortcut, g, sync_bn=sync_bn) if c3k else Bottleneck(self.c, self.c, shortcut, k=(3, 3), g=(1, g), sync_bn=sync_bn) + for _ in range(n) + ] + ) + +class C3k(C3): + # C3k is a CSP bottleneck module with customizable kernel sizes for feature extraction in neural networks. + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, k=3, sync_bn=False): + super().__init__(c1, c2, n, shortcut, g, e, sync_bn=sync_bn) + c_ = int(c2 * e) # hidden channels + self.m = nn.SequentialCell( + [ + Bottleneck(c_, c_, shortcut, k=(k, k), g=(1, g), e=1.0, sync_bn=sync_bn) + for _ in range(n) + ] + ) + +class AAttn(nn.Cell): + """ + Area-attention module for YOLO models, providing efficient attention mechanisms. + """ + def __init__(self, dim, num_heads, area=1): + """ + Initializes an Area-attention module for YOLO models. + + Args: + dim (int): Number of hidden channels; + num_heads (int): Number of heads into which the attention mechanism is divided; + area (int, optional): Number of areas the feature map is divided. Defaults to 1. + """ + super().__init__() + self.area = area + + self.num_heads = num_heads + self.head_dim = head_dim = dim // num_heads + all_head_dim = head_dim * self.num_heads + + self.qkv = ConvNormAct(dim, all_head_dim * 3, k=1, act=False) + self.proj = ConvNormAct(all_head_dim, dim, k=1, act=False) + self.pe = ConvNormAct(all_head_dim, dim, k=7, s=1, p=3, g=dim, act=False) + + def construct(self, x): + B, C, H, W = x.shape + N = H * W + + qkv = ops.transpose(ops.flatten(self.qkv(x), start_dim=2), (0, 2, 1)) + if self.area > 1: + qkv = qkv.reshape(B * self.area, N // self.area, C * 3) + B, N, _ = qkv.shape + + q, k, v = ops.transpose(qkv.view(B, N, self.num_heads, self.head_dim * 3), (0, 2, 3, 1)).split([self.head_dim, self.head_dim, self.head_dim], axis=2) + attn = ( + (ops.transpose(q, (0, 1, 3, 2)) @ k) * (self.head_dim**-0.5) + ) + attn = ops.softmax(attn, -1) + x = v @ ops.transpose(attn, (0, 1, 3, 2)) + x = ops.transpose(x, (0, 3, 1, 2)) + v = ops.transpose(v, (0, 3, 1, 2)) + + if self.area > 1: + x = x.reshape(B // self.area, N * self.area, C) + v = v.reshape(B // self.area, N * self.area, C) + B, N, _ = x.shape + + x = ops.transpose(x.reshape(B, H, W, C), (0, 3, 1, 2)) + v = ops.transpose(v.reshape(B, H, W, C), (0, 3, 1, 2)) + + x = x + self.pe(v) + return self.proj(x) + +class ABlock(nn.Cell): + """ + Area-attention block module for efficient feature extraction in YOLO models. + + This module implements an area-attention mechanism combined with a feed-forward network for processing feature maps. + It uses a novel area-based attention approach that is more efficient than traditional self-attention while + maintaining effectiveness. + """ + def __init__(self, dim, num_heads, mlp_ratio=1.2, area=1): + super().__init__() + + self.attn = AAttn(dim, num_heads=num_heads, area=area) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = nn.SequentialCell(ConvNormAct(dim, mlp_hidden_dim, k=1), ConvNormAct(mlp_hidden_dim, dim, k=1, act=False)) + + def _init_weights(self, m): + # Initialize weights using a truncated normal distribution + if isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + m.bias.set_data(initializer(Constant(0), shape=m.bias.shape, dtype=m.bias.dtype)) + + def construct(self, x): + # Performs a forward pass through ABlock, applying area-attention and feed-forward layers to the input tensor. + x = x + self.attn(x) + return x + self.mlp(x) + +class A2C2f(nn.Cell): + """ + Area-Attention C2f module for enhanced feature extraction with area-based attention mechanisms. + This module extends the C2f architecture by incorporating area-attention and ABlock layers for improved feature + processing. It supports both area-attention and standard convolution modes. + """ + def __init__(self, c1, c2, n=1, a2=True, area=1, residual=False, mlp_ratio=2.0, e=0.5, g=1, shortcut=True): + super().__init__() + c_ = int(c2 * e) # hidden channels + assert c_ % 32 == 0, "Dimension of ABlock be a multiple of 32." + num_heads = c_ // 32 + + self.cv1 = ConvNormAct(c1, c_, k=1, s=1) + self.cv2 = ConvNormAct((1 + n) * c_, c2, k=1) + + init_values = 0.01 + self.gamma = Parameter(init_values * ops.ones((c2)), requires_grad=True) if a2 and residual else None + + self.m = nn.CellList( + [ + nn.SequentialCell( + [ + ABlock(c_, num_heads, mlp_ratio, area) for _ in range(2) + ] + ) if a2 else C3k(c_, c_, 2, shortcut, g) + for _ in range(n) + ]) + + def construct(self, x): + # Performs a forward pass through R-ELAN layer. + x1 = self.cv1(x) + y = (x1, ) + for i in range(len(self.m)): + m = self.m[i] + out = m(y[-1]) + y += (out,) + y = self.cv2(ops.concat(y, 1)) + if self.gamma is not None: + return x + self.gamma.view(1, -1, 1, 1) * y + return y \ No newline at end of file diff --git a/mindyolo/models/layers/conv.py b/mindyolo/models/layers/conv.py index 7c1bb683..9c5bf676 100644 --- a/mindyolo/models/layers/conv.py +++ b/mindyolo/models/layers/conv.py @@ -1,3 +1,5 @@ +import math + from mindspore import nn, ops from .common import Identity @@ -262,3 +264,9 @@ def construct(self, x): x2 = self.max_pool2d(x2) x2 = self.cv2(x2) return ops.cat((x1, x2), 1) + +class DWConv(ConvNormAct): + """Depth-wise convolution.""" + def __init__(self, c1, c2, k=1, s=1, d=1, act=True, sync_bn=False): # ch_in, ch_out, kernel, stride, dilation, activation + """Initialize Depth-wise convolution with given parameters.""" + super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act, sync_bn=sync_bn) \ No newline at end of file diff --git a/mindyolo/models/losses/__init__.py b/mindyolo/models/losses/__init__.py index c64c2ef1..0d2bad2b 100644 --- a/mindyolo/models/losses/__init__.py +++ b/mindyolo/models/losses/__init__.py @@ -1,5 +1,5 @@ from . import (loss_factory, yolov3_loss, yolov4_loss, yolov5_loss, - yolov7_loss, yolov8_loss, yolov9_loss, yolov10_loss) + yolov7_loss, yolov8_loss, yolov9_loss, yolov10_loss, yolov12_loss) from .loss_factory import * from .yolov3_loss import * from .yolov4_loss import * @@ -8,6 +8,7 @@ from .yolov8_loss import * from .yolov9_loss import * from .yolov10_loss import * +from .yolov12_loss import * from .yolox_loss import * __all__ = [] @@ -18,4 +19,5 @@ __all__.extend(yolov8_loss.__all__) __all__.extend(yolov9_loss.__all__) __all__.extend(yolov10_loss.__all__) +__all__.extend(yolov12_loss.__all__) __all__.extend(loss_factory.__all__) diff --git a/mindyolo/models/losses/yolov12_loss.py b/mindyolo/models/losses/yolov12_loss.py new file mode 100644 index 00000000..29536e4d --- /dev/null +++ b/mindyolo/models/losses/yolov12_loss.py @@ -0,0 +1,18 @@ +from mindspore import nn, ops + +from .yolov8_loss import YOLOv8Loss +from mindyolo.models.registry import register_model + +__all__ = ["YOLOv12Loss"] + +@register_model +class YOLOv12Loss(nn.Cell): + def __init__(self, box, cls, dfl, stride, nc, reg_max=16, tal_topk=10, **kwargs): + super().__init__() + self.loss = YOLOv8Loss(box, cls, dfl, stride, nc, reg_max=reg_max, tal_topk=tal_topk, **kwargs) + # branch name returned by lossitem for print + self.loss_item_name = ["loss", "lbox", "lcls", "dfl"] + + def construct(self, feats, targets, imgs): + # YOLOV12 Loss + return self.loss(feats, targets, imgs) \ No newline at end of file diff --git a/mindyolo/models/model_factory.py b/mindyolo/models/model_factory.py index 2092acd0..07e79602 100644 --- a/mindyolo/models/model_factory.py +++ b/mindyolo/models/model_factory.py @@ -155,6 +155,8 @@ def parse_model(d, ch, nc, sync_bn=False): # model_dict, input_channels(3) SCDown, PSA, C2fCIB, + C3k2, + A2C2f, ): c1, c2 = ch[f], args[0] if max_channels: @@ -182,9 +184,13 @@ def parse_model(d, ch, nc, sync_bn=False): # model_dict, input_channels(3) ADown ): kwargs["sync_bn"] = sync_bn - if m in (DownC, SPPCSPC, C3, C2f, DWC3, C2fCIB): + if m in (DownC, SPPCSPC, C3, C2f, DWC3, C2fCIB, C3k2, A2C2f): args.insert(2, n) # number of repeats n = 1 + if m is C3k2 and d.get("scale") in "mlx": + args[3] = True + if m is A2C2f and d.get("scale") in "lx": + args.extend((True, 1.2)) elif m in (nn.BatchNorm2d, nn.SyncBatchNorm): args = [ch[f]] elif m in (Concat,): @@ -195,7 +201,7 @@ def parse_model(d, ch, nc, sync_bn=False): # model_dict, input_channels(3) args.append([ch[x] for x in f]) if isinstance(args[1], int): # number of anchors args[1] = [list(range(args[1] * 2))] * len(f) - elif m in (YOLOv10Head, YOLOv9Head, YOLOv8Head, YOLOv8SegHead, YOLOXHead): # head of anchor free + elif m in (YOLOv12Head, YOLOv10Head, YOLOv9Head, YOLOv8Head, YOLOv8SegHead, YOLOXHead): # head of anchor free args.append([ch[x] for x in f]) if m in (YOLOv8SegHead,): args[3] = math.ceil(min(args[3], max_channels) * gw / 8) * 8 diff --git a/mindyolo/models/yolov12.py b/mindyolo/models/yolov12.py new file mode 100644 index 00000000..a8c65ef2 --- /dev/null +++ b/mindyolo/models/yolov12.py @@ -0,0 +1,59 @@ +import numpy as np + +import mindspore as ms +from mindspore import Tensor, nn + +from mindyolo.models.layers.bottleneck import A2C2f, ABlock, C3k +from mindyolo.models.heads.yolov12_head import YOLOv12Head +from mindyolo.models.model_factory import build_model_from_cfg +from mindyolo.models.registry import register_model + +__all__ = ["YOLOv12", "yolov12"] + + +def _cfg(url="", **kwargs): + return {"url": url, **kwargs} + + +default_cfgs = {"yolov12": _cfg(url="")} + + +class YOLOv12(nn.Cell): + def __init__(self, cfg, in_channels=3, num_classes=None, sync_bn=False): + super(YOLOv12, self).__init__() + self.cfg = cfg + self.stride = Tensor(np.array(cfg.stride), ms.int32) + self.stride_max = int(max(self.cfg.stride)) + ch, nc = in_channels, num_classes + + self.nc = nc # override yaml value + self.model = build_model_from_cfg(model_cfg=cfg, in_channels=ch, num_classes=nc, sync_bn=sync_bn) + self.names = [str(i) for i in range(nc)] # default names + + self.initialize_weights() + + def construct(self, x): + return self.model(x) + + def initialize_weights(self): + # reset parameter for Detect Head + for i in range(len(self.model.model)-1): + m = self.model.model[i] + if isinstance(m, A2C2f) and isinstance(m.m[0], C3k): + pass + elif isinstance(m, A2C2f) and isinstance(m.m[0][0], ABlock): + for i in range(len(m.m)): + for j in range(len(m.m[0])): + m.m[i][j].apply(m.m[i][j]._init_weights) + + m = self.model.model[-1] + if isinstance(m, YOLOv12Head): + m.initialize_biases() + m.dfl.initialize_conv_weight() + + +@register_model +def yolov12(cfg, in_channels=3, num_classes=None, **kwargs) -> YOLOv12: + """Get yolov12 model.""" + model = YOLOv12(cfg=cfg, in_channels=in_channels, num_classes=num_classes, **kwargs) + return model