diff --git a/configs/co_detr/README.md b/configs/co_detr/README.md
new file mode 100644
index 0000000000..2dc370f3ae
--- /dev/null
+++ b/configs/co_detr/README.md
@@ -0,0 +1,39 @@
+# DETR
+
+## Introduction
+
+
+DETR is an object detection model based on transformer. We reproduced the model of the paper.
+
+
+## Model Zoo
+
+| Backbone | Model | Images/GPU  | Inf time (fps) | Box AP | Config | Download |
+|:------:|:--------:|:--------:|:--------------:|:------:|:------:|:--------:|
+| R-50 | DETR  | 4 | --- | 42.3 | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/detr/detr_r50_1x_coco.yml) | [model](https://paddledet.bj.bcebos.com/models/detr_r50_1x_coco.pdparams) |
+
+**Notes:**
+
+- DETR is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`.
+- DETR uses 8GPU to train 500 epochs.
+
+GPU multi-card training
+```bash
+export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
+python -m paddle.distributed.launch --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/detr/detr_r50_1x_coco.yml --fleet
+```
+
+## Citations
+```
+@inproceedings{detr,
+  author    = {Nicolas Carion and
+               Francisco Massa and
+               Gabriel Synnaeve and
+               Nicolas Usunier and
+               Alexander Kirillov and
+               Sergey Zagoruyko},
+  title     = {End-to-End Object Detection with Transformers},
+  booktitle = {ECCV},
+  year      = {2020}
+}
+```
diff --git a/configs/co_detr/_base_/co_detr_r50.yml b/configs/co_detr/_base_/co_detr_r50.yml
new file mode 100644
index 0000000000..d4a3721e40
--- /dev/null
+++ b/configs/co_detr/_base_/co_detr_r50.yml
@@ -0,0 +1,187 @@
+architecture: CO_DETR
+# pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vb_normal_pretrained.pdparams
+pretrain_weights: /home/aistudio/co_deformable_detr_r50_1x_coco.pdparams
+
+# model settings
+num_dec_layer: &num_dec_layer 6
+lambda_2: &lambda_2 2.0
+
+CO_DETR:
+  backbone: ResNet
+  neck: ChannelMapper
+  query_head: CoDeformDETRHead
+  rpn_head: RPNHead
+  roi_head: Co_RoiHead
+  bbox_head:
+    name: CoATSSHead
+    num_classes: 80
+    in_channels: 256
+    stacked_convs: 1
+    feat_channels: 256
+    anchor_generator: 
+      name: CoAnchorGenerator
+      octave_base_scale: 8
+      scales_per_octave: 1
+      aspect_ratios: [1.0]
+      strides: [8, 16, 32, 64, 128]
+    assigner: 
+      name: ATSSAssigner
+      topk: 9
+    loss_cls: 
+      name: Weighted_FocalLoss
+      use_sigmoid: true
+      gamma: 2.0
+      alpha: 0.25
+    loss_bbox: 
+      name: GIoULoss
+
+
+ResNet:
+  # index 0 stands for res2
+  depth: 50
+  norm_type: bn
+  freeze_at: 0
+  return_idx: [1,2,3]
+  lr_mult_list: [0.0, 0.1, 0.1, 0.1]
+  num_stages: 4
+
+ChannelMapper:
+  in_channels: [512, 1024, 2048]
+  kernel_size: 1
+  out_channels: 256
+  norm_type: "gn"
+  norm_groups: 32
+  act: None
+  num_outs: 4
+  
+
+CoDeformDETRHead:
+  num_query: 300
+  num_classes: 80
+  in_channels: 2048
+  sync_cls_avg_factor: True
+  with_box_refine: True
+  as_two_stage: True
+  mixed_selection: True
+  transformer:
+    name: CoDeformableDetrTransformer
+    num_co_heads: 2
+    as_two_stage: True
+    mixed_selection: True
+    encoder:
+      name: CoTransformerEncoder
+      num_layers: *num_dec_layer
+      out_channel: 256
+      encoder_layer:
+        name: TransformerEncoderLayer
+        d_model: 256
+        attn:
+          name: MSDeformableAttention
+          embed_dim: 256
+          num_heads: 8
+          num_levels: 4
+          num_points: 4
+        dim_feedforward: 2048
+        dropout: 0.0
+    decoder:
+      name: CoDeformableDetrTransformerDecoder
+      num_layers: *num_dec_layer
+      return_intermediate: True
+      look_forward_twice: True
+      decoder_layer:
+        name: PETR_TransformerDecoderLayer
+        d_model: 256
+        dim_feedforward: 2048
+        dropout: 0.0
+        self_attn:
+          name: MultiHeadAttention
+          embed_dim: 256
+          num_heads: 8
+          dropout: 0.0
+        cross_attn:
+          name: MSDeformableAttention
+          embed_dim: 256
+  positional_encoding:
+    name: PositionEmbedding
+    num_pos_feats: 128
+    normalize: true
+    offset: -0.5
+  loss_cls:
+    name: Weighted_FocalLoss
+    use_sigmoid: true
+    gamma: 2.0
+    alpha: 0.25
+    loss_weight: 2.0
+  loss_bbox:
+    name: L1Loss
+    loss_weight: 5.0
+  loss_iou:
+    name: GIoULoss
+    loss_weight: 2.0
+  assigner:
+    name: HungarianAssigner
+    cls_cost:
+      name: FocalLossCost
+      weight: 2.0
+    reg_cost:
+      name: BBoxL1Cost
+      weight: 5.0
+      box_format: xywh
+    iou_cost:
+      name: IoUCost
+      iou_mode: giou
+      weight: 2.0
+  test_cfg:
+    max_per_img: 100
+    score_thr: 0.0
+    nms: false
+  nms: 
+    name: MultiClassNMS
+    keep_top_k: 100
+    score_threshold: 0.05
+    nms_threshold: 0.6
+
+RPNHead:
+  loss_rpn_bbox: L1Loss
+  in_channel: 256
+  anchor_generator: 
+    name: RetinaAnchorGenerator
+    octave_base_scale: 4
+    scales_per_octave: 3
+    aspect_ratios: [0.5, 1.0, 2.0]
+    strides: [8.0, 16.0, 32.0, 64.0, 128.0]
+  rpn_target_assign:
+    batch_size_per_im: 256
+    fg_fraction: 0.5
+    negative_overlap: 0.3
+    positive_overlap: 0.7
+    use_random: True
+  train_proposal:
+    min_size: 0.0
+    nms_thresh: 0.7
+    pre_nms_top_n: 4000
+    post_nms_top_n: 1000
+    topk_after_collect: True
+  test_proposal:
+    min_size: 0.0
+    nms_thresh: 0.7
+    pre_nms_top_n: 1000
+    post_nms_top_n: 1000
+
+Co_RoiHead:
+  in_channel: 256
+  num_classes: 80
+  head: TwoFCHead
+  roi_extractor:
+    resolution: 7
+    sampling_ratio: 0
+    aligned: True
+  bbox_assigner: 
+    name: BBoxAssigner
+    batch_size_per_im: 512
+    bg_thresh: 0.5
+    fg_thresh: 0.5
+    fg_fraction: 0.25
+    use_random: True
+  bbox_loss: 
+    name: GIoULoss
diff --git a/configs/co_detr/_base_/co_detr_reader.yml b/configs/co_detr/_base_/co_detr_reader.yml
new file mode 100644
index 0000000000..6f10ab454d
--- /dev/null
+++ b/configs/co_detr/_base_/co_detr_reader.yml
@@ -0,0 +1,47 @@
+worker_num: 0
+TrainReader:
+  sample_transforms:
+  - Decode: {}
+  - RandomFlip: {prob: 0.5}
+  - RandomSelect: { transforms1: [ RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ],
+                    transforms2: [
+                        RandomShortSideResize: { short_side_sizes: [ 400, 500, 600 ] },
+                        RandomSizeCrop: { min_size: 384, max_size: 600 },
+                        RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ] }
+  - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
+  - NormalizeBox: {}
+  - BboxXYXY2XYWH: {}
+  - Permute: {}
+  batch_transforms:
+  - PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
+  batch_size: 2
+  shuffle: false
+  drop_last: true
+  collate_batch: false
+  use_shared_memory: false
+
+
+EvalReader:
+  sample_transforms:
+    - Decode: {}
+    # - PETR_Resize: {img_scale: [[800, 1333]], keep_ratio: True}
+    - Resize: {target_size: [800, 1333], keep_ratio: True, interp: 1}
+    - NormalizeImage:
+        mean: [0.485,0.456,0.406]
+        std: [0.229, 0.224,0.225]
+        is_scale: true
+    - Permute: {}
+  batch_size: 1
+  shuffle: false
+  drop_last: false
+
+
+TestReader:
+  sample_transforms:
+  - Decode: {}
+  - Resize: {target_size: [800, 1333], keep_ratio: True}
+  - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
+  - Permute: {}
+  batch_size: 1
+  shuffle: false
+  drop_last: false
diff --git a/configs/co_detr/_base_/optimizer_1x.yml b/configs/co_detr/_base_/optimizer_1x.yml
new file mode 100644
index 0000000000..13528c5eba
--- /dev/null
+++ b/configs/co_detr/_base_/optimizer_1x.yml
@@ -0,0 +1,16 @@
+epoch: 500
+
+LearningRate:
+  base_lr: 0.0001
+  schedulers:
+  - !PiecewiseDecay
+    gamma: 0.1
+    milestones: [400]
+    use_warmup: false
+
+OptimizerBuilder:
+  clip_grad_by_norm: 0.1
+  regularizer: false
+  optimizer:
+    type: AdamW
+    weight_decay: 0.0001
diff --git a/configs/co_detr/co_detr_r50_1x_coco.yml b/configs/co_detr/co_detr_r50_1x_coco.yml
new file mode 100644
index 0000000000..be03708020
--- /dev/null
+++ b/configs/co_detr/co_detr_r50_1x_coco.yml
@@ -0,0 +1,9 @@
+_BASE_: [
+  '../datasets/coco_detection.yml',
+  '../runtime.yml',
+  '_base_/optimizer_1x.yml',
+  '_base_/co_detr_r50.yml',
+  '_base_/co_detr_reader.yml',
+]
+weights: /home/aistudio/co_deformable_detr_r50_1x_coco.pdparams
+find_unused_parameters: True
diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py
index d22df32d85..4803b97e42 100644
--- a/ppdet/modeling/architectures/__init__.py
+++ b/ppdet/modeling/architectures/__init__.py
@@ -45,6 +45,7 @@
 from . import detr_ssod
 from . import multi_stream_detector
 from . import clrnet
+from . import co_detr
 
 from .meta_arch import *
 from .faster_rcnn import *
@@ -68,6 +69,7 @@
 from .gfl import *
 from .picodet import *
 from .detr import *
+from .co_detr import *
 from .sparse_rcnn import *
 from .tood import *
 from .retinanet import *
diff --git a/ppdet/modeling/architectures/co_detr.py b/ppdet/modeling/architectures/co_detr.py
new file mode 100644
index 0000000000..c007b73b4d
--- /dev/null
+++ b/ppdet/modeling/architectures/co_detr.py
@@ -0,0 +1,218 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+import numpy as np
+from .meta_arch import BaseArch
+from ppdet.core.workspace import register, create
+
+__all__ = ['CO_DETR']
+# Collaborative DETR, DINO use the same architecture as DETR
+
+def bbox2result(bboxes, labels, num_classes):
+    """Convert detection results to a list of numpy arrays.
+
+    Args:
+        bboxes (paddle.Tensor | np.ndarray): shape (n, 5)
+        labels (paddle.Tensor | np.ndarray): shape (n, )
+        num_classes (int): class number, including background class
+
+    Returns:
+        list(ndarray): bbox results of each class
+    """
+    if bboxes.shape[0] == 0:
+        return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)]
+    else:
+        if isinstance(bboxes, paddle.Tensor):
+            bboxes = bboxes.detach().cpu().numpy()
+            labels = labels.detach().cpu().numpy()
+        return [bboxes[labels == i, :] for i in range(num_classes)]
+    
+
+@register
+class CO_DETR(BaseArch):
+    __category__ = 'architecture'
+    __inject__ = ['bbox_head']
+
+    def __init__(self,
+                 backbone,
+                 neck=None,
+                 query_head=None,
+                 rpn_head=None,
+                 roi_head=None,
+                 bbox_head=None,
+                 with_pos_coord=True,
+                 with_attn_mask=True,
+                 ):
+        super(CO_DETR, self).__init__()
+        self.backbone = backbone
+        if neck is not None:
+            self.with_neck = True
+        self.neck = neck
+        self.query_head = query_head
+        self.rpn_head = rpn_head
+        self.roi_head = roi_head
+        self.bbox_head = bbox_head
+        self.deploy = False
+        self.with_pos_coord = with_pos_coord
+        self.with_attn_mask = with_attn_mask
+    
+    @classmethod
+    def from_config(cls, cfg, *args, **kwargs):
+        backbone = create(cfg['backbone'])
+        kwargs = {'input_shape': backbone.out_shape}
+        neck = cfg['neck'] and create(cfg['neck'], **kwargs)
+        # out_shape = neck and neck.out_shape or backbone.out_shape
+        query_head = create(cfg['query_head'])
+        out_shape = query_head.transformer.encoder.out_shape
+        kwargs = {'input_shape': out_shape}
+        rpn_head = create(cfg['rpn_head'], **kwargs)
+        roi_head = create(cfg['roi_head'], **kwargs)
+        return {
+            'backbone': backbone,
+            'neck': neck,
+            'query_head': query_head,
+            'rpn_head': rpn_head,
+            'roi_head':roi_head,
+        }
+        
+    def extract_feat(self, img):
+        """Directly extract features from the backbone+neck."""
+        x = self.backbone(img)
+        if self.with_neck:
+            x = self.neck(x)
+        return x
+    
+    def get_inputs(self):
+        img_metas = []
+        gt_bboxes = []
+        gt_labels = []
+
+        for idx, im_shape in enumerate(self.inputs['im_shape']):
+            img_meta = {
+                'img_shape': im_shape.astype("int32").tolist() + [1, ],
+                'batch_input_shape': self.inputs['image'].shape[-2:],
+                'pad_mask': self.inputs['pad_mask'][idx],
+            }
+            img_metas.append(img_meta)
+
+            gt_labels.append(self.inputs['gt_class'][idx])
+            gt_bboxes.append(self.inputs['gt_bbox'][idx])
+
+        return img_metas, gt_bboxes, gt_labels
+    
+
+    def get_pred(self):
+        img = self.inputs['image']
+        batch_size, _, height, width = img.shape
+        img_metas = [
+            dict(
+                batch_input_shape=(height, width),
+                img_shape=(height, width, 3),
+                scale_factor=self.inputs['scale_factor'][i])
+            for i in range(batch_size)
+        ]
+        
+        x = self.extract_feat(self.inputs)
+        # from reprod_log import ReprodLogger
+        # reprod_log_1 = ReprodLogger()
+        # reprod_log_1.add("demo_test_1", x[0].cpu().detach().numpy())
+        # reprod_log_1.save("result_1_paddle.npy")
+        # breakpoint()
+        bbox = self.query_head.simple_test(
+            x, img_metas, rescale=True)
+        bbox_num=[]
+        for i in range(len(bbox)):
+            bbox_num.append(bbox[i].shape[0])
+        bbox_num = paddle.to_tensor(bbox_num)
+        bbox = paddle.concat(bbox, axis=0)
+
+        return {'bbox': bbox, 'bbox_num': bbox_num}
+
+    def get_loss(self):
+        """
+        Args:
+            img (Tensor): Input images of shape (N, C, H, W).
+                Typically these should be mean centered and std scaled.
+            img_metas (list[dict]): A List of image info dict where each dict
+                has: 'img_shape', 'scale_factor', 'flip', and may also contain
+                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+                For details on the values of these keys see
+                :class:`mmdet.datasets.pipelines.Collect`.
+            gt_bboxes (list[Tensor]): Each item are the truth boxes for each
+                image in [tl_x, tl_y, br_x, br_y] format.
+            gt_labels (list[Tensor]): Class indices corresponding to each box.
+            gt_areas (list[Tensor]): mask areas corresponding to each box.
+            gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
+                boxes can be ignored when computing the loss.
+
+        Returns:
+            dict[str, Tensor]: A dictionary of loss components.
+        """
+        img = self.inputs['image']
+        batch_size, _, height, width = img.shape
+        img_metas, gt_bboxes, gt_labels = self.get_inputs()
+        gt_bboxes_ignore = getattr(self.inputs, 'gt_bboxes_ignore', None)
+        x = self.extract_feat(self.inputs)
+        losses = dict()
+        # DETR encoder and decoder forward
+        if self.query_head is not None:
+            bbox_losses, x = self.query_head.forward_train(x, img_metas, gt_bboxes,
+                                                          gt_labels, gt_bboxes_ignore)
+            losses.update(bbox_losses)
+
+        if self.rpn_head is not None:
+            rois, rois_num, rpn_loss = self.rpn_head(x, self.inputs)
+            losses.update(rpn_loss)
+            
+        positive_coords = []
+        if self.roi_head is not None:
+            roi_losses, _ = self.roi_head(x, rois, rois_num,
+                                self.inputs)
+            if self.with_pos_coord:
+                positive_coords.append(roi_losses.pop('pos_coords'))
+            else: 
+                if 'pos_coords' in roi_losses.keys():
+                    tmp = roi_losses.pop('pos_coords')  
+            losses.update(roi_losses)
+
+        # if self.bbox_head is not None:
+        #     bbox_losses = self.bbox_head.forward_train(x,img_metas,gt_bboxes,gt_labels,)
+        #     if self.with_pos_coord:
+        #         positive_coords.append(bbox_losses.pop('pos_coords'))
+        #     else: 
+        #         if 'pos_coords' in bbox_losses.keys():
+        #             tmp = bbox_losses.pop('pos_coords')  
+        #     losses.update(bbox_losses)
+        
+        if self.with_pos_coord and len(positive_coords)>0:
+            for i in range(len(positive_coords)):
+                bbox_losses = self.query_head.forward_train_aux(x, img_metas, gt_bboxes,
+                                                            gt_labels, gt_bboxes_ignore, positive_coords[i], i)
+                if bbox_losses is not None:                    
+                    losses.update(bbox_losses)
+        loss = 0
+        for k, v in losses.items():
+            if isinstance(v, list):
+                loss += sum(v)
+            else: 
+                loss += v
+        losses={}
+        losses['loss'] = loss
+        return losses
+    
\ No newline at end of file
diff --git a/ppdet/modeling/assigners/hungarian_assigner.py b/ppdet/modeling/assigners/hungarian_assigner.py
index 154c27ce97..e08286f953 100644
--- a/ppdet/modeling/assigners/hungarian_assigner.py
+++ b/ppdet/modeling/assigners/hungarian_assigner.py
@@ -24,8 +24,9 @@
 import paddle
 
 from ppdet.core.workspace import register
+from ppdet.modeling.assigners.pose_utils import bbox_cxcywh_to_xyxy
 
-__all__ = ['PoseHungarianAssigner', 'PseudoSampler']
+__all__ = ["PoseHungarianAssigner", "PseudoSampler", "HungarianAssigner"]
 
 
 class AssignResult:
@@ -72,11 +73,11 @@ def get_extra_property(self, key):
     def info(self):
         """dict: a dictionary of info about the object"""
         basic_info = {
-            'num_gts': self.num_gts,
-            'num_preds': self.num_preds,
-            'gt_inds': self.gt_inds,
-            'max_overlaps': self.max_overlaps,
-            'labels': self.labels,
+            "num_gts": self.num_gts,
+            "num_preds": self.num_preds,
+            "gt_inds": self.gt_inds,
+            "max_overlaps": self.max_overlaps,
+            "labels": self.labels,
         }
         basic_info.update(self._extra_properties)
         return basic_info
@@ -105,24 +106,19 @@ class PoseHungarianAssigner:
         oks_weight (int | float, optional): The scale factor for regression
             oks cost. Default 1.0.
     """
-    __inject__ = ['cls_cost', 'kpt_cost', 'oks_cost']
 
-    def __init__(self,
-                 cls_cost='ClassificationCost',
-                 kpt_cost='KptL1Cost',
-                 oks_cost='OksCost'):
+    __inject__ = ["cls_cost", "kpt_cost", "oks_cost"]
+
+    def __init__(
+        self, cls_cost="ClassificationCost", kpt_cost="KptL1Cost", oks_cost="OksCost"
+    ):
         self.cls_cost = cls_cost
         self.kpt_cost = kpt_cost
         self.oks_cost = oks_cost
 
-    def assign(self,
-               cls_pred,
-               kpt_pred,
-               gt_labels,
-               gt_keypoints,
-               gt_areas,
-               img_meta,
-               eps=1e-7):
+    def assign(
+        self, cls_pred, kpt_pred, gt_labels, gt_keypoints, gt_areas, img_meta, eps=1e-7
+    ):
         """Computes one-to-one matching based on the weighted costs.
 
         This method assign each query prediction to a ground truth or
@@ -157,52 +153,50 @@ def assign(self,
             :obj:`AssignResult`: The assigned result.
         """
         num_gts, num_kpts = gt_keypoints.shape[0], kpt_pred.shape[0]
-        if not gt_keypoints.astype('bool').any():
+        if not gt_keypoints.astype("bool").any():
             num_gts = 0
 
         # 1. assign -1 by default
-        assigned_gt_inds = paddle.full((num_kpts, ), -1, dtype="int64")
-        assigned_labels = paddle.full((num_kpts, ), -1, dtype="int64")
+        assigned_gt_inds = paddle.full((num_kpts,), -1, dtype="int64")
+        assigned_labels = paddle.full((num_kpts,), -1, dtype="int64")
         if num_gts == 0 or num_kpts == 0:
             # No ground truth or keypoints, return empty assignment
             if num_gts == 0:
                 # No ground truth, assign all to background
                 assigned_gt_inds[:] = 0
-            return AssignResult(
-                num_gts, assigned_gt_inds, None, labels=assigned_labels)
-        img_h, img_w, _ = img_meta['img_shape']
+            return AssignResult(num_gts, assigned_gt_inds, None, labels=assigned_labels)
+        img_h, img_w, _ = img_meta["img_shape"]
         factor = paddle.to_tensor(
-            [img_w, img_h, img_w, img_h], dtype=gt_keypoints.dtype).reshape(
-                (1, -1))
+            [img_w, img_h, img_w, img_h], dtype=gt_keypoints.dtype
+        ).reshape((1, -1))
 
         # 2. compute the weighted costs
         # classification cost
         cls_cost = self.cls_cost(cls_pred, gt_labels)
 
         # keypoint regression L1 cost
-        gt_keypoints_reshape = gt_keypoints.reshape((gt_keypoints.shape[0], -1,
-                                                     3))
+        gt_keypoints_reshape = gt_keypoints.reshape((gt_keypoints.shape[0], -1, 3))
         valid_kpt_flag = gt_keypoints_reshape[..., -1]
-        kpt_pred_tmp = kpt_pred.clone().detach().reshape((kpt_pred.shape[0], -1,
-                                                          2))
-        normalize_gt_keypoints = gt_keypoints_reshape[
-            ..., :2] / factor[:, :2].unsqueeze(0)
-        kpt_cost = self.kpt_cost(kpt_pred_tmp, normalize_gt_keypoints,
-                                 valid_kpt_flag)
+        kpt_pred_tmp = kpt_pred.clone().detach().reshape((kpt_pred.shape[0], -1, 2))
+        normalize_gt_keypoints = gt_keypoints_reshape[..., :2] / factor[
+            :, :2
+        ].unsqueeze(0)
+        kpt_cost = self.kpt_cost(kpt_pred_tmp, normalize_gt_keypoints, valid_kpt_flag)
         # keypoint OKS cost
-        kpt_pred_tmp = kpt_pred.clone().detach().reshape((kpt_pred.shape[0], -1,
-                                                          2))
+        kpt_pred_tmp = kpt_pred.clone().detach().reshape((kpt_pred.shape[0], -1, 2))
         kpt_pred_tmp = kpt_pred_tmp * factor[:, :2].unsqueeze(0)
-        oks_cost = self.oks_cost(kpt_pred_tmp, gt_keypoints_reshape[..., :2],
-                                 valid_kpt_flag, gt_areas)
+        oks_cost = self.oks_cost(
+            kpt_pred_tmp, gt_keypoints_reshape[..., :2], valid_kpt_flag, gt_areas
+        )
         # weighted sum of above three costs
         cost = cls_cost + kpt_cost + oks_cost
 
         # 3. do Hungarian matching on CPU using linear_sum_assignment
         cost = cost.detach().cpu()
         if linear_sum_assignment is None:
-            raise ImportError('Please run "pip install scipy" '
-                              'to install scipy first.')
+            raise ImportError(
+                'Please run "pip install scipy" ' "to install scipy first."
+            )
         matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
         matched_row_inds = paddle.to_tensor(matched_row_inds)
         matched_col_inds = paddle.to_tensor(matched_col_inds)
@@ -212,20 +206,19 @@ def assign(self,
         assigned_gt_inds[:] = 0
         # assign foregrounds based on matching results
         assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
-        assigned_labels[matched_row_inds] = gt_labels[matched_col_inds][
-            ..., 0].astype("int64")
-        return AssignResult(
-            num_gts, assigned_gt_inds, None, labels=assigned_labels)
+        assigned_labels[matched_row_inds] = gt_labels[matched_col_inds][..., 0].astype(
+            "int64"
+        )
+        return AssignResult(num_gts, assigned_gt_inds, None, labels=assigned_labels)
 
 
 class SamplingResult:
-    """Bbox sampling result.
-    """
+    """Bbox sampling result."""
 
-    def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
-                 gt_flags):
+    def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags):
         self.pos_inds = pos_inds
         self.neg_inds = neg_inds
+
         if pos_inds.size > 0:
             self.pos_bboxes = bboxes[pos_inds]
             self.neg_bboxes = bboxes[neg_inds]
@@ -238,15 +231,15 @@ def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
                 # hack for index error case
                 assert self.pos_assigned_gt_inds.numel() == 0
                 self.pos_gt_bboxes = paddle.zeros(
-                    gt_bboxes.shape, dtype=gt_bboxes.dtype).reshape((-1, 4))
+                    gt_bboxes.shape, dtype=gt_bboxes.dtype
+                ).reshape((-1, 4))
             else:
                 if len(gt_bboxes.shape) < 2:
                     gt_bboxes = gt_bboxes.reshape((-1, 4))
 
                 self.pos_gt_bboxes = paddle.index_select(
-                    gt_bboxes,
-                    self.pos_assigned_gt_inds.astype('int64'),
-                    axis=0)
+                    gt_bboxes, self.pos_assigned_gt_inds.astype("int64"), axis=0
+                )
 
             if assign_result.labels is not None:
                 self.pos_gt_labels = assign_result.labels[pos_inds]
@@ -260,23 +253,23 @@ def bboxes(self):
 
     def __nice__(self):
         data = self.info.copy()
-        data['pos_bboxes'] = data.pop('pos_bboxes').shape
-        data['neg_bboxes'] = data.pop('neg_bboxes').shape
+        data["pos_bboxes"] = data.pop("pos_bboxes").shape
+        data["neg_bboxes"] = data.pop("neg_bboxes").shape
         parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
-        body = '    ' + ',\n    '.join(parts)
-        return '{\n' + body + '\n}'
+        body = "    " + ",\n    ".join(parts)
+        return "{\n" + body + "\n}"
 
     @property
     def info(self):
         """Returns a dictionary of info about the object."""
         return {
-            'pos_inds': self.pos_inds,
-            'neg_inds': self.neg_inds,
-            'pos_bboxes': self.pos_bboxes,
-            'neg_bboxes': self.neg_bboxes,
-            'pos_is_gt': self.pos_is_gt,
-            'num_gts': self.num_gts,
-            'pos_assigned_gt_inds': self.pos_assigned_gt_inds,
+            "pos_inds": self.pos_inds,
+            "neg_inds": self.neg_inds,
+            "pos_bboxes": self.pos_bboxes,
+            "neg_bboxes": self.neg_bboxes,
+            "pos_is_gt": self.pos_is_gt,
+            "num_gts": self.num_gts,
+            "pos_assigned_gt_inds": self.pos_assigned_gt_inds,
         }
 
 
@@ -306,11 +299,146 @@ def sample(self, assign_result, bboxes, gt_bboxes, *args, **kwargs):
         Returns:
             :obj:`SamplingResult`: sampler results
         """
-        pos_inds = paddle.nonzero(
-            assign_result.gt_inds > 0, as_tuple=False).squeeze(-1)
-        neg_inds = paddle.nonzero(
-            assign_result.gt_inds == 0, as_tuple=False).squeeze(-1)
-        gt_flags = paddle.zeros([bboxes.shape[0]], dtype='int32')
-        sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
-                                         assign_result, gt_flags)
+
+        pos_inds = paddle.nonzero(assign_result.gt_inds > 0, as_tuple=False).squeeze(-1)
+        neg_inds = paddle.nonzero(assign_result.gt_inds == 0, as_tuple=False).squeeze(
+            -1
+        )
+        gt_flags = paddle.zeros([bboxes.shape[0]], dtype="int32")
+        sampling_result = SamplingResult(
+            pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags
+        )
         return sampling_result
+
+
+@register
+class HungarianAssigner:
+    """Computes one-to-one matching between predictions and ground truth.
+
+    This class computes an assignment between the targets and the predictions
+    based on the costs. The costs are weighted sum of three components:
+    classification cost, regression L1 cost and regression iou cost. The
+    targets don't include the no_object, so generally there are more
+    predictions than targets. After the one-to-one matching, the un-matched
+    are treated as backgrounds. Thus each query prediction will be assigned
+    with `0` or a positive integer indicating the ground truth index:
+
+    - 0: negative sample, no assigned gt
+    - positive integer: positive sample, index (1-based) of assigned gt
+
+    Args:
+        cls_weight (int | float, optional): The scale factor for classification
+            cost. Default 1.0.
+        bbox_weight (int | float, optional): The scale factor for regression
+            L1 cost. Default 1.0.
+        iou_weight (int | float, optional): The scale factor for regression
+            iou cost. Default 1.0.
+        iou_calculator (dict | optional): The config for the iou calculation.
+            Default type `BboxOverlaps2D`.
+        iou_mode (str | optional): "iou" (intersection over union), "iof"
+                (intersection over foreground), or "giou" (generalized
+                intersection over union). Default "giou".
+    """
+
+    __inject__ = ["cls_cost", "reg_cost", "iou_cost"]
+
+    def __init__(
+        self, cls_cost="ClassificationCost", reg_cost="BBoxL1Cost", iou_cost="IoUCost"
+    ):
+        self.cls_cost = cls_cost
+        self.reg_cost = reg_cost
+        self.iou_cost = iou_cost
+
+    def assign(
+        self,
+        bbox_pred,
+        cls_pred,
+        gt_bboxes,
+        gt_labels,
+        img_meta,
+        gt_bboxes_ignore=None,
+        eps=1e-7,
+    ):
+        """Computes one-to-one matching based on the weighted costs.
+
+        This method assign each query prediction to a ground truth or
+        background. The `assigned_gt_inds` with -1 means don't care,
+        0 means negative sample, and positive number is the index (1-based)
+        of assigned gt.
+        The assignment is done in the following steps, the order matters.
+
+        1. assign every prediction to -1
+        2. compute the weighted costs
+        3. do Hungarian matching on CPU based on the costs
+        4. assign all to 0 (background) first, then for each matched pair
+           between predictions and gts, treat this prediction as foreground
+           and assign the corresponding gt index (plus 1) to it.
+
+        Args:
+            bbox_pred (Tensor): Predicted boxes with normalized coordinates
+                (cx, cy, w, h), which are all in range [0, 1]. Shape
+                [num_query, 4].
+            cls_pred (Tensor): Predicted classification logits, shape
+                [num_query, num_class].
+            gt_bboxes (Tensor): Ground truth boxes with unnormalized
+                coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
+            gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
+            img_meta (dict): Meta information for current image.
+            gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+                labelled as `ignored`. Default None.
+            eps (int | float, optional): A value added to the denominator for
+                numerical stability. Default 1e-7.
+
+        Returns:
+            :obj:`AssignResult`: The assigned result.
+        """
+        assert (
+            gt_bboxes_ignore is None
+        ), "Only case when gt_bboxes_ignore is None is supported."
+        num_gts, num_bboxes = gt_bboxes.shape[0], bbox_pred.shape[0]
+
+        # 1. assign -1 by default
+        assigned_gt_inds = paddle.full((num_bboxes,), -1, dtype="int64")
+        assigned_labels = paddle.full((num_bboxes,), -1, dtype="int64")
+        if num_gts == 0 or num_bboxes == 0:
+            # No ground truth or boxes, return empty assignment
+            if num_gts == 0:
+                # No ground truth, assign all to background
+                assigned_gt_inds[:] = 0
+            return AssignResult(num_gts, assigned_gt_inds, None, labels=assigned_labels)
+        img_h, img_w, _ = img_meta["img_shape"]
+        factor = paddle.to_tensor(
+            [img_w, img_h, img_w, img_h], dtype=gt_bboxes.dtype
+        ).unsqueeze(0)
+
+        # 2. compute the weighted costs
+        # classification and bboxcost.
+        cls_cost = self.cls_cost(cls_pred, gt_labels)
+        # regression L1 cost
+        normalize_gt_bboxes = gt_bboxes / factor
+        reg_cost = self.reg_cost(bbox_pred, normalize_gt_bboxes)
+        # regression iou cost, defaultly giou is used in official DETR.
+        bboxes = bbox_cxcywh_to_xyxy(bbox_pred) * factor
+        iou_cost = self.iou_cost(bboxes, gt_bboxes)
+        # weighted sum of above three costs
+        cost = cls_cost + reg_cost + iou_cost
+
+        # 3. do Hungarian matching on CPU using linear_sum_assignment
+        cost = cost.detach().cpu()
+        if linear_sum_assignment is None:
+            raise ImportError('Please run "pip install scipy" '
+                              'to install scipy first.')
+        matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
+        matched_row_inds = paddle.to_tensor(matched_row_inds)
+        matched_col_inds = paddle.to_tensor(matched_col_inds)
+
+        # 4. assign backgrounds and foregrounds
+        # assign all indices to backgrounds first
+        assigned_gt_inds[:] = 0
+        # assign foregrounds based on matching results
+        assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
+        assigned_labels[matched_row_inds] = gt_labels[matched_col_inds][
+            ..., 0].astype("int64")
+        
+        return AssignResult(
+            num_gts, assigned_gt_inds, None, labels=assigned_labels)
diff --git a/ppdet/modeling/assigners/pose_utils.py b/ppdet/modeling/assigners/pose_utils.py
index 313215a4dd..422b92eb3a 100644
--- a/ppdet/modeling/assigners/pose_utils.py
+++ b/ppdet/modeling/assigners/pose_utils.py
@@ -21,8 +21,10 @@
 import paddle.nn.functional as F
 
 from ppdet.core.workspace import register
+from ppdet.data.transform.atss_assigner import bbox_overlaps
+from ppdet.modeling.transformers.utils import bbox_xyxy_to_cxcywh
 
-__all__ = ['KptL1Cost', 'OksCost', 'ClassificationCost']
+__all__ = ["KptL1Cost", "OksCost", "ClassificationCost", "BBoxL1Cost", "IoUCost"]
 
 
 def masked_fill(x, mask, value):
@@ -63,17 +65,18 @@ def __call__(self, kpt_pred, gt_keypoints, valid_kpt_flag):
                 kpt_cost.append(kpt_pred.sum() * 0)
             kpt_pred_tmp = kpt_pred.clone()
             valid_flag = valid_kpt_flag[i] > 0
-            valid_flag_expand = valid_flag.unsqueeze(0).unsqueeze(-1).expand_as(
-                kpt_pred_tmp)
+            valid_flag_expand = (
+                valid_flag.unsqueeze(0).unsqueeze(-1).expand_as(kpt_pred_tmp)
+            )
             if not valid_flag_expand.all():
                 kpt_pred_tmp = masked_fill(kpt_pred_tmp, ~valid_flag_expand, 0)
             cost = F.pairwise_distance(
                 kpt_pred_tmp.reshape((kpt_pred_tmp.shape[0], -1)),
-                gt_keypoints[i].reshape((-1, )).unsqueeze(0),
+                gt_keypoints[i].reshape((-1,)).unsqueeze(0),
                 p=1,
-                keepdim=True)
-            avg_factor = paddle.clip(
-                valid_flag.astype('float32').sum() * 2, 1.0)
+                keepdim=True,
+            )
+            avg_factor = paddle.clip(valid_flag.astype("float32").sum() * 2, 1.0)
             cost = cost / avg_factor
             kpt_cost.append(cost)
         kpt_cost = paddle.concat(kpt_cost, axis=1)
@@ -94,21 +97,56 @@ class OksCost(object):
     def __init__(self, num_keypoints=17, weight=1.0):
         self.weight = weight
         if num_keypoints == 17:
-            self.sigmas = np.array(
-                [
-                    .26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07,
-                    1.07, .87, .87, .89, .89
-                ],
-                dtype=np.float32) / 10.0
+            self.sigmas = (
+                np.array(
+                    [
+                        0.26,
+                        0.25,
+                        0.25,
+                        0.35,
+                        0.35,
+                        0.79,
+                        0.79,
+                        0.72,
+                        0.72,
+                        0.62,
+                        0.62,
+                        1.07,
+                        1.07,
+                        0.87,
+                        0.87,
+                        0.89,
+                        0.89,
+                    ],
+                    dtype=np.float32,
+                )
+                / 10.0
+            )
         elif num_keypoints == 14:
-            self.sigmas = np.array(
-                [
-                    .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89,
-                    .89, .79, .79
-                ],
-                dtype=np.float32) / 10.0
+            self.sigmas = (
+                np.array(
+                    [
+                        0.79,
+                        0.79,
+                        0.72,
+                        0.72,
+                        0.62,
+                        0.62,
+                        1.07,
+                        1.07,
+                        0.87,
+                        0.87,
+                        0.89,
+                        0.89,
+                        0.79,
+                        0.79,
+                    ],
+                    dtype=np.float32,
+                )
+                / 10.0
+            )
         else:
-            raise ValueError(f'Unsupported keypoints number {num_keypoints}')
+            raise ValueError(f"Unsupported keypoints number {num_keypoints}")
 
     def __call__(self, kpt_pred, gt_keypoints, valid_kpt_flag, gt_areas):
         """
@@ -125,17 +163,17 @@ def __call__(self, kpt_pred, gt_keypoints, valid_kpt_flag, gt_areas):
             paddle.Tensor: oks_cost value with weight.
         """
         sigmas = paddle.to_tensor(self.sigmas)
-        variances = (sigmas * 2)**2
+        variances = (sigmas * 2) ** 2
 
         oks_cost = []
         assert len(gt_keypoints) == len(gt_areas)
         for i in range(len(gt_keypoints)):
             if gt_keypoints[i].size == 0:
                 oks_cost.append(kpt_pred.sum() * 0)
-            squared_distance = \
-                (kpt_pred[:, :, 0] - gt_keypoints[i, :, 0].unsqueeze(0)) ** 2 + \
-                (kpt_pred[:, :, 1] - gt_keypoints[i, :, 1].unsqueeze(0)) ** 2
-            vis_flag = (valid_kpt_flag[i] > 0).astype('int')
+            squared_distance = (
+                kpt_pred[:, :, 0] - gt_keypoints[i, :, 0].unsqueeze(0)
+            ) ** 2 + (kpt_pred[:, :, 1] - gt_keypoints[i, :, 1].unsqueeze(0)) ** 2
+            vis_flag = (valid_kpt_flag[i] > 0).astype("int")
             vis_ind = vis_flag.nonzero(as_tuple=False)[:, 0]
             num_vis_kpt = vis_ind.shape[0]
             # assert num_vis_kpt > 0
@@ -145,10 +183,8 @@ def __call__(self, kpt_pred, gt_keypoints, valid_kpt_flag, gt_areas):
             area = gt_areas[i]
 
             squared_distance0 = squared_distance / (area * variances * 2)
-            squared_distance0 = paddle.index_select(
-                squared_distance0, vis_ind, axis=1)
-            squared_distance1 = paddle.exp(-squared_distance0).sum(axis=1,
-                                                                   keepdim=True)
+            squared_distance0 = paddle.index_select(squared_distance0, vis_ind, axis=1)
+            squared_distance1 = paddle.exp(-squared_distance0).sum(axis=1, keepdim=True)
             oks = squared_distance1 / num_vis_kpt
             # The 1 is a constant that doesn't change the matching, so omitted.
             oks_cost.append(-oks)
@@ -160,11 +196,11 @@ def __call__(self, kpt_pred, gt_keypoints, valid_kpt_flag, gt_areas):
 class ClassificationCost:
     """ClsSoftmaxCost.
 
-     Args:
-         weight (int | float, optional): loss_weight
+    Args:
+        weight (int | float, optional): loss_weight
     """
 
-    def __init__(self, weight=1.):
+    def __init__(self, weight=1.0):
         self.weight = weight
 
     def __call__(self, cls_pred, gt_labels):
@@ -190,21 +226,16 @@ def __call__(self, cls_pred, gt_labels):
 class FocalLossCost:
     """FocalLossCost.
 
-     Args:
-         weight (int | float, optional): loss_weight
-         alpha (int | float, optional): focal_loss alpha
-         gamma (int | float, optional): focal_loss gamma
-         eps (float, optional): default 1e-12
-         binary_input (bool, optional): Whether the input is binary,
-            default False.
+    Args:
+        weight (int | float, optional): loss_weight
+        alpha (int | float, optional): focal_loss alpha
+        gamma (int | float, optional): focal_loss gamma
+        eps (float, optional): default 1e-12
+        binary_input (bool, optional): Whether the input is binary,
+           default False.
     """
 
-    def __init__(self,
-                 weight=1.,
-                 alpha=0.25,
-                 gamma=2,
-                 eps=1e-12,
-                 binary_input=False):
+    def __init__(self, weight=1.0, alpha=0.25, gamma=2, eps=1e-12, binary_input=False):
         self.weight = weight
         self.alpha = alpha
         self.gamma = gamma
@@ -224,14 +255,18 @@ def _focal_loss_cost(self, cls_pred, gt_labels):
         if gt_labels.size == 0:
             return cls_pred.sum() * 0
         cls_pred = F.sigmoid(cls_pred)
-        neg_cost = -(1 - cls_pred + self.eps).log() * (
-            1 - self.alpha) * cls_pred.pow(self.gamma)
-        pos_cost = -(cls_pred + self.eps).log() * self.alpha * (
-            1 - cls_pred).pow(self.gamma)
+        neg_cost = (
+            -(1 - cls_pred + self.eps).log()
+            * (1 - self.alpha)
+            * cls_pred.pow(self.gamma)
+        )
+        pos_cost = (
+            -(cls_pred + self.eps).log() * self.alpha * (1 - cls_pred).pow(self.gamma)
+        )
 
         cls_cost = paddle.index_select(
-            pos_cost, gt_labels, axis=1) - paddle.index_select(
-                neg_cost, gt_labels, axis=1)
+            pos_cost, gt_labels, axis=1
+        ) - paddle.index_select(neg_cost, gt_labels, axis=1)
         return cls_cost * self.weight
 
     def _mask_focal_loss_cost(self, cls_pred, gt_labels):
@@ -250,13 +285,18 @@ def _mask_focal_loss_cost(self, cls_pred, gt_labels):
         gt_labels = gt_labels.flatten(1).float()
         n = cls_pred.shape[1]
         cls_pred = F.sigmoid(cls_pred)
-        neg_cost = -(1 - cls_pred + self.eps).log() * (
-            1 - self.alpha) * cls_pred.pow(self.gamma)
-        pos_cost = -(cls_pred + self.eps).log() * self.alpha * (
-            1 - cls_pred).pow(self.gamma)
-
-        cls_cost = paddle.einsum('nc,mc->nm', pos_cost, gt_labels) + \
-            paddle.einsum('nc,mc->nm', neg_cost, (1 - gt_labels))
+        neg_cost = (
+            -(1 - cls_pred + self.eps).log()
+            * (1 - self.alpha)
+            * cls_pred.pow(self.gamma)
+        )
+        pos_cost = (
+            -(cls_pred + self.eps).log() * self.alpha * (1 - cls_pred).pow(self.gamma)
+        )
+
+        cls_cost = paddle.einsum("nc,mc->nm", pos_cost, gt_labels) + paddle.einsum(
+            "nc,mc->nm", neg_cost, (1 - gt_labels)
+        )
         return cls_cost / n * self.weight
 
     def __call__(self, cls_pred, gt_labels):
@@ -273,3 +313,86 @@ def __call__(self, cls_pred, gt_labels):
             return self._mask_focal_loss_cost(cls_pred, gt_labels)
         else:
             return self._focal_loss_cost(cls_pred, gt_labels)
+
+
+def bbox_cxcywh_to_xyxy(bbox):
+    """Convert bbox coordinates from (cx, cy, w, h) to (x1, y1, x2, y2).
+
+    Args:
+        bbox (Tensor): Shape (n, 4) for bboxes.
+
+    Returns:
+        Tensor: Converted bboxes.
+    """
+    cx, cy, w, h = paddle.split(bbox, (1, 1, 1, 1), axis=-1)
+    bbox_new = [(cx - 0.5 * w), (cy - 0.5 * h), (cx + 0.5 * w), (cy + 0.5 * h)]
+    return paddle.concat(bbox_new, axis=-1)
+
+
+@register
+class BBoxL1Cost:
+    """BBoxL1Cost.
+
+    Args:
+        weight (int | float, optional): loss_weight
+        box_format (str, optional): 'xyxy' for DETR, 'xywh' for Sparse_RCNN
+
+    """
+
+    def __init__(self, weight=1.0, box_format="xyxy"):
+        self.weight = weight
+        assert box_format in ["xyxy", "xywh"]
+        self.box_format = box_format
+
+    def __call__(self, bbox_pred, gt_bboxes):
+        """
+        Args:
+            bbox_pred (Tensor): Predicted boxes with normalized coordinates
+                (cx, cy, w, h), which are all in range [0, 1]. Shape
+                (num_query, 4).
+            gt_bboxes (Tensor): Ground truth boxes with normalized
+                coordinates (x1, y1, x2, y2). Shape (num_gt, 4).
+
+        Returns:
+            Tensor: bbox_cost value with weight
+        """
+        if self.box_format == "xywh":
+            gt_bboxes = bbox_xyxy_to_cxcywh(gt_bboxes)
+        elif self.box_format == "xyxy":
+            bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred)
+        bbox_cost = paddle.cdist(bbox_pred, gt_bboxes, p=1)
+        return bbox_cost * self.weight
+
+
+@register
+class IoUCost:
+    """IoUCost.
+
+    Args:
+        iou_mode (str, optional): iou mode such as 'iou' | 'giou'
+        weight (int | float, optional): loss weight
+
+    """
+
+    def __init__(self, iou_mode="giou", weight=1.0):
+        self.weight = weight
+        self.iou_mode = iou_mode
+
+    def __call__(self, bboxes, gt_bboxes):
+        """
+        Args:
+            bboxes (Tensor): Predicted boxes with unnormalized coordinates
+                (x1, y1, x2, y2). Shape (num_query, 4).
+            gt_bboxes (Tensor): Ground truth boxes with unnormalized
+                coordinates (x1, y1, x2, y2). Shape (num_gt, 4).
+
+        Returns:
+            Tensor: iou_cost value with weight
+        """
+        # overlaps: [num_bboxes, num_gt]
+        overlaps = bbox_overlaps(
+            bboxes.detach().numpy(), gt_bboxes.detach().numpy(), mode=self.iou_mode, is_aligned=False
+        )
+        # The 1 is a constant that doesn't change the matching, so omitted.
+        iou_cost = -overlaps
+        return iou_cost * self.weight
diff --git a/ppdet/modeling/heads/__init__.py b/ppdet/modeling/heads/__init__.py
index e2b2dc2da0..23c0136ca2 100644
--- a/ppdet/modeling/heads/__init__.py
+++ b/ppdet/modeling/heads/__init__.py
@@ -41,6 +41,9 @@
 from . import sparse_roi_head
 from . import vitpose_head
 from . import clrnet_head
+from . import co_deformable_detr_head
+from . import co_roi_head
+from . import co_atss_head
 
 from .bbox_head import *
 from .mask_head import *
@@ -72,3 +75,6 @@
 from .petr_head import *
 from .vitpose_head import *
 from .clrnet_head import *
+from .co_deformable_detr_head import *
+from .co_roi_head import *
+from .co_atss_head import *
diff --git a/ppdet/modeling/heads/co_atss_head.py b/ppdet/modeling/heads/co_atss_head.py
new file mode 100644
index 0000000000..afe35af034
--- /dev/null
+++ b/ppdet/modeling/heads/co_atss_head.py
@@ -0,0 +1,665 @@
+from functools import partial
+import paddle
+import paddle.nn as nn
+from ppdet.modeling import bbox_utils
+from ppdet.core.workspace import register
+from ppdet.modeling.assigners import hungarian_assigner
+from ppdet.data.transform.atss_assigner import ATSSAssigner
+
+__all__ = ['CoATSSHead']
+
+class Scale(nn.Layer):
+    """A learnable scale parameter.
+
+    This layer scales the input by a learnable factor. It multiplies a
+    learnable scale parameter of shape (1,) with input of any shape.
+
+    Args:
+        scale (float): Initial value of scale factor. Default: 1.0
+    """
+
+    def __init__(self, scale: float = 1.0):
+        super().__init__()
+        self.scale = paddle.create_parameter(paddle.to_tensor(scale, dtype='float32').shape,dtype='float32')
+
+    def forward(self, x: paddle.Tensor) -> paddle.Tensor:
+        return x * self.scale
+    
+def reduce_mean(tensor):
+    world_size = paddle.distributed.get_world_size()
+    if world_size == 1:
+        return tensor
+    paddle.distributed.all_reduce(tensor)
+    return tensor / world_size
+
+def multi_apply(func, *args, **kwargs):
+    """Apply function to a list of arguments.
+
+    Note:
+        This function applies the ``func`` to multiple inputs and
+        map the multiple outputs of the ``func`` into different
+        list. Each list contains the same type of outputs corresponding
+        to different inputs.
+
+    Args:
+        func (Function): A function that will be applied to a list of
+            arguments
+
+    Returns:
+        tuple(list): A tuple containing multiple list, each list contains \
+            a kind of returned results by the function
+    """
+    pfunc = partial(func, **kwargs) if kwargs else func
+    map_results = map(pfunc, *args)
+    res = tuple(map(list, zip(*map_results)))
+    return res
+
+@register
+class CoATSSHead(nn.Layer):
+    """Bridging the Gap Between Anchor-based and Anchor-free Detection via
+    Adaptive Training Sample Selection.
+
+    ATSS head structure is similar with FCOS, however ATSS use anchor boxes
+    and assign label by Adaptive Training Sample Selection instead max-iou.
+
+    https://arxiv.org/abs/1912.02424
+    """
+    __inject__ = ['anchor_generator','loss_cls', 'loss_bbox','sampler']
+    def __init__(self,                  
+                 num_classes,
+                 in_channels,
+                 stacked_convs=4,
+                 feat_channels=256,
+                 anchor_generator=None,
+                 assigner='ATSSAssigner',
+                 sampler='PseudoSampler',
+                 loss_cls=None,
+                 loss_bbox=None,
+                 reg_decoded_bbox=True,
+                 pos_weight=-1
+                 ):
+        super().__init__()
+        self.num_classes=num_classes
+        self.in_channels=in_channels
+        self.stacked_convs=stacked_convs
+        self.feat_channels=feat_channels
+        self.anchor_generator=anchor_generator
+        self.num_levels=len(self.anchor_generator.strides)
+        self.num_anchors = self.anchor_generator.num_anchors
+        self.use_sigmoid_cls = True
+        self.loss_cls=loss_cls
+        self.loss_bbox=loss_bbox
+        self.loss_centerness=nn.CrossEntropyLoss()
+        self.assigner=ATSSAssigner()
+        self.sampler = sampler
+        self.reg_decoded_bbox=reg_decoded_bbox
+        self.pos_weight=pos_weight
+        if self.use_sigmoid_cls:
+            self.cls_out_channels = num_classes
+        else:
+            self.cls_out_channels = num_classes + 1
+
+        if self.cls_out_channels <= 0:
+            raise ValueError(f'num_classes={num_classes} is too small')
+        self._init_layers()
+        
+    def _init_layers(self):
+        """Initialize layers of the head."""
+        self.cls_convs = nn.LayerList()
+        self.reg_convs = nn.LayerList()
+        for i in range(self.stacked_convs):
+            chn = self.in_channels if i == 0 else self.feat_channels
+            self.cls_convs.append(
+                nn.Sequential(
+                    nn.Conv2D(chn, self.feat_channels, 3, padding=1), 
+                    nn.GroupNorm(32,self.feat_channels),
+                    nn.ReLU()))
+                
+            self.reg_convs.append(
+                nn.Sequential(
+                    nn.Conv2D(chn, self.feat_channels, 3, padding=1), 
+                    nn.GroupNorm(32,self.feat_channels),
+                    nn.ReLU()))
+        self.atss_cls = nn.Conv2D(
+            self.feat_channels,
+            self.num_anchors * self.cls_out_channels,
+            3,
+            padding=1)
+        self.atss_reg = nn.Conv2D(
+            self.feat_channels, self.num_anchors * 4, 3, padding=1)
+        self.atss_centerness = nn.Conv2D(
+            self.feat_channels, self.num_anchors * 1, 3, padding=1)
+        self.scales = nn.LayerList(
+            [Scale(1.0) for _ in self.anchor_generator.strides])
+    
+    def forward(self, feats):
+        """Forward features from the upstream network.
+
+        Args:
+            feats (tuple[Tensor]): Features from the upstream network, each is
+                a 4D-tensor.
+
+        Returns:
+            tuple: Usually a tuple of classification scores and bbox prediction
+                cls_scores (list[Tensor]): Classification scores for all scale
+                    levels, each is a 4D-tensor, the channels number is
+                    num_anchors * num_classes.
+                bbox_preds (list[Tensor]): Box energies / deltas for all scale
+                    levels, each is a 4D-tensor, the channels number is
+                    num_anchors * 4.
+        """
+        return multi_apply(self.forward_single, feats, self.scales)
+
+    def forward_single(self, x, scale):
+        """Forward feature of a single scale level.
+
+        Args:
+            x (Tensor): Features of a single scale level.
+            scale : Learnable scale module to resize the bbox prediction.
+
+        Returns:
+            tuple:
+                cls_score (Tensor): Cls scores for a single scale level
+                    the channels number is num_anchors * num_classes.
+                bbox_pred (Tensor): Box energies / deltas for a single scale
+                    level, the channels number is num_anchors * 4.
+                centerness (Tensor): Centerness for a single scale level, the
+                    channel number is (N, num_anchors * 1, H, W).
+        """
+        cls_feat = x
+        reg_feat = x
+    
+        for cls_conv in self.cls_convs:
+            cls_feat = cls_conv(cls_feat)
+        for reg_conv in self.reg_convs:
+            reg_feat = reg_conv(reg_feat)
+        cls_score = self.atss_cls(cls_feat)
+        # we just follow atss, not apply exp in bbox_pred
+        bbox_pred = scale(self.atss_reg(reg_feat)).astype('float32')
+        centerness = self.atss_centerness(reg_feat)
+        return cls_score, bbox_pred, centerness
+        
+    def loss_single(self, anchors, cls_score, bbox_pred, centerness, labels,
+                    label_weights, bbox_targets, img_metas, num_total_samples):
+        """Compute loss of a single scale level.
+
+        Args:
+            cls_score (Tensor): Box scores for each scale level
+                Has shape (N, num_anchors * num_classes, H, W).
+            bbox_pred (Tensor): Box energies / deltas for each scale
+                level with shape (N, num_anchors * 4, H, W).
+            anchors (Tensor): Box reference for each scale level with shape
+                (N, num_total_anchors, 4).
+            labels (Tensor): Labels of each anchors with shape
+                (N, num_total_anchors).
+            label_weights (Tensor): Label weights of each anchor with shape
+                (N, num_total_anchors)
+            bbox_targets (Tensor): BBox regression targets of each anchor
+                weight shape (N, num_total_anchors, 4).
+            num_total_samples (int): Number os positive samples that is
+                reduced over all GPUs.
+
+        Returns:
+            dict[str, Tensor]: A dictionary of loss components.
+        """
+        anchors = anchors.reshape((-1, 4))
+        cls_score = cls_score.transpose((0, 2, 3, 1)).reshape(
+            (-1, self.cls_out_channels))
+        bbox_pred = bbox_pred.transpose((0, 2, 3, 1)).reshape((-1, 4))
+        centerness = centerness.transpose((0, 2, 3, 1)).reshape([-1])
+        bbox_targets = bbox_targets.reshape((-1, 4))
+        labels = labels.reshape([-1]).astype(paddle.int32)
+        label_weights = label_weights.reshape([-1])
+        # classification loss
+        loss_cls = self.loss_cls(
+            cls_score, labels, label_weights, avg_factor=num_total_samples)
+        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+        bg_class_ind = self.num_classes
+        pos_inds = paddle.nonzero(
+                paddle.logical_and((labels >= 0), (labels < bg_class_ind)),
+                as_tuple=False).squeeze(1)
+        
+        if len(pos_inds) > 0:
+            pos_bbox_targets = bbox_targets[pos_inds]
+            pos_bbox_pred = bbox_pred[pos_inds]
+            pos_anchors = anchors[pos_inds]
+            pos_centerness = centerness[pos_inds]
+
+            centerness_targets = self.centerness_target(
+                pos_anchors, pos_bbox_targets)
+            pos_decode_bbox_pred = bbox_utils.delta2bbox(
+                pos_anchors, pos_bbox_pred)
+
+            # regression loss
+            loss_bbox = self.loss_bbox(
+                pos_decode_bbox_pred,
+                pos_bbox_targets,
+                weight=centerness_targets,
+                avg_factor=1.0)
+
+            # centerness loss
+            loss_centerness = self.loss_centerness(
+                pos_centerness,
+                centerness_targets,
+                avg_factor=num_total_samples)
+        else:
+            loss_bbox = bbox_pred.sum() * 0
+            loss_centerness = centerness.sum() * 0
+            centerness_targets = paddle.to_tensor(0., dtype=bbox_targets.dtype)
+
+        return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum()
+
+    def centerness_target(self, anchors, gts):
+        # only calculate pos centerness targets, otherwise there may be nan
+        anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2
+        anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2
+        l_ = anchors_cx - gts[:, 0]
+        t_ = anchors_cy - gts[:, 1]
+        r_ = gts[:, 2] - anchors_cx
+        b_ = gts[:, 3] - anchors_cy
+
+        left_right = paddle.stack([l_, r_], axis=1)
+        top_bottom = paddle.stack([t_, b_], axis=1)
+        centerness = paddle.sqrt(
+            (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) *
+            (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]))
+        assert not paddle.isnan(centerness).any()
+        return centerness
+    
+    def get_anchors(self, featmap_sizes, img_metas, device='cuda'):
+        """Get anchors according to feature map sizes.
+
+        Args:
+            featmap_sizes (list[tuple]): Multi-level feature map sizes.
+            img_metas (list[dict]): Image meta info.
+            device (torch.device | str): Device for returned tensors
+
+        Returns:
+            tuple:
+                anchor_list (list[Tensor]): Anchors of each image.
+                valid_flag_list (list[Tensor]): Valid flags of each image.
+        """
+        num_imgs = len(img_metas)
+
+        # since feature map sizes of all images are the same, we only compute
+        # anchors for one time
+        multi_level_anchors = self.anchor_generator(
+            featmap_sizes)
+        anchor_list = [multi_level_anchors for _ in range(num_imgs)]
+
+        # for each image, we compute valid flags of multi level anchors
+        valid_flag_list = []
+        for img_id, img_meta in enumerate(img_metas):
+            multi_level_flags = self.anchor_generator.valid_flags(
+                featmap_sizes, img_meta['img_shape'])
+            valid_flag_list.append(multi_level_flags)
+
+        return anchor_list, valid_flag_list
+    
+    def loss(self,
+             cls_scores,
+             bbox_preds,
+             centernesses,
+             gt_bboxes,
+             gt_labels,
+             img_metas,
+             gt_bboxes_ignore=None):
+        """Compute losses of the head.
+
+        Args:
+            cls_scores (list[Tensor]): Box scores for each scale level
+                Has shape (N, num_anchors * num_classes, H, W)
+            bbox_preds (list[Tensor]): Box energies / deltas for each scale
+                level with shape (N, num_anchors * 4, H, W)
+            centernesses (list[Tensor]): Centerness for each scale
+                level with shape (N, num_anchors * 1, H, W)
+            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+            gt_labels (list[Tensor]): class indices corresponding to each box
+            img_metas (list[dict]): Meta information of each image, e.g.,
+                image size, scaling factor, etc.
+            gt_bboxes_ignore (list[Tensor] | None): specify which bounding
+                boxes can be ignored when computing the loss.
+
+        Returns:
+            dict[str, Tensor]: A dictionary of loss components.
+        """
+        featmap_sizes = [featmap.shape[-2:] for featmap in cls_scores]
+        assert len(featmap_sizes) == self.num_levels
+        anchor_list, valid_flag_list = self.get_anchors(
+            cls_scores, img_metas)
+
+        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+        cls_reg_targets = self.get_targets(
+            anchor_list,
+            valid_flag_list,
+            gt_bboxes,
+            img_metas,
+            gt_bboxes_ignore_list=gt_bboxes_ignore,
+            gt_labels_list=gt_labels,
+            label_channels=label_channels)
+        if cls_reg_targets is None:
+            return None
+
+        (anchor_list, labels_list, label_weights_list, bbox_targets_list,
+         bbox_weights_list, num_total_pos, num_total_neg,
+         ori_anchors, ori_labels, ori_bbox_targets) = cls_reg_targets
+        num_total_samples = reduce_mean(
+            paddle.to_tensor(num_total_pos, dtype=paddle.float32)).item()
+        num_total_samples = max(num_total_samples, 1.0)
+        new_img_metas = [img_metas for _ in range(len(anchor_list))]
+        losses_cls, losses_bbox, loss_centerness,\
+            bbox_avg_factor = multi_apply(
+                self.loss_single,
+                anchor_list,
+                cls_scores,
+                bbox_preds,
+                centernesses,
+                labels_list,
+                label_weights_list,
+                bbox_targets_list,
+                new_img_metas,
+                num_total_samples=num_total_samples)
+
+        bbox_avg_factor = sum(bbox_avg_factor)
+        bbox_avg_factor = reduce_mean(bbox_avg_factor).clip_(min=1).item()
+        losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox))
+
+        pos_coords = (ori_anchors, ori_labels, ori_bbox_targets, 'atss')
+        return dict(
+            loss_cls=losses_cls,
+            loss_bbox=losses_bbox,
+            loss_centerness=loss_centerness,
+            pos_coords=pos_coords)
+        
+        
+    def images_to_levels(self, target, num_level_anchors):
+        """
+        Convert targets by image to targets by feature level.
+        """
+        target = paddle.stack(target, 0)
+        level_targets = []
+        start = 0
+        for n in num_level_anchors:
+            end = start + n                             
+            level_targets.append(target[:, start:end].squeeze(0))
+            start = end
+        return level_targets
+
+
+    def get_targets(self,
+                    anchor_list,
+                    valid_flag_list,
+                    gt_bboxes_list,
+                    img_metas,
+                    gt_bboxes_ignore_list=None,
+                    gt_labels_list=None,
+                    label_channels=1,
+                    unmap_outputs=True):
+        """Get targets for ATSS head.
+
+        This method is almost the same as `AnchorHead.get_targets()`. Besides
+        returning the targets as the parent method does, it also returns the
+        anchors as the first element of the returned tuple.
+        """
+        num_imgs = len(img_metas)
+        assert len(anchor_list) == len(valid_flag_list) == num_imgs
+
+        # anchor number of multi levels
+        num_level_anchors = [anchors.shape[0] for anchors in anchor_list[0]]
+        num_level_anchors_list = [num_level_anchors] * num_imgs
+
+        # concat all level anchors and flags to a single tensor
+        for i in range(num_imgs):
+            assert len(anchor_list[i]) == len(valid_flag_list[i])
+            anchor_list[i] = paddle.concat(anchor_list[i])
+            valid_flag_list[i] = paddle.concat(valid_flag_list[i])
+
+        # compute targets for each image
+        if gt_bboxes_ignore_list is None:
+            gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+        if gt_labels_list is None:
+            gt_labels_list = [None for _ in range(num_imgs)]
+        (all_anchors, all_labels, all_label_weights, all_bbox_targets,
+         all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
+             self._get_target_single,
+             anchor_list,
+             valid_flag_list,
+             num_level_anchors_list,
+             gt_bboxes_list,
+             gt_bboxes_ignore_list,
+             gt_labels_list,
+             img_metas,
+             label_channels=label_channels,
+             unmap_outputs=unmap_outputs)
+        # no valid anchors
+        if any([labels is None for labels in all_labels]):
+            return None
+        # sampled anchors of all images
+        num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+        num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+        # split targets to a list w.r.t. multiple levels
+        ori_anchors = all_anchors
+        ori_labels = all_labels
+        ori_bbox_targets = all_bbox_targets
+        anchors_list = self.images_to_levels(all_anchors, num_level_anchors)
+        labels_list = self.images_to_levels(all_labels, num_level_anchors)
+        label_weights_list = self.images_to_levels(all_label_weights,
+                                              num_level_anchors)
+        bbox_targets_list = self.images_to_levels(all_bbox_targets,
+                                             num_level_anchors)
+        bbox_weights_list = self.images_to_levels(all_bbox_weights,
+                                             num_level_anchors)
+        return (anchors_list, labels_list, label_weights_list,
+                bbox_targets_list, bbox_weights_list, num_total_pos,
+                num_total_neg, ori_anchors, ori_labels, ori_bbox_targets)
+        
+    def anchor_inside_flags(self,flat_anchors,
+                        valid_flags,
+                        img_shape,
+                        allowed_border=0):
+        """Check whether the anchors are inside the border.
+
+        Args:
+            flat_anchors (torch.Tensor): Flatten anchors, shape (n, 4).
+            valid_flags (torch.Tensor): An existing valid flags of anchors.
+            img_shape (tuple(int)): Shape of current image.
+            allowed_border (int, optional): The border to allow the valid anchor.
+                Defaults to 0.
+
+        Returns:
+            torch.Tensor: Flags indicating whether the anchors are inside a \
+                valid range.
+        """
+        img_h, img_w = img_shape[:2]
+        if allowed_border >= 0:
+            inside_flags = valid_flags & \
+                (flat_anchors[:, 0] >= -allowed_border) & \
+                (flat_anchors[:, 1] >= -allowed_border) & \
+                (flat_anchors[:, 2] < img_w + allowed_border) & \
+                (flat_anchors[:, 3] < img_h + allowed_border)
+        else:
+            inside_flags = valid_flags
+        return inside_flags
+    
+
+    def unmap(self,data, count, inds, fill=0):
+        """Unmap a subset of item (data) back to the original set of items (of size
+        count)"""
+        if data.dim() == 1:
+            ret = paddle.full((count,1), fill)
+            data=data.unsqueeze(0).transpose((1,0))
+            ret[inds,:] = data
+            ret=ret.transpose((1,0)).squeeze()
+        else:
+            new_size = (count, ) + tuple(data.shape[1:])
+            ret = paddle.full(new_size, fill)
+            ret[inds.astype(paddle.bool), :] = data
+        return ret
+
+
+    def get_num_level_anchors_inside(self, num_level_anchors, inside_flags):
+        split_inside_flags = paddle.split(inside_flags, num_level_anchors)
+        num_level_anchors_inside = [
+            int(flags.sum()) for flags in split_inside_flags
+        ]
+        return num_level_anchors_inside
+    
+    def _get_target_single(self,
+                           flat_anchors,
+                           valid_flags,
+                           num_level_anchors,
+                           gt_bboxes,
+                           gt_bboxes_ignore,
+                           gt_labels,
+                           img_meta,
+                           label_channels=1,
+                           unmap_outputs=True):
+        """Compute regression, classification targets for anchors in a single
+        image.
+
+        Args:
+            flat_anchors (Tensor): Multi-level anchors of the image, which are
+                concatenated into a single tensor of shape (num_anchors ,4)
+            valid_flags (Tensor): Multi level valid flags of the image,
+                which are concatenated into a single tensor of
+                    shape (num_anchors,).
+            num_level_anchors Tensor): Number of anchors of each scale level.
+            gt_bboxes (Tensor): Ground truth bboxes of the image,
+                shape (num_gts, 4).
+            gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+                ignored, shape (num_ignored_gts, 4).
+            gt_labels (Tensor): Ground truth labels of each box,
+                shape (num_gts,).
+            img_meta (dict): Meta info of the image.
+            label_channels (int): Channel of label.
+            unmap_outputs (bool): Whether to map outputs back to the original
+                set of anchors.
+
+        Returns:
+            tuple: N is the number of total anchors in the image.
+                labels (Tensor): Labels of all anchors in the image with shape
+                    (N,).
+                label_weights (Tensor): Label weights of all anchor in the
+                    image with shape (N,).
+                bbox_targets (Tensor): BBox targets of all anchors in the
+                    image with shape (N, 4).
+                bbox_weights (Tensor): BBox weights of all anchors in the
+                    image with shape (N, 4)
+                pos_inds (Tensor): Indices of positive anchor with shape
+                    (num_pos,).
+                neg_inds (Tensor): Indices of negative anchor with shape
+                    (num_neg,).
+        """
+        inside_flags = self.anchor_inside_flags(flat_anchors, valid_flags,
+                                           img_meta['img_shape'][:2],
+                                           -1).astype(paddle.bool)
+        if not inside_flags.any():
+            return (None, ) * 7
+        # assign gt and sample anchors
+        anchors = flat_anchors[inside_flags, :]
+
+        num_level_anchors_inside = self.get_num_level_anchors_inside(
+            num_level_anchors, inside_flags)
+
+        # pad_gt_mask = (
+        #     gt_bboxes.sum(axis=-1, keepdim=True) > 0).astype(gt_bboxes.dtype)
+
+        assigned_gt_inds, max_overlaps = self.assigner(anchors.cpu().detach().numpy(), num_level_anchors_inside,
+                                      gt_labels=gt_labels, gt_bboxes=gt_bboxes.cpu().detach().numpy(), 
+                                      )
+        assigned_gt_inds = paddle.to_tensor(assigned_gt_inds)
+        max_overlaps = paddle.to_tensor(max_overlaps)
+        if gt_labels is not None:
+            assigned_labels = paddle.full((anchors.shape[0], ),-1, dtype=assigned_gt_inds.dtype)
+            pos_inds = paddle.nonzero(
+                assigned_gt_inds > 0, as_tuple=False).squeeze()
+            if pos_inds.numel() > 0:
+                assigned_labels[pos_inds] = gt_labels[
+                    assigned_gt_inds[pos_inds] - 1]
+        else:
+            assigned_labels = None
+            
+        assign_result=hungarian_assigner.AssignResult(
+            gt_bboxes.shape[0], assigned_gt_inds, max_overlaps, labels=assigned_labels)
+
+        sampling_result = self.sampler.sample(assign_result, anchors,
+                                              gt_bboxes)
+
+        num_valid_anchors = anchors.shape[0]
+        bbox_targets = paddle.zeros_like(anchors)
+        bbox_weights = paddle.zeros_like(anchors)
+        labels = paddle.full((num_valid_anchors, ),self.num_classes,dtype=paddle.int64)
+        
+        label_weights = paddle.zeros((num_valid_anchors, ), dtype=paddle.float32)
+        pos_inds = sampling_result.pos_inds
+        neg_inds = sampling_result.neg_inds
+        if len(pos_inds) > 0:
+            if self.reg_decoded_bbox:
+                pos_bbox_targets = sampling_result.pos_gt_bboxes
+            else:
+                pos_bbox_targets = bbox_utils.bbox2delta(
+                    sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
+
+            bbox_targets[pos_inds, :] = pos_bbox_targets
+            bbox_weights[pos_inds, :] = 1.0
+            if gt_labels is None:
+                # Only rpn gives gt_labels as None
+                # Foreground is the first class since v2.5.0
+                labels[pos_inds] = 0
+            else:
+                labels[pos_inds] = gt_labels[
+                    sampling_result.pos_assigned_gt_inds]
+            if self.pos_weight <= 0:
+                label_weights[pos_inds] = 1.0
+            else:
+                label_weights[pos_inds] = self.pos_weight
+        if len(neg_inds) > 0:
+            label_weights[neg_inds] = 1.0
+
+        # map up to original set of anchors
+        if unmap_outputs:
+            num_total_anchors = flat_anchors.shape[0]
+            anchors = self.unmap(anchors, num_total_anchors, inside_flags)
+            labels = self.unmap(
+                labels, num_total_anchors, inside_flags, fill=self.num_classes)
+
+            label_weights = self.unmap(label_weights, num_total_anchors,
+                                  inside_flags)
+            bbox_targets = self.unmap(bbox_targets, num_total_anchors, inside_flags)
+            bbox_weights = self.unmap(bbox_weights, num_total_anchors, inside_flags)
+
+        return (anchors, labels, label_weights, bbox_targets, bbox_weights,
+                pos_inds, neg_inds)
+        
+    def forward_train(self,
+                      x,
+                      img_metas,
+                      gt_bboxes,
+                      gt_labels=None,
+                      gt_bboxes_ignore=None,
+                      **kwargs):
+        """
+        Args:
+            x (list[Tensor]): Features from FPN.
+            img_metas (list[dict]): Meta information of each image, e.g.,
+                image size, scaling factor, etc.
+            gt_bboxes (Tensor): Ground truth bboxes of the image,
+                shape (num_gts, 4).
+            gt_labels (Tensor): Ground truth labels of each box,
+                shape (num_gts,).
+            gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+                ignored, shape (num_ignored_gts, 4).
+
+        Returns:
+            tuple:
+                losses: (dict[str, Tensor]): A dictionary of loss components.
+        """
+        outs = self(x)
+        if gt_labels is None:
+            loss_inputs = outs + (gt_bboxes, img_metas)
+        else:
+            loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
+        losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
+
+        return losses
diff --git a/ppdet/modeling/heads/co_deformable_detr_head.py b/ppdet/modeling/heads/co_deformable_detr_head.py
new file mode 100644
index 0000000000..8f6918743b
--- /dev/null
+++ b/ppdet/modeling/heads/co_deformable_detr_head.py
@@ -0,0 +1,1300 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+this code is base on https://github.com/Sense-X/Co-DETR/blob/main/projects/models/co_deformable_detr_head.py
+"""
+import copy
+import numpy as np
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+from ppdet.core.workspace import register
+import paddle.distributed as dist
+
+from ..transformers.petr_transformer import inverse_sigmoid, masked_fill
+from ..initializer import constant_, normal_
+from ppdet.modeling.transformers.utils import bbox_cxcywh_to_xyxy
+
+__all__ = ["CoDeformDETRHead"]
+
+from functools import partial
+
+
+def bias_init_with_prob(prior_prob: float) -> float:
+    """initialize conv/fc bias value according to a given probability value."""
+    bias_init = float(-np.log((1 - prior_prob) / prior_prob))
+    return bias_init
+
+
+def constant_init(module, val, bias=0):
+    if hasattr(module, "weight") and module.weight is not None:
+        constant_(module.weight, val)
+    if hasattr(module, "bias") and module.bias is not None:
+        constant_(module.bias, bias)
+
+
+def reduce_mean(tensor):
+    """ "Obtain the mean of tensor on different GPUs."""
+    if not (dist.get_world_size() and dist.is_initialized()):
+        return tensor
+    tensor = tensor.clone()
+    dist.all_reduce(
+        tensor.divide(paddle.to_tensor(dist.get_world_size(), dtype="float32")),
+        op=dist.ReduceOp.SUM,
+    )
+    return tensor
+
+
+def bbox_xyxy_to_cxcywh(bbox):
+    """Convert bbox coordinates from (x1, y1, x2, y2) to (cx, cy, w, h).
+
+    Args:
+        bbox (Tensor): Shape (n, 4) for bboxes.
+
+    Returns:
+        Tensor: Converted bboxes.
+    """
+    x1, y1, x2, y2 = paddle.split(bbox, (1, 1, 1, 1), axis=-1)
+    bbox_new = [(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)]
+    return paddle.concat(bbox_new, axis=-1)
+
+
+def multi_apply(func, *args, **kwargs):
+    """Apply function to a list of arguments.
+
+    Note:
+        This function applies the ``func`` to multiple inputs and
+        map the multiple outputs of the ``func`` into different
+        list. Each list contains the same type of outputs corresponding
+        to different inputs.
+
+    Args:
+        func (Function): A function that will be applied to a list of
+            arguments
+
+    Returns:
+        tuple(list): A tuple containing multiple list, each list contains \
+            a kind of returned results by the function
+    """
+    pfunc = partial(func, **kwargs) if kwargs else func
+    map_results = map(pfunc, *args)
+    res = tuple(map(list, zip(*map_results)))
+    return res
+
+
+@register
+class CoDeformDETRHead(nn.Layer):
+    __inject__ = [
+        "transformer",
+        "positional_encoding",
+        "loss_cls",
+        "loss_bbox",
+        "loss_iou",
+        "nms",
+        "assigner",
+        "sampler"
+    ]
+
+    def __init__(
+        self,
+        num_classes,
+        in_channels,
+        num_query=300,
+        sync_cls_avg_factor=True,
+        with_box_refine=False,
+        as_two_stage=False,
+        mixed_selection=False,
+        max_pos_coords=300,
+        lambda_1=1,
+        num_reg_fcs=2,
+        transformer=None,
+        positional_encoding="SinePositionalEncoding",
+        loss_cls="FocalLoss",
+        loss_bbox="L1Loss",
+        loss_iou="GIoULoss",
+        assigner="HungarianAssigner",
+        sampler="PseudoSampler",
+        test_cfg=dict(max_per_img=100),
+        nms=None,
+        use_zero_padding=False,
+    ):
+        super().__init__()
+        self.num_classes = num_classes
+        self.in_channels = in_channels
+        self.assigner = assigner
+        self.sampler = sampler
+        self.bg_cls_weight = 0
+        self.num_query = num_query
+        self.sync_cls_avg_factor = sync_cls_avg_factor
+        self.with_box_refine = with_box_refine
+        self.as_two_stage = as_two_stage
+        self.mixed_selection = mixed_selection
+        self.max_pos_coords = max_pos_coords
+        self.lambda_1 = lambda_1
+        self.use_zero_padding = use_zero_padding
+        self.test_cfg = test_cfg
+        self.nms = nms
+        self.transformer = transformer
+        self.num_reg_fcs = num_reg_fcs
+        
+        self.positional_encoding = positional_encoding
+        self.loss_cls = loss_cls
+        self.loss_bbox = loss_bbox
+        self.loss_iou = loss_iou
+        if self.loss_cls.use_sigmoid:
+            self.cls_out_channels = num_classes
+        else:
+            self.cls_out_channels = num_classes + 1
+        self.embed_dims = self.transformer.embed_dims
+
+        num_feats = positional_encoding.num_pos_feats
+        assert num_feats * 2 == self.embed_dims, (
+            "embed_dims should"
+            f" be exactly 2 times of num_feats. Found {self.embed_dims}"
+            f" and {num_feats}."
+        )
+        self._init_layers()
+        self.init_weights()
+
+    def _init_layers(self):
+        """Initialize classification branch and regression branch of head."""
+        self.downsample = nn.Sequential(
+            nn.Conv2D(
+                self.embed_dims, self.embed_dims, kernel_size=3, stride=2, padding=1
+            ),
+            nn.GroupNorm(32, self.embed_dims),
+        )
+
+        fc_cls = nn.Linear(self.embed_dims, self.cls_out_channels)
+
+        reg_branch = []
+        for _ in range(self.num_reg_fcs):
+            reg_branch.append(nn.Linear(self.embed_dims, self.embed_dims))
+            reg_branch.append(nn.ReLU())
+        reg_branch.append(nn.Linear(self.embed_dims, 4))
+        reg_branch = nn.Sequential(*reg_branch)
+
+        def _get_clones(module, N):
+            return nn.LayerList([copy.deepcopy(module) for i in range(N)])
+
+        # last reg_branch is used to generate proposal from
+        # encode feature map when as_two_stage is True.
+        num_pred = (
+            (self.transformer.decoder.num_layers + 1)
+            if self.as_two_stage
+            else self.transformer.decoder.num_layers
+        )
+
+        if self.with_box_refine:
+            self.cls_branches = _get_clones(fc_cls, num_pred)
+            self.reg_branches = _get_clones(reg_branch, num_pred)
+        else:
+            self.cls_branches = nn.LayerList([fc_cls for _ in range(num_pred)])
+            self.reg_branches = nn.LayerList([reg_branch for _ in range(num_pred)])
+
+        if not self.as_two_stage:
+            self.query_embedding = nn.Embedding(self.num_query, self.embed_dims * 2)
+        elif self.mixed_selection:
+            self.query_embedding = nn.Embedding(self.num_query, self.embed_dims)
+
+    def init_weights(self):
+        """Initialize weights of the DeformDETR head."""
+        self.transformer.init_weights()
+        if self.loss_cls.use_sigmoid:
+            bias_init = bias_init_with_prob(0.01)
+            for m in self.cls_branches:
+                constant_(m.bias, bias_init)
+        for m in self.reg_branches:
+            constant_init(m[-1], 0, bias=0)
+        constant_(self.reg_branches[0][-1].bias.data[2:], -2.0)
+        if self.as_two_stage:
+            for m in self.reg_branches:
+                constant_(m[-1].bias.data[2:], 0.0)
+
+    def forward(self, mlvl_feats, img_metas):
+        """Forward function.
+
+        Args:
+            mlvl_feats (tuple[Tensor]): Features from the upstream
+                network, each is a 4D-tensor with shape
+                (N, C, H, W).
+            img_metas (list[dict]): List of image information.
+
+        Returns:
+            all_cls_scores (Tensor): Outputs from the classification head, \
+                shape [nb_dec, bs, num_query, cls_out_channels]. Note \
+                cls_out_channels should includes background.
+            all_bbox_preds (Tensor): Sigmoid outputs from the regression \
+                head with normalized coordinate format (cx, cy, w, h). \
+                Shape [nb_dec, bs, num_query, 4].
+            enc_outputs_class (Tensor): The score of each point on encode \
+                feature map, has shape (N, h*w, num_class). Only when \
+                as_two_stage is True it would be returned, otherwise \
+                `None` would be returned.
+            enc_outputs_coord (Tensor): The proposal generate from the \
+                encode feature map, has shape (N, h*w, 4). Only when \
+                as_two_stage is True it would be returned, otherwise \
+                `None` would be returned.
+        """
+        batch_size = mlvl_feats[0].shape[0]
+        input_img_h, input_img_w = img_metas[0]["batch_input_shape"]
+        img_masks = paddle.zeros((batch_size, input_img_h, input_img_w),mlvl_feats[0].dtype)
+        for img_id in range(batch_size):
+            img_h, img_w, _ = img_metas[img_id]["img_shape"]
+            img_masks[img_id, :img_h, :img_w] = 0
+        
+        mlvl_masks = []
+        mlvl_positional_encodings = []
+        for feat in mlvl_feats:
+            mlvl_masks.append(
+                F.interpolate(
+                    img_masks[None], size=feat.shape[-2:]).squeeze(0))
+            mlvl_positional_encodings.append(self.positional_encoding(paddle.logical_not(mlvl_masks[-1]).astype('float32')).transpose((0,3,1,2)))
+            
+        query_embeds = None
+        if not self.as_two_stage or self.mixed_selection:
+            query_embeds = self.query_embedding.weight
+
+        (
+            hs,
+            init_reference,
+            inter_references,
+            enc_outputs_class,
+            enc_outputs_coord,
+            enc_outputs,
+        ) = self.transformer(
+            mlvl_feats,
+            mlvl_masks,
+            query_embeds,
+            mlvl_positional_encodings,
+            reg_branches=(
+                self.reg_branches if self.with_box_refine else None
+            ),  # noqa:E501
+            cls_branches=self.cls_branches if self.as_two_stage else None,  # noqa:E501
+            return_encoder_output=True,
+        )
+
+        outs = []
+        num_level = len(mlvl_feats)
+        start = 0
+        enc_outputs = enc_outputs.transpose((1,0,2))
+        for lvl in range(num_level):
+            bs, c, h, w = mlvl_feats[lvl].shape
+            end = start + h * w
+            feat = enc_outputs[start:end].transpose((1, 2, 0))
+            start = end
+            outs.append(feat.reshape((bs, c, h, w)))
+        outs.append(self.downsample(outs[-1]))
+
+        outputs_classes = []
+        outputs_coords = []
+
+        for lvl in range(hs.shape[0]):
+            if lvl == 0:
+                reference = init_reference
+            else:
+                reference = inter_references[lvl - 1]
+            reference = inverse_sigmoid(reference)
+            outputs_class = self.cls_branches[lvl](hs[lvl])
+            tmp = self.reg_branches[lvl](hs[lvl])
+            if reference.shape[-1] == 4:
+                tmp += reference
+            else:
+                assert reference.shape[-1] == 2
+                tmp[..., :2] += reference
+            outputs_coord = F.sigmoid(tmp)
+            outputs_classes.append(outputs_class)
+            outputs_coords.append(outputs_coord)
+
+        outputs_classes = paddle.stack(outputs_classes)
+        outputs_coords = paddle.stack(outputs_coords)
+        if self.as_two_stage:
+            return (
+                outputs_classes,
+                outputs_coords,
+                enc_outputs_class,
+                F.sigmoid(enc_outputs_coord),
+                outs,
+            )
+        else:
+            return outputs_classes, outputs_coords, None, None, outs
+
+    def forward_aux(self, mlvl_feats, img_metas, aux_targets, head_idx):
+        """Forward function.
+
+        Args:
+            mlvl_feats (tuple[Tensor]): Features from the upstream
+                network, each is a 4D-tensor with shape
+                (N, C, H, W).
+            img_metas (list[dict]): List of image information.
+
+        Returns:
+            all_cls_scores (Tensor): Outputs from the classification head, \
+                shape [nb_dec, bs, num_query, cls_out_channels]. Note \
+                cls_out_channels should includes background.
+            all_bbox_preds (Tensor): Sigmoid outputs from the regression \
+                head with normalized coordinate format (cx, cy, w, h). \
+                Shape [nb_dec, bs, num_query, 4].
+            enc_outputs_class (Tensor): The score of each point on encode \
+                feature map, has shape (N, h*w, num_class). Only when \
+                as_two_stage is True it would be returned, otherwise \
+                `None` would be returned.
+            enc_outputs_coord (Tensor): The proposal generate from the \
+                encode feature map, has shape (N, h*w, 4). Only when \
+                as_two_stage is True it would be returned, otherwise \
+                `None` would be returned.
+        """
+        (
+            aux_coords,
+            aux_labels,
+            aux_targets,
+            aux_label_weights,
+            aux_bbox_weights,
+            aux_feats,
+            attn_masks,
+        ) = aux_targets
+        batch_size = mlvl_feats[0].shape[0]
+        input_img_h, input_img_w = img_metas[0]["batch_input_shape"]
+        img_masks = paddle.zeros((batch_size, input_img_h, input_img_w),mlvl_feats[0].dtype)
+        for img_id in range(batch_size):
+            img_h, img_w, _ = img_metas[img_id]["img_shape"]
+            img_masks[img_id, :img_h, :img_w] = 0
+
+        mlvl_masks = []
+        mlvl_positional_encodings = []
+        for feat in mlvl_feats:
+            mlvl_masks.append(
+                F.interpolate(img_masks[None], size=feat.shape[-2:])
+                .astype(paddle.bool)
+                .squeeze(0)
+            )
+            mlvl_positional_encodings.append(self.positional_encoding(paddle.logical_not(mlvl_masks[-1]).astype('float32')).transpose((0,3,1,2)))
+
+        query_embeds = None
+        hs, init_reference, inter_references = self.transformer.forward_aux(
+            mlvl_feats,
+            mlvl_masks,
+            query_embeds,
+            mlvl_positional_encodings,
+            aux_coords,
+            pos_feats=aux_feats,
+            reg_branches=(
+                self.reg_branches if self.with_box_refine else None
+            ),  # noqa:E501
+            cls_branches=self.cls_branches if self.as_two_stage else None,  # noqa:E501
+            return_encoder_output=True,
+            attn_masks=attn_masks,
+            head_idx=head_idx,
+        )
+        if hs is None:
+            return None, None, None, None
+        outputs_classes = []
+        outputs_coords = []
+
+        for lvl in range(hs.shape[0]):
+            if lvl == 0:
+                reference = init_reference
+            else:
+                reference = inter_references[lvl - 1]
+            reference = inverse_sigmoid(reference)
+            outputs_class = self.cls_branches[lvl](hs[lvl])
+            tmp = self.reg_branches[lvl](hs[lvl])
+            if reference.shape[-1] == 4:
+                tmp += reference
+            else:
+                assert reference.shape[-1] == 2
+                tmp[..., :2] += reference
+            outputs_coord = F.sigmoid(tmp)
+            outputs_classes.append(outputs_class)
+            outputs_coords.append(outputs_coord)
+
+        outputs_classes = paddle.stack(outputs_classes)
+        outputs_coords = paddle.stack(outputs_coords)
+
+        return outputs_classes, outputs_coords, None, None
+
+    def loss_single_aux(
+        self,
+        cls_scores,
+        bbox_preds,
+        labels,
+        label_weights,
+        bbox_targets,
+        bbox_weights,
+        img_metas,
+        gt_bboxes_ignore_list=None,
+    ):
+        """ "Loss function for outputs from a single decoder layer of a single
+        feature level.
+
+        Args:
+            cls_scores (Tensor): Box score logits from a single decoder layer
+                for all images. Shape [bs, num_query, cls_out_channels].
+            bbox_preds (Tensor): Sigmoid outputs from a single decoder layer
+                for all images, with normalized coordinate (cx, cy, w, h) and
+                shape [bs, num_query, 4].
+            gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
+                with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+            gt_labels_list (list[Tensor]): Ground truth class indices for each
+                image with shape (num_gts, ).
+            img_metas (list[dict]): List of image meta information.
+            gt_bboxes_ignore_list (list[Tensor], optional): Bounding
+                boxes which can be ignored for each image. Default None.
+
+        Returns:
+            dict[str, Tensor]: A dictionary of loss components for outputs from
+                a single decoder layer.
+        """
+        num_imgs = cls_scores.shape[0]
+        num_q = cls_scores.shape[1]
+        try:
+            labels = labels.reshape((num_imgs * num_q))
+            label_weights = label_weights.reshape((num_imgs * num_q))
+            bbox_targets = bbox_targets.reshape((num_imgs * num_q, 4))
+            bbox_weights = bbox_weights.reshape((num_imgs * num_q, 4))
+        except:
+            return cls_scores.mean() * 0, cls_scores.mean() * 0, cls_scores.mean() * 0
+
+        bg_class_ind = self.num_classes
+        num_total_pos = len(
+            ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1)
+        )
+        num_total_neg = num_imgs * num_q - num_total_pos
+
+        # classification loss
+        cls_scores = cls_scores.reshape((-1, self.cls_out_channels))
+        # construct weighted avg_factor to match with the official DETR repo
+        cls_avg_factor = num_total_pos * 1.0 + num_total_neg * self.bg_cls_weight
+        if self.sync_cls_avg_factor:
+            cls_avg_factor = reduce_mean(
+                paddle.to_tensor([cls_avg_factor], dtype=cls_scores.dtype)
+            )
+        cls_avg_factor = max(cls_avg_factor, 1)
+        loss_cls = self.loss_cls(
+            cls_scores, labels, label_weights, avg_factor=cls_avg_factor
+        )
+
+        # Compute the average number of gt boxes across all gpus, for
+        # normalization purposes
+        num_total_pos = loss_cls.new_tensor([num_total_pos])
+        num_total_pos = paddle.clip(reduce_mean(num_total_pos), min=1).item()
+
+        # construct factors used for rescale bboxes
+        factors = []
+        for img_meta, bbox_pred in zip(img_metas, bbox_preds):
+            img_h, img_w, _ = img_meta["img_shape"]
+            factor = (
+                paddle.to_tensor([img_w, img_h, img_w, img_h], dtype=bbox_pred.dtype)
+                .unsqueeze(0)
+                .tile((bbox_pred.shape[0], 1))
+            )
+            factors.append(factor)
+        factors = paddle.concat(factors, 0)
+
+        # DETR regress the relative position of boxes (cxcywh) in the image,
+        # thus the learning target is normalized by the image size. So here
+        # we need to re-scale them for calculating IoU loss
+        bbox_preds = bbox_preds.reshape(-1, 4)
+        bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
+        bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors
+
+        # regression IoU loss, defaultly GIoU loss
+        loss_iou = self.loss_iou(
+            bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos
+        )
+
+        # regression L1 loss
+        loss_bbox = self.loss_bbox(
+            bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos
+        )
+        return (
+            loss_cls * self.lambda_1,
+            loss_bbox * self.lambda_1,
+            loss_iou * self.lambda_1,
+        )
+
+    def get_aux_targets(self, pos_coords, img_metas, mlvl_feats, head_idx):
+        coords, labels, targets = pos_coords[:3]
+        head_name = pos_coords[-1]
+        bs, c = len(coords), mlvl_feats[0].shape[1]
+        max_num_coords = 0
+        all_feats = []
+        for i in range(bs):
+            label = labels[i]
+            feats = [feat[i].reshape((c, -1)).transpose((1, 0)) for feat in mlvl_feats]
+            feats = paddle.concat(feats, axis=0)
+            bg_class_ind = self.num_classes
+            pos_inds = paddle.logical_and((label >= 0),
+                             (label < bg_class_ind)).nonzero().squeeze(1)
+            max_num_coords = max(max_num_coords, len(pos_inds))
+            all_feats.append(feats)
+        max_num_coords = min(self.max_pos_coords, max_num_coords)
+        max_num_coords = max(9, max_num_coords)
+
+        if self.use_zero_padding:
+            attn_masks = []
+            label_weights = paddle.zeros([bs, max_num_coords], coords[0].dtype)
+        else:
+            attn_masks = None
+            label_weights = paddle.zeros([bs, max_num_coords], coords[0].dtype)
+        bbox_weights = paddle.zeros([bs, max_num_coords, 4], coords[0].dtype)
+
+        aux_coords, aux_labels, aux_targets, aux_feats = [], [], [], []
+        for i in range(bs):
+            coord, label, target = coords[i], labels[i], targets[i]
+            feats = all_feats[i]
+            if "rcnn" in head_name:
+                feats = pos_coords[-2][i]
+                num_coords_per_point = 1
+            else:
+                num_coords_per_point = coord.shape[0] // feats.shape[0]
+            feats = feats.unsqueeze(1).tile((1, num_coords_per_point, 1))
+            feats = feats.reshape(
+                (feats.shape[0] * num_coords_per_point, feats.shape[-1])
+            )
+            img_meta = img_metas[i]
+            img_h, img_w, _ = img_meta["img_shape"]
+            factor = (
+                paddle.to_tensor([img_w, img_h, img_w, img_h], dtype="float32")
+                .unsqueeze(0)
+                # .tile((self.num_query, 1))
+            )
+            bg_class_ind = self.num_classes
+            pos_inds = paddle.logical_and((label >= 0),
+                             (label < bg_class_ind)).nonzero().squeeze(1)
+            neg_inds = ((label == bg_class_ind)).nonzero().squeeze(1)
+            if pos_inds.shape[0] > max_num_coords:
+                indices = paddle.randperm(pos_inds.shape[0])[:max_num_coords]
+                pos_inds = pos_inds[indices]
+
+            if pos_inds.shape[0] == 0:
+                return None, None,None,None,None, None,None
+     
+            coord = bbox_xyxy_to_cxcywh(coord[pos_inds] / factor)
+            label = label[pos_inds]
+            target = bbox_xyxy_to_cxcywh(target[pos_inds] / factor)
+            feat = feats[pos_inds]
+
+            if self.use_zero_padding:
+                label_weights[i][: len(label)] = 1
+                bbox_weights[i][: len(label)] = 1
+                attn_mask = paddle.zeros(
+                    [
+                        max_num_coords,
+                        max_num_coords,
+                    ]
+                ).astype(paddle.bool())
+            else:
+                bbox_weights[i][: len(label)] = 1
+
+            if coord.shape[0] < max_num_coords:
+                padding_shape = max_num_coords - coord.shape[0]
+                if self.use_zero_padding:
+                    padding_coord = paddle.zeros([padding_shape, 4])
+                    padding_label = paddle.zeros([padding_shape]) * self.num_classes
+                    padding_target = paddle.zeros([padding_shape, 4])
+                    padding_feat = paddle.zeros([padding_shape, c])
+                    attn_mask[
+                        coord.shape[0] :,
+                        0 : coord.shape[0],
+                    ] = True
+                    attn_mask[
+                        :,
+                        coord.shape[0] :,
+                    ] = True
+                else:
+                    indices = paddle.randperm(neg_inds.shape[0])[:padding_shape]
+                    neg_inds = neg_inds[indices]
+                    padding_coord = bbox_xyxy_to_cxcywh(coords[i][neg_inds] / factor)
+                    padding_label = labels[i][neg_inds]
+                    padding_target = bbox_xyxy_to_cxcywh(targets[i][neg_inds] / factor)
+                    padding_feat = feats[neg_inds]
+                coord = paddle.concat((coord, padding_coord), axis=0)
+                label = paddle.concat((label, padding_label), axis=0)
+                target = paddle.concat((target, padding_target), axis=0)
+                feat = paddle.concat((feat, padding_feat), axis=0)
+            if self.use_zero_padding:
+                attn_masks.append(attn_mask.unsqueeze(0))
+            aux_coords.append(coord.unsqueeze(0))
+            aux_labels.append(label.unsqueeze(0))
+            aux_targets.append(target.unsqueeze(0))
+            aux_feats.append(feat.unsqueeze(0))
+
+        if self.use_zero_padding:
+            attn_masks = (
+                paddle.concat(attn_masks, axis=0).unsqueeze(1).tile((1, 8, 1, 1))
+            )
+            attn_masks = attn_masks.reshape((bs * 8, max_num_coords, max_num_coords))
+        else:
+            attn_mask = None
+
+        aux_coords = paddle.concat(aux_coords, axis=0)
+        aux_labels = paddle.concat(aux_labels, axis=0)
+        aux_targets = paddle.concat(aux_targets, axis=0)
+        aux_feats = paddle.concat(aux_feats, axis=0)
+        aux_label_weights = label_weights
+        aux_bbox_weights = bbox_weights
+        return (
+            aux_coords,
+            aux_labels,
+            aux_targets,
+            aux_label_weights,
+            aux_bbox_weights,
+            aux_feats,
+            attn_masks,
+        )
+
+    # over-write because img_metas are needed as inputs for bbox_head.
+    def forward_train_aux(
+        self,
+        x,
+        img_metas,
+        gt_bboxes,
+        gt_labels=None,
+        gt_bboxes_ignore=None,
+        pos_coords=None,
+        head_idx=0,
+        **kwargs,
+    ):
+        """Forward function for training mode.
+
+        Args:
+            x (list[Tensor]): Features from backbone.
+            img_metas (list[dict]): Meta information of each image, e.g.,
+                image size, scaling factor, etc.
+            gt_bboxes (Tensor): Ground truth bboxes of the image,
+                shape (num_gts, 4).
+            gt_labels (Tensor): Ground truth labels of each box,
+                shape (num_gts,).
+            gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+                ignored, shape (num_ignored_gts, 4).
+            proposal_cfg (mmcv.Config): Test / postprocessing configuration,
+                if None, test_cfg would be used.
+
+        Returns:
+            dict[str, Tensor]: A dictionary of loss components.
+        """
+        aux_targets = self.get_aux_targets(pos_coords, img_metas, x, head_idx)
+        if aux_targets[0] is None:
+            return None
+
+        outs = self.forward_aux(x[:-1], img_metas, aux_targets, head_idx)
+        outs = outs + aux_targets
+        if gt_labels is None:
+            loss_inputs = outs + (gt_bboxes, img_metas)
+        else:
+            loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
+        losses = self.loss_aux(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
+        return losses
+
+    def loss_aux(
+        self,
+        all_cls_scores,
+        all_bbox_preds,
+        enc_cls_scores,
+        enc_bbox_preds,
+        aux_coords,
+        aux_labels,
+        aux_targets,
+        aux_label_weights,
+        aux_bbox_weights,
+        aux_feats,
+        attn_masks,
+        gt_bboxes_list,
+        gt_labels_list,
+        img_metas,
+        gt_bboxes_ignore=None,
+    ):
+        """ "Loss function.
+
+        Args:
+            all_cls_scores (Tensor): Classification score of all
+                decoder layers, has shape
+                [nb_dec, bs, num_query, cls_out_channels].
+            all_bbox_preds (Tensor): Sigmoid regression
+                outputs of all decode layers. Each is a 4D-tensor with
+                normalized coordinate format (cx, cy, w, h) and shape
+                [nb_dec, bs, num_query, 4].
+            enc_cls_scores (Tensor): Classification scores of
+                points on encode feature map , has shape
+                (N, h*w, num_classes). Only be passed when as_two_stage is
+                True, otherwise is None.
+            enc_bbox_preds (Tensor): Regression results of each points
+                on the encode feature map, has shape (N, h*w, 4). Only be
+                passed when as_two_stage is True, otherwise is None.
+            gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
+                with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+            gt_labels_list (list[Tensor]): Ground truth class indices for each
+                image with shape (num_gts, ).
+            img_metas (list[dict]): List of image meta information.
+            gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
+                which can be ignored for each image. Default None.
+
+        Returns:
+            dict[str, Tensor]: A dictionary of loss components.
+        """
+        num_dec_layers = len(all_cls_scores)
+        all_labels = [aux_labels for _ in range(num_dec_layers)]
+        all_label_weights = [aux_label_weights for _ in range(num_dec_layers)]
+        all_bbox_targets = [aux_targets for _ in range(num_dec_layers)]
+        all_bbox_weights = [aux_bbox_weights for _ in range(num_dec_layers)]
+        img_metas_list = [img_metas for _ in range(num_dec_layers)]
+        all_gt_bboxes_ignore_list = [gt_bboxes_ignore for _ in range(num_dec_layers)]
+
+        losses_cls, losses_bbox, losses_iou = multi_apply(
+            self.loss_single_aux,
+            all_cls_scores,
+            all_bbox_preds,
+            all_labels,
+            all_label_weights,
+            all_bbox_targets,
+            all_bbox_weights,
+            img_metas_list,
+            all_gt_bboxes_ignore_list,
+        )
+
+        loss_dict = dict()
+        # loss of proposal generated from encode feature map.
+        # loss from the last decoder layer
+        loss_dict["loss_cls_aux"] = losses_cls[-1]
+        loss_dict["loss_bbox_aux"] = losses_bbox[-1]
+        loss_dict["loss_iou_aux"] = losses_iou[-1]
+        # loss from other decoder layers
+        num_dec_layer = 0
+        for loss_cls_i, loss_bbox_i, loss_iou_i in zip(
+            losses_cls[:-1], losses_bbox[:-1], losses_iou[:-1]
+        ):
+            loss_dict[f"d{num_dec_layer}.loss_cls_aux"] = loss_cls_i
+            loss_dict[f"d{num_dec_layer}.loss_bbox_aux"] = loss_bbox_i
+            loss_dict[f"d{num_dec_layer}.loss_iou_aux"] = loss_iou_i
+            num_dec_layer += 1
+        return loss_dict
+
+    # over-write because img_metas are needed as inputs for bbox_head.
+    def forward_train(
+        self,
+        x,
+        img_metas,
+        gt_bboxes,
+        gt_labels=None,
+        gt_bboxes_ignore=None,
+        proposal_cfg=None,
+        **kwargs,
+    ):
+        """Forward function for training mode.
+
+        Args:
+            x (list[Tensor]): Features from backbone.
+            img_metas (list[dict]): Meta information of each image, e.g.,
+                image size, scaling factor, etc.
+            gt_bboxes (Tensor): Ground truth bboxes of the image,
+                shape (num_gts, 4).
+            gt_labels (Tensor): Ground truth labels of each box,
+                shape (num_gts,).
+            gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+                ignored, shape (num_ignored_gts, 4).
+            proposal_cfg (mmcv.Config): Test / postprocessing configuration,
+                if None, test_cfg would be used.
+
+        Returns:
+            dict[str, Tensor]: A dictionary of loss components.
+        """
+        assert proposal_cfg is None, '"proposal_cfg" must be None'
+        outs = self(x, img_metas)
+        if gt_labels is None:
+            loss_inputs = outs + (gt_bboxes, img_metas)
+        else:
+            loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
+        losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
+        enc_outputs = outs[-1]
+        return losses, enc_outputs
+
+    def loss(
+        self,
+        all_cls_scores,
+        all_bbox_preds,
+        enc_cls_scores,
+        enc_bbox_preds,
+        enc_outputs,
+        gt_bboxes_list,
+        gt_labels_list,
+        img_metas,
+        gt_bboxes_ignore=None,
+    ):
+        """ "Loss function.
+
+        Args:
+            all_cls_scores (Tensor): Classification score of all
+                decoder layers, has shape
+                [nb_dec, bs, num_query, cls_out_channels].
+            all_bbox_preds (Tensor): Sigmoid regression
+                outputs of all decode layers. Each is a 4D-tensor with
+                normalized coordinate format (cx, cy, w, h) and shape
+                [nb_dec, bs, num_query, 4].
+            enc_cls_scores (Tensor): Classification scores of
+                points on encode feature map , has shape
+                (N, h*w, num_classes). Only be passed when as_two_stage is
+                True, otherwise is None.
+            enc_bbox_preds (Tensor): Regression results of each points
+                on the encode feature map, has shape (N, h*w, 4). Only be
+                passed when as_two_stage is True, otherwise is None.
+            gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
+                with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+            gt_labels_list (list[Tensor]): Ground truth class indices for each
+                image with shape (num_gts, ).
+            img_metas (list[dict]): List of image meta information.
+            gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
+                which can be ignored for each image. Default None.
+
+        Returns:
+            dict[str, Tensor]: A dictionary of loss components.
+        """
+
+        num_dec_layers = len(all_cls_scores)
+        all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
+        all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
+        all_gt_bboxes_ignore_list = [gt_bboxes_ignore for _ in range(num_dec_layers)]
+        img_metas_list = [img_metas for _ in range(num_dec_layers)]
+
+        losses_cls, losses_bbox, losses_iou = multi_apply(
+            self.loss_single,
+            all_cls_scores,
+            all_bbox_preds,
+            all_gt_bboxes_list,
+            all_gt_labels_list,
+            img_metas_list,
+            all_gt_bboxes_ignore_list,
+        )
+
+        loss_dict = dict()
+        # loss of proposal generated from encode feature map.
+        if enc_cls_scores is not None:
+            binary_labels_list = [
+                paddle.zeros_like(gt_labels_list[i]) for i in range(len(img_metas))
+            ]
+            enc_loss_cls, enc_losses_bbox, enc_losses_iou = self.loss_single(
+                enc_cls_scores,
+                enc_bbox_preds,
+                gt_bboxes_list,
+                binary_labels_list,
+                img_metas,
+                gt_bboxes_ignore,
+            )
+            loss_dict["enc_loss_cls"] = enc_loss_cls
+            loss_dict["enc_loss_bbox"] = enc_losses_bbox
+            loss_dict["enc_loss_iou"] = enc_losses_iou
+
+        # loss from the last decoder layer
+        loss_dict["loss_cls"] = losses_cls[-1]
+        loss_dict["loss_bbox"] = losses_bbox[-1]
+        loss_dict["loss_iou"] = losses_iou[-1]
+        # loss from other decoder layers
+        num_dec_layer = 0
+        for loss_cls_i, loss_bbox_i, loss_iou_i in zip(
+            losses_cls[:-1], losses_bbox[:-1], losses_iou[:-1]
+        ):
+            loss_dict[f"d{num_dec_layer}.loss_cls"] = loss_cls_i
+            loss_dict[f"d{num_dec_layer}.loss_bbox"] = loss_bbox_i
+            loss_dict[f"d{num_dec_layer}.loss_iou"] = loss_iou_i
+            num_dec_layer += 1
+        return loss_dict
+
+    def get_bboxes(
+        self,
+        all_cls_scores,
+        all_bbox_preds,
+        enc_cls_scores,
+        enc_bbox_preds,
+        enc_outputs,
+        img_metas,
+        rescale=False,
+    ):
+        """Transform network outputs for a batch into bbox predictions.
+
+        Args:
+            all_cls_scores (Tensor): Classification score of all
+                decoder layers, has shape
+                [nb_dec, bs, num_query, cls_out_channels].
+            all_bbox_preds (Tensor): Sigmoid regression
+                outputs of all decode layers. Each is a 4D-tensor with
+                normalized coordinate format (cx, cy, w, h) and shape
+                [nb_dec, bs, num_query, 4].
+            enc_cls_scores (Tensor): Classification scores of
+                points on encode feature map , has shape
+                (N, h*w, num_classes). Only be passed when as_two_stage is
+                True, otherwise is None.
+            enc_bbox_preds (Tensor): Regression results of each points
+                on the encode feature map, has shape (N, h*w, 4). Only be
+                passed when as_two_stage is True, otherwise is None.
+            img_metas (list[dict]): Meta information of each image.
+            rescale (bool, optional): If True, return boxes in original
+                image space. Default False.
+
+        Returns:
+            list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. \
+                The first item is an (n, 5) tensor, where the first 4 columns \
+                are bounding box positions (tl_x, tl_y, br_x, br_y) and the \
+                5-th column is a score between 0 and 1. The second item is a \
+                (n,) tensor where each item is the predicted class label of \
+                the corresponding box.
+        """
+        cls_scores = all_cls_scores[-1]
+        bbox_preds = all_bbox_preds[-1]
+        result_list = []
+        for img_id in range(len(img_metas)):
+            cls_score = cls_scores[img_id]
+            bbox_pred = bbox_preds[img_id]
+            img_shape = img_metas[img_id]["img_shape"]
+            scale_factor = img_metas[img_id]["scale_factor"]
+            proposals = self._get_bboxes_single(
+                cls_score, bbox_pred, img_shape, scale_factor, rescale
+            )
+            result_list.append(proposals)
+        
+        return result_list
+
+    def _get_bboxes_single(self,
+                           cls_score,
+                           bbox_pred,
+                           img_shape,
+                           scale_factor,
+                           rescale=False,
+                           ):
+        """Transform outputs from the last decoder layer into bbox predictions
+        for each image.
+
+        Args:
+            cls_score (Tensor): Box score logits from the last decoder layer
+                for each image. Shape [num_query, cls_out_channels].
+            bbox_pred (Tensor): Sigmoid outputs from the last decoder layer
+                for each image, with coordinate format (cx, cy, w, h) and
+                shape [num_query, 4].
+            img_shape (tuple[int]): Shape of input image, (height, width, 3).
+            scale_factor (ndarray, optional): Scale factor of the image arange
+                as (w_scale, h_scale, w_scale, h_scale).
+            rescale (bool, optional): If True, return boxes in original image
+                space. Default False.
+
+        Returns:
+            tuple[Tensor]: Results of detected bboxes and labels.
+
+                - det_bboxes: Predicted bboxes with shape [num_query, 5], \
+                    where the first 4 columns are bounding box positions \
+                    (tl_x, tl_y, br_x, br_y) and the 5-th column are scores \
+                    between 0 and 1.
+                - det_labels: Predicted labels of the corresponding box with \
+                    shape [num_query].
+        """
+        assert len(cls_score) == len(bbox_pred)
+        max_per_img = self.test_cfg.get('max_per_img', self.num_query)
+        score_thr = self.test_cfg.get('score_thr', 0)
+
+        # exclude background
+        if self.loss_cls.use_sigmoid:
+            cls_score = F.sigmoid(cls_score)
+            scores, indexes = cls_score.reshape([-1]).topk(max_per_img)
+            det_labels = indexes % self.num_classes
+            bbox_index = indexes // self.num_classes
+            bbox_pred = bbox_pred[bbox_index]
+        else:
+            scores, det_labels = F.softmax(cls_score, axis=-1)[..., :-1].max(-1)
+            scores, bbox_index = scores.topk(max_per_img)
+            bbox_pred = bbox_pred[bbox_index]
+            det_labels = det_labels[bbox_index]
+
+        valid_mask = scores > score_thr
+        scores = scores[valid_mask]
+        bbox_pred = bbox_pred[valid_mask]
+        det_labels = det_labels[valid_mask]
+
+        det_bboxes = bbox_cxcywh_to_xyxy(bbox_pred)
+        det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1]
+        det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0]
+        det_bboxes[:, 0::2].clip(min=0, max=img_shape[1])
+        det_bboxes[:, 1::2].clip(min=0, max=img_shape[0])
+
+        if rescale:
+            det_bboxes /=paddle.concat([scale_factor[::-1], scale_factor[::-1]])
+        det_bboxes = paddle.concat((scores.unsqueeze(1),det_bboxes.astype('float32')), -1)
+        proposals = paddle.concat((det_labels.unsqueeze(1).astype('float32'),det_bboxes),-1)
+        return proposals
+
+    def loss_single(
+        self,
+        cls_scores,
+        bbox_preds,
+        gt_bboxes_list,
+        gt_labels_list,
+        img_metas,
+        gt_bboxes_ignore_list=None,
+    ):
+        """ "Loss function for outputs from a single decoder layer of a single
+        feature level.
+
+        Args:
+            cls_scores (Tensor): Box score logits from a single decoder layer
+                for all images. Shape [bs, num_query, cls_out_channels].
+            bbox_preds (Tensor): Sigmoid outputs from a single decoder layer
+                for all images, with normalized coordinate (cx, cy, w, h) and
+                shape [bs, num_query, 4].
+            gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
+                with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+            gt_labels_list (list[Tensor]): Ground truth class indices for each
+                image with shape (num_gts, ).
+            img_metas (list[dict]): List of image meta information.
+            gt_bboxes_ignore_list (list[Tensor], optional): Bounding
+                boxes which can be ignored for each image. Default None.
+
+        Returns:
+            dict[str, Tensor]: A dictionary of loss components for outputs from
+                a single decoder layer.
+        """
+        num_imgs = cls_scores.shape[0]
+        cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
+        bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
+        cls_reg_targets = self.get_targets(
+            cls_scores_list,
+            bbox_preds_list,
+            gt_bboxes_list,
+            gt_labels_list,
+            img_metas,
+            gt_bboxes_ignore_list,
+        )
+        (
+            labels_list,
+            label_weights_list,
+            bbox_targets_list,
+            bbox_weights_list,
+            num_total_pos,
+            num_total_neg,
+        ) = cls_reg_targets
+        labels = paddle.concat(labels_list, 0)
+        label_weights = paddle.concat(label_weights_list, 0)
+        bbox_targets = paddle.concat(bbox_targets_list, 0)
+        bbox_weights = paddle.concat(bbox_weights_list, 0)
+
+        # classification loss
+        cls_scores = cls_scores.reshape((-1, self.cls_out_channels))
+        # construct weighted avg_factor to match with the official DETR repo
+        cls_avg_factor = num_total_pos * 1.0 + num_total_neg * self.bg_cls_weight
+        if self.sync_cls_avg_factor:
+            cls_avg_factor = reduce_mean(
+                paddle.to_tensor([cls_avg_factor], dtype=cls_scores.dtype)
+            )
+        cls_avg_factor = max(cls_avg_factor, 1)
+        loss_cls = self.loss_cls(
+            cls_scores, labels, label_weights, avg_factor=cls_avg_factor
+        )
+
+        # Compute the average number of gt boxes across all gpus, for
+        # normalization purposes
+        num_total_pos = paddle.to_tensor([num_total_pos], dtype=loss_cls.dtype)
+        num_total_pos = paddle.clip(reduce_mean(num_total_pos), min=1).item()
+
+        # construct factors used for rescale bboxes
+        factors = []
+        for img_meta, bbox_pred in zip(img_metas, bbox_preds):
+            img_h, img_w, _ = img_meta["img_shape"]
+            factor = (
+                paddle.to_tensor([img_w, img_h, img_w, img_h], dtype=bbox_pred.dtype)
+                .unsqueeze(0)
+                .tile((bbox_pred.shape[0], 1))
+            )
+            factors.append(factor)
+        factors = paddle.concat(factors, 0)
+
+        # DETR regress the relative position of boxes (cxcywh) in the image,
+        # thus the learning target is normalized by the image size. So here
+        # we need to re-scale them for calculating IoU loss
+        bbox_preds = bbox_preds.reshape((-1, 4))
+        bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
+        bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors
+
+        # regression IoU loss, defaultly GIoU loss
+        loss_iou = self.loss_iou(
+            bboxes, bboxes_gt, bbox_weights
+        ).mean()
+
+        # regression L1 loss
+        loss_bbox = self.loss_bbox(
+            bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos
+        )
+
+        return loss_cls, loss_bbox, loss_iou
+
+    def get_targets(
+        self,
+        cls_scores_list,
+        bbox_preds_list,
+        gt_bboxes_list,
+        gt_labels_list,
+        img_metas,
+        gt_bboxes_ignore_list=None,
+    ):
+        """"Compute regression and classification targets for a batch image.
+
+        Outputs from a single decoder layer of a single feature level are used.
+
+        Args:
+            cls_scores_list (list[Tensor]): Box score logits from a single
+                decoder layer for each image with shape [num_query,
+                cls_out_channels].
+            bbox_preds_list (list[Tensor]): Sigmoid outputs from a single
+                decoder layer for each image, with normalized coordinate
+                (cx, cy, w, h) and shape [num_query, 4].
+            gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
+                with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+            gt_labels_list (list[Tensor]): Ground truth class indices for each
+                image with shape (num_gts, ).
+            img_metas (list[dict]): List of image meta information.
+            gt_bboxes_ignore_list (list[Tensor], optional): Bounding
+                boxes which can be ignored for each image. Default None.
+
+        Returns:
+            tuple: a tuple containing the following targets.
+
+                - labels_list (list[Tensor]): Labels for all images.
+                - label_weights_list (list[Tensor]): Label weights for all \
+                    images.
+                - bbox_targets_list (list[Tensor]): BBox targets for all \
+                    images.
+                - bbox_weights_list (list[Tensor]): BBox weights for all \
+                    images.
+                - num_total_pos (int): Number of positive samples in all \
+                    images.
+                - num_total_neg (int): Number of negative samples in all \
+                    images.
+        """
+        num_imgs = len(cls_scores_list)
+        if gt_bboxes_ignore_list is None:
+            gt_bboxes_ignore_list = [gt_bboxes_ignore_list for _ in range(num_imgs)]
+
+        (
+            labels_list,
+            label_weights_list,
+            bbox_targets_list,
+            bbox_weights_list,
+            pos_inds_list,
+            neg_inds_list,
+        ) = multi_apply(
+            self._get_target_single,
+            cls_scores_list,
+            bbox_preds_list,
+            gt_bboxes_list,
+            gt_labels_list,
+            img_metas,
+            gt_bboxes_ignore_list,
+        )
+        num_total_pos = sum((inds.numel() for inds in pos_inds_list))
+        num_total_neg = sum((inds.numel() for inds in neg_inds_list))
+        return (
+            labels_list,
+            label_weights_list,
+            bbox_targets_list,
+            bbox_weights_list,
+            num_total_pos,
+            num_total_neg,
+        )
+
+    def _get_target_single(
+        self,
+        cls_score,
+        bbox_pred,
+        gt_bboxes,
+        gt_labels,
+        img_meta,
+        gt_bboxes_ignore=None,
+    ):
+        """ "Compute regression and classification targets for one image.
+
+        Outputs from a single decoder layer of a single feature level are used.
+
+        Args:
+            cls_score (Tensor): Box score logits from a single decoder layer
+                for one image. Shape [num_query, cls_out_channels].
+            bbox_pred (Tensor): Sigmoid outputs from a single decoder layer
+                for one image, with normalized coordinate (cx, cy, w, h) and
+                shape [num_query, 4].
+            gt_bboxes (Tensor): Ground truth bboxes for one image with
+                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+            gt_labels (Tensor): Ground truth class indices for one image
+                with shape (num_gts, ).
+            img_meta (dict): Meta information for one image.
+            gt_bboxes_ignore (Tensor, optional): Bounding boxes
+                which can be ignored. Default None.
+
+        Returns:
+            tuple[Tensor]: a tuple containing the following for one image.
+
+                - labels (Tensor): Labels of each image.
+                - label_weights (Tensor]): Label weights of each image.
+                - bbox_targets (Tensor): BBox targets of each image.
+                - bbox_weights (Tensor): BBox weights of each image.
+                - pos_inds (Tensor): Sampled positive indices for each image.
+                - neg_inds (Tensor): Sampled negative indices for each image.
+        """
+        num_bboxes = bbox_pred.shape[0]
+        # assigner and sampler
+        assign_result = self.assigner.assign(
+            bbox_pred, cls_score, gt_bboxes, gt_labels, img_meta, gt_bboxes_ignore
+        )
+        sampling_result = self.sampler.sample(assign_result, bbox_pred, gt_bboxes)
+        pos_inds = sampling_result.pos_inds
+        neg_inds = sampling_result.neg_inds
+        # label targets
+        labels = paddle.full((num_bboxes,), self.num_classes, dtype="int64")
+        labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds][..., 0].astype("int64")
+        label_weights = paddle.ones((num_bboxes,), dtype=gt_bboxes.dtype)
+        # bbox targets
+        bbox_targets = paddle.zeros_like(bbox_pred)
+        bbox_weights = paddle.zeros_like(bbox_pred)
+        bbox_weights[pos_inds] = 1.0
+        img_h, img_w, _ = img_meta["img_shape"]
+
+        
+        # DETR regress the relative position of boxes (cxcywh) in the image.
+        # Thus the learning target should be normalized by the image size, also
+        # the box format should be converted from defaultly x1y1x2y2 to cxcywh.
+        factor = paddle.to_tensor(
+            [img_w, img_h, img_w, img_h], dtype=bbox_pred.dtype
+        ).unsqueeze(0)
+        pos_gt_bboxes_normalized = sampling_result.pos_gt_bboxes / factor
+        pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized)
+        bbox_targets[pos_inds] = pos_gt_bboxes_targets
+        
+        return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, neg_inds)
+
+    def simple_test(self, feats, img_metas, rescale=False):
+        """Test det bboxes without test-time augmentation.
+
+        Args:
+            feats (tuple[Tensor]): Multi-level features from the
+                upstream network, each is a 4D-tensor.
+            img_metas (list[dict]): List of image information.
+            rescale (bool, optional): Whether to rescale the results.
+                Defaults to False.
+
+        Returns:
+            list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+                The first item is ``bboxes`` with shape (n, 5),
+                where 5 represent (tl_x, tl_y, br_x, br_y, score).
+                The shape of the second tensor in the tuple is ``labels``
+                with shape (n,)
+        """
+        # forward of this head requires img_metas
+        outs = self.forward(feats, img_metas)
+        result_list = self.get_bboxes(*outs, img_metas, rescale=rescale)
+        return result_list
+    
\ No newline at end of file
diff --git a/ppdet/modeling/heads/co_roi_head.py b/ppdet/modeling/heads/co_roi_head.py
new file mode 100644
index 0000000000..73c6059f64
--- /dev/null
+++ b/ppdet/modeling/heads/co_roi_head.py
@@ -0,0 +1,129 @@
+
+import paddle
+import paddle.nn.functional as F
+
+from ppdet.core.workspace import register
+from ppdet.modeling.heads.bbox_head import BBoxHead
+from .roi_extractor import RoIAlign
+from ..cls_utils import _get_class_default_kwargs
+
+__all__ = ['Co_RoiHead']
+
+@register
+class Co_RoiHead(BBoxHead):
+    __shared__ = ['num_classes', 'use_cot']
+    __inject__ = ['bbox_assigner', 'bbox_loss', 'loss_cot']
+    """
+    RCNN bbox head
+
+    Args:
+        head (nn.Layer): Extract feature in bbox head
+        in_channel (int): Input channel after RoI extractor
+        roi_extractor (object): The module of RoI Extractor
+        bbox_assigner (object): The module of Box Assigner, label and sample the 
+            box.
+        with_pool (bool): Whether to use pooling for the RoI feature.
+        num_classes (int): The number of classes
+        bbox_weight (List[float]): The weight to get the decode box
+        cot_classes (int): The number of base classes
+        loss_cot (object): The module of Label-cotuning
+        use_cot(bool): whether to use Label-cotuning 
+    """
+
+    def __init__(self,
+                 head,
+                 in_channel,
+                 roi_extractor=_get_class_default_kwargs(RoIAlign),
+                 bbox_assigner='BboxAssigner',
+                 with_pool=False,
+                 num_classes=80,
+                 bbox_weight=[10., 10., 5., 5.],
+                 bbox_loss=None,
+                 loss_normalize_pos=False,
+                 cot_classes=None,
+                 loss_cot='COTLoss',
+                 use_cot=False):
+        super(Co_RoiHead, self).__init__(
+            head=head,
+            in_channel=in_channel,
+            roi_extractor=roi_extractor,
+            bbox_assigner=bbox_assigner,
+            with_pool=with_pool,
+            num_classes=num_classes,
+            bbox_weight=bbox_weight,
+            bbox_loss =bbox_loss,
+            loss_normalize_pos=loss_normalize_pos,
+            cot_classes=cot_classes,
+            loss_cot=loss_cot,
+            use_cot=use_cot
+            )
+        self.head=head
+    
+    def forward(self, body_feats=None, rois=None, rois_num=None, inputs=None, cot=False):
+        """
+        body_feats (list[Tensor]): Feature maps from backbone
+        rois (list[Tensor]): RoIs generated from RPN module
+        rois_num (Tensor): The number of RoIs in each image
+        inputs (dict{Tensor}): The ground-truth of image
+        """
+        if self.training:
+            rois, rois_num, targets = self.bbox_assigner(rois, rois_num, inputs)
+            self.assigned_rois = (rois, rois_num)
+            self.assigned_targets = targets
+
+        rois_feat = self.roi_extractor(body_feats, rois, rois_num)
+        bbox_feat = self.head(rois_feat)
+        if self.with_pool:
+            feat = F.adaptive_avg_pool2d(bbox_feat, output_size=1)
+            feat = paddle.squeeze(feat, axis=[2, 3])
+        else:
+            feat = bbox_feat
+        if self.use_cot:
+            scores = self.cot_bbox_score(feat)
+            cot_scores = self.bbox_score(feat)
+        else:
+            scores = self.bbox_score(feat)
+        deltas = self.bbox_delta(feat)
+
+        if self.training:
+            loss = self.get_loss(
+                scores,
+                deltas,
+                targets,
+                rois,
+                self.bbox_weight,
+                loss_normalize_pos=self.loss_normalize_pos)
+            
+            if self.cot_relation is not None:
+                loss_cot = self.loss_cot(cot_scores, targets, self.cot_relation)
+                loss.update(loss_cot)
+                
+            target_labels,target_bboxs,_ = targets
+            max_proposal = target_labels[0].shape[0]
+            # get pos_coords
+            ori_proposals, ori_labels, ori_bbox_targets, ori_bbox_feats = [], [], [], []
+            for i in range(len(rois)):
+                ori_proposal = rois[i].unsqueeze(0)
+                ori_label = target_labels[i].unsqueeze(0)
+                ori_bbox_target = target_bboxs[i].unsqueeze(0)
+                
+                ori_bbox_feat = rois_feat[i*max_proposal:(i+1)*max_proposal].mean(-1).mean(-1)
+                ori_bbox_feat = ori_bbox_feat.unsqueeze(0)
+                ori_proposals.append(ori_proposal) 
+                ori_labels.append(ori_label)
+                ori_bbox_targets.append(ori_bbox_target)
+                ori_bbox_feats.append(ori_bbox_feat)
+                
+            ori_coords = paddle.concat(ori_proposals, axis=0)
+            ori_labels = paddle.concat(ori_labels, axis=0)
+            ori_bbox_targets = paddle.concat(ori_bbox_targets, axis=0)
+            ori_bbox_feats = paddle.concat(ori_bbox_feats, axis=0)
+            pos_coords = (ori_coords, ori_labels, ori_bbox_targets, ori_bbox_feats, 'rcnn')
+            loss.update(pos_coords=pos_coords)
+            return loss, bbox_feat
+        else:
+            if cot:
+                pred = self.get_prediction(cot_scores, deltas)
+            else:
+                pred = self.get_prediction(scores, deltas)
+            return pred, self.head
diff --git a/ppdet/modeling/heads/petr_head.py b/ppdet/modeling/heads/petr_head.py
index 90760c6651..5888d77a5e 100644
--- a/ppdet/modeling/heads/petr_head.py
+++ b/ppdet/modeling/heads/petr_head.py
@@ -186,7 +186,7 @@ def __init__(self,
                  loss_oks='OKSLoss',
                  loss_hm='CenterFocalLoss',
                  with_kpt_refine=True,
-                 assigner='PoseHungarianAssigner',
+                 assigner='HungarianAssigner',
                  sampler='PseudoSampler',
                  loss_kpt_rpn='L1Loss',
                  loss_kpt_refine='L1Loss',
diff --git a/ppdet/modeling/proposal_generator/anchor_generator.py b/ppdet/modeling/proposal_generator/anchor_generator.py
index d189f784a2..f1ae6ff9e7 100644
--- a/ppdet/modeling/proposal_generator/anchor_generator.py
+++ b/ppdet/modeling/proposal_generator/anchor_generator.py
@@ -23,7 +23,7 @@
 
 from ppdet.core.workspace import register
 
-__all__ = ['AnchorGenerator', 'RetinaAnchorGenerator', 'S2ANetAnchorGenerator']
+__all__ = ['AnchorGenerator', 'RetinaAnchorGenerator', 'S2ANetAnchorGenerator','CoAnchorGenerator']
 
 
 @register
@@ -51,15 +51,17 @@ def __init__(self,
                  aspect_ratios=[0.5, 1.0, 2.0],
                  strides=[16.0],
                  variance=[1.0, 1.0, 1.0, 1.0],
-                 offset=0.):
+                 offset=0.,
+                 ):
         super(AnchorGenerator, self).__init__()
         self.anchor_sizes = anchor_sizes
         self.aspect_ratios = aspect_ratios
         self.strides = strides
+        self.num_levels = len(self.strides)
         self.variance = variance
         self.cell_anchors = self._calculate_anchors(len(strides))
         self.offset = offset
-
+        
     def _broadcast_params(self, params, num_features):
         if not isinstance(params[0], (list, tuple)):  # list[float]
             return [params] * num_features
@@ -121,6 +123,7 @@ def forward(self, input):
         anchors_over_all_feature_maps = self._grid_anchors(grid_sizes)
         return anchors_over_all_feature_maps
 
+
     @property
     def num_anchors(self):
         """
@@ -155,7 +158,100 @@ def __init__(self,
             variance=variance,
             offset=offset)
 
+    
+@register
+class CoAnchorGenerator(AnchorGenerator):
+    def __init__(self,
+                 octave_base_scale=4,
+                 scales_per_octave=3,
+                 aspect_ratios=[0.5, 1.0, 2.0],
+                 strides=[8.0, 16.0, 32.0, 64.0, 128.0],
+                 variance=[1.0, 1.0, 1.0, 1.0],
+                 offset=0.0):
+        anchor_sizes = []
+        for s in strides:
+            anchor_sizes.append([
+                s * octave_base_scale * 2**(i/scales_per_octave) \
+                for i in range(scales_per_octave)])
+        super(CoAnchorGenerator, self).__init__(
+            anchor_sizes=anchor_sizes,
+            aspect_ratios=aspect_ratios,
+            strides=strides,
+            variance=variance,
+            offset=offset)
+        
+    def _meshgrid(self, x, y, row_major=True):
+        yy, xx = paddle.meshgrid(y, x)
+        yy = yy.reshape([-1])
+        xx = xx.reshape([-1])
+        if row_major:
+            return xx, yy
+        else:
+            return yy, xx
+        
+    def valid_flags(self, featmap_sizes, pad_shape):
+        """Generate valid flags of anchors in multiple feature levels.
+
+        Args:
+            featmap_sizes (list(tuple)): List of feature map sizes in
+                multiple feature levels.
+            pad_shape (tuple): The padded shape of the image.
+            device (str): Device where the anchors will be put on.
 
+        Return:
+            list(torch.Tensor): Valid flags of anchors in multiple levels.
+        """
+        featmap_sizes = [feature_map.shape[-2:] for feature_map in featmap_sizes]
+        num_base_anchors = [base_anchors.shape[0] for base_anchors in self.cell_anchors]
+        assert self.num_levels == len(featmap_sizes)
+        multi_level_flags = []
+        for i in range(self.num_levels):
+            anchor_stride = self.strides[i]
+            feat_h, feat_w = featmap_sizes[i]
+            h, w = pad_shape[:2]
+            valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
+            valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
+            flags = self.single_level_valid_flags((feat_h, feat_w),
+                                                  (valid_feat_h, valid_feat_w),
+                                                   num_base_anchors[i],
+                                                  )
+            multi_level_flags.append(flags)
+        return multi_level_flags
+    
+    def single_level_valid_flags(self,
+                                 featmap_size,
+                                 valid_size,
+                                 num_base_anchors,
+                                 ):
+        """Generate the valid flags of anchor in a single feature map.
+
+        Args:
+            featmap_size (tuple[int]): The size of feature maps, arrange
+                as (h, w).
+            valid_size (tuple[int]): The valid size of the feature maps.
+            num_base_anchors (int): The number of base anchors.
+            device (str, optional): Device where the flags will be put on.
+                Defaults to 'cuda'.
+
+        Returns:
+            torch.Tensor: The valid flags of each anchor in a single level \
+                feature map.
+        """
+        feat_h, feat_w = featmap_size
+        valid_h, valid_w = valid_size
+        assert valid_h <= feat_h and valid_w <= feat_w
+        valid_x = paddle.zeros(feat_w, dtype='int32')
+        valid_y = paddle.zeros(feat_h, dtype='int32')
+        valid_x[:valid_w] = 1
+        valid_y[:valid_h] = 1
+        valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
+        valid = valid_xx & valid_yy
+        valid = paddle.reshape(valid, [-1, 1])
+        valid = paddle.expand(valid, [-1, num_base_anchors]).reshape([-1])
+        
+        return valid
+    
+    
 @register
 class S2ANetAnchorGenerator(nn.Layer):
     """
@@ -202,15 +298,6 @@ def gen_base_anchors(self):
         base_anchors = paddle.round(base_anchors)
         return base_anchors
 
-    def _meshgrid(self, x, y, row_major=True):
-        yy, xx = paddle.meshgrid(y, x)
-        yy = yy.reshape([-1])
-        xx = xx.reshape([-1])
-        if row_major:
-            return xx, yy
-        else:
-            return yy, xx
-
     def forward(self, featmap_size, stride=16):
         # featmap_size*stride project it to original area
 
@@ -227,6 +314,14 @@ def forward(self, featmap_size, stride=16):
         all_anchors = self.rect2rbox(all_anchors)
         return all_anchors
 
+    def _meshgrid(self, x, y, row_major=True):
+        yy, xx = paddle.meshgrid(y, x)
+        yy = yy.reshape([-1])
+        xx = xx.reshape([-1])
+        if row_major:
+            return xx, yy
+        else:
+            return yy, xx
     def valid_flags(self, featmap_size, valid_size):
         feat_h, feat_w = featmap_size
         valid_h, valid_w = valid_size
diff --git a/ppdet/modeling/transformers/__init__.py b/ppdet/modeling/transformers/__init__.py
index 5eac4f110d..27148e09ab 100644
--- a/ppdet/modeling/transformers/__init__.py
+++ b/ppdet/modeling/transformers/__init__.py
@@ -23,6 +23,7 @@
 from . import rtdetr_transformer
 from . import hybrid_encoder
 from . import mask_rtdetr_transformer
+from . import co_deformable_detr_transformer
 
 from .detr_transformer import *
 from .utils import *
@@ -36,3 +37,4 @@
 from .rtdetr_transformer import *
 from .hybrid_encoder import *
 from .mask_rtdetr_transformer import *
+from .co_deformable_detr_transformer import *
diff --git a/ppdet/modeling/transformers/co_deformable_detr_transformer.py b/ppdet/modeling/transformers/co_deformable_detr_transformer.py
new file mode 100644
index 0000000000..b9a546b0a6
--- /dev/null
+++ b/ppdet/modeling/transformers/co_deformable_detr_transformer.py
@@ -0,0 +1,639 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+this code is base on https://github.com/Sense-X/Co-DETR/blob/main/projects/models/transformer.py
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import numpy as np
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+
+from ppdet.core.workspace import register
+from ..layers import MultiHeadAttention, _convert_attention_mask
+from .utils import _get_clones
+from ..initializer import linear_init_, normal_, constant_, xavier_uniform_
+from ..shape_spec import ShapeSpec
+
+from .petr_transformer import (
+    PETR_TransformerDecoder,
+    MSDeformableAttention,
+    TransformerEncoder,
+    inverse_sigmoid,
+)
+
+__all__ = [
+    "CoDeformableDetrTransformerDecoder",
+    "CoDeformableDetrTransformer",
+    "CoTransformerEncoder",
+    
+]
+
+@register
+class CoTransformerEncoder(TransformerEncoder):
+    def __init__(self, encoder_layer, num_layers, norm=None,out_channel=256,spatial_scales=[1/8,1/16,1/32,1/64,1/128]):
+        super().__init__(encoder_layer, num_layers, norm)
+        self.out_channel=out_channel
+        self.spatial_scales=spatial_scales
+        
+    @property
+    def out_shape(self):
+        return [
+            ShapeSpec(
+                channels=self.out_channel, stride=1. / s)
+            for s in self.spatial_scales
+        ]
+
+@register
+class CoDeformableDetrTransformerDecoder(PETR_TransformerDecoder):
+    __inject__ = ["decoder_layer"]
+
+    def __init__(
+        self,
+        decoder_layer,
+        num_layers,
+        norm=None,
+        return_intermediate=False,
+        look_forward_twice=False,
+        **kwargs
+    ):
+        super().__init__(decoder_layer, num_layers, norm, return_intermediate, **kwargs)
+        self.layers = _get_clones(decoder_layer, num_layers)
+        self.num_layers = num_layers
+        self.norm = norm
+        self.return_intermediate = return_intermediate
+        self.look_forward_twice = look_forward_twice
+
+    def forward(
+        self,
+        query,
+        *args,
+        reference_points=None,
+        valid_ratios=None,
+        reg_branches=None,
+        **kwargs
+    ):
+        """Forward function for `TransformerDecoder`.
+
+        Args:
+            query (Tensor): Input query with shape (num_query, bs, embed_dims).
+            reference_points (Tensor): The reference points of offset,
+                has shape (bs, num_query, K*2).
+            valid_ratios (Tensor): The radios of valid points on the feature
+                map, has shape (bs, num_levels, 2).
+            reg_branch: (obj:`nn.ModuleList`): Used for refining the regression results.
+                Only would be passed when with_box_refine is True,otherwise would be
+                passed a `None`.
+
+        Returns:
+            Tensor: Results with shape [1, num_query, bs, embed_dims] when
+                return_intermediate is `False`, otherwise it has shape
+                [num_layers, num_query, bs, embed_dims].
+        """
+
+        output = query
+        intermediate = []
+        intermediate_reference_points = []
+        for lid, layer in enumerate(self.layers):
+            if reference_points.shape[-1] == 4:
+                reference_points_input = (
+                    reference_points[:, :, None]
+                    * paddle.concat([valid_ratios, valid_ratios], -1)[:, None]
+                )
+            else:
+                assert reference_points.shape[-1] == 2
+                reference_points_input = (
+                    reference_points[:, :, None] * valid_ratios[:, None]
+                )
+            output = layer(
+                output, *args, reference_points=reference_points_input, **kwargs
+            )
+
+            if reg_branches is not None:
+                tmp = reg_branches[lid](output)
+                if reference_points.shape[-1] == 4:
+                    new_reference_points = tmp + inverse_sigmoid(reference_points)
+                    new_reference_points = F.sigmoid(new_reference_points)
+                else:
+                    assert reference_points.shape[-1] == 2
+                    new_reference_points = tmp
+                    new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(
+                        reference_points
+                    )
+                    new_reference_points = F.sigmoid(new_reference_points)
+                reference_points = new_reference_points.detach()
+
+            if self.return_intermediate:
+                intermediate.append(output)
+                intermediate_reference_points.append(
+                    new_reference_points
+                    if self.look_forward_twice
+                    else reference_points
+                )
+
+        if self.return_intermediate:
+            return paddle.stack(intermediate), paddle.stack(
+                intermediate_reference_points
+            )
+
+        return output, reference_points
+
+
+@register
+class CoDeformableDetrTransformer(nn.Layer):
+    __inject__ = ["encoder", "decoder"]
+
+    def __init__(
+        self,
+        encoder="",
+        decoder="",
+        mixed_selection=True,
+        with_pos_coord=True,
+        with_coord_feat=True,
+        num_co_heads=1,
+        as_two_stage=False,
+        two_stage_num_proposals=300,
+        num_feature_levels=4,
+        **kwargs
+    ):
+        super(CoDeformableDetrTransformer, self).__init__(**kwargs)
+
+        self.as_two_stage = as_two_stage
+        self.two_stage_num_proposals = two_stage_num_proposals
+        self.encoder = encoder
+        self.decoder = decoder
+        self.embed_dims = self.encoder.embed_dims
+        self.mixed_selection = mixed_selection
+        self.with_pos_coord = with_pos_coord
+        self.with_coord_feat = with_coord_feat
+        self.num_co_heads = num_co_heads
+        self.num_feature_levels = num_feature_levels
+        self.init_layers()
+
+    def init_layers(self):
+        """Initialize layers of the DeformableDetrTransformer."""
+        if self.with_pos_coord:
+            if self.num_co_heads > 0:
+                # bug: this code should be 'self.head_pos_embed = nn.Embedding(self.num_co_heads, self.embed_dims)', we keep this bug for reproducing our results with ResNet-50.
+                # You can fix this bug when reproducing results with swin transformer.
+                self.head_pos_embed = nn.Embedding(
+                    self.num_co_heads, 1, 1, self.embed_dims
+                )
+                self.aux_pos_trans = nn.LayerList()
+                self.aux_pos_trans_norm = nn.LayerList()
+                self.pos_feats_trans = nn.LayerList()
+                self.pos_feats_norm = nn.LayerList()
+                for i in range(self.num_co_heads):
+                    self.aux_pos_trans.append(
+                        nn.Linear(self.embed_dims * 2, self.embed_dims * 2)
+                    )
+                    self.aux_pos_trans_norm.append(nn.LayerNorm(self.embed_dims * 2))
+                    if self.with_coord_feat:
+                        self.pos_feats_trans.append(
+                            nn.Linear(self.embed_dims, self.embed_dims)
+                        )
+                        self.pos_feats_norm.append(nn.LayerNorm(self.embed_dims))
+
+        self.level_embeds = paddle.create_parameter(
+            (self.num_feature_levels, self.embed_dims), dtype="float32"
+        )
+
+        if self.as_two_stage:
+            self.enc_output = nn.Linear(self.embed_dims, self.embed_dims)
+            self.enc_output_norm = nn.LayerNorm(self.embed_dims)
+            self.pos_trans = nn.Linear(self.embed_dims * 2, self.embed_dims * 2)
+            self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2)
+        else:
+            self.reference_points = nn.Linear(self.embed_dims, 2)
+
+    def init_weights(self):
+        """Initialize the transformer weights."""
+        for p in self.parameters():
+            if p.rank() > 1:
+                xavier_uniform_(p)
+                if hasattr(p, "bias") and p.bias is not None:
+                    constant_(p.bais)
+        for m in self.sublayers():
+            if isinstance(m, MSDeformableAttention):
+                m._reset_parameters()
+        if not self.as_two_stage:
+            xavier_uniform_(self.reference_points.weight)
+            constant_(self.reference_points.bias)
+        normal_(self.level_embeds)
+
+    def get_proposal_pos_embed(self, proposals, num_pos_feats=128, temperature=10000):
+        """Get the position embedding of proposal."""
+        num_pos_feats = self.embed_dims // 2
+        scale = 2 * math.pi
+        dim_t = paddle.arange(num_pos_feats, dtype=paddle.float32)
+        dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
+        # N, L, 4
+        proposals = proposals.sigmoid() * scale
+        # N, L, 4, 128
+        pos = proposals[:, :, :, None] / dim_t
+        # N, L, 4, 64, 2
+        pos = paddle.stack(
+            (pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), axis=4
+        ).flatten(2)
+
+        return pos
+
+    def get_valid_ratio(self, mask):
+        """Get the valid radios of feature maps of all level."""
+        _, H, W = mask.shape
+        valid_H = paddle.sum(paddle.logical_not(mask[:, :, 0]).astype("float"), 1)
+        valid_W = paddle.sum(paddle.logical_not(mask[:, 0, :]).astype("float"), 1)
+        valid_ratio_h = valid_H.astype("float") / H
+        valid_ratio_w = valid_W.astype("float") / W
+        valid_ratio = paddle.stack([valid_ratio_w, valid_ratio_h], -1)
+        return valid_ratio
+
+    def get_proposal_pos_embed(self, proposals, num_pos_feats=128, temperature=10000):
+        """Get the position embedding of proposal."""
+        scale = 2 * math.pi
+        dim_t = paddle.arange(num_pos_feats, dtype="float32")
+        dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
+        # N, L, 4
+        proposals = F.sigmoid(proposals) * scale
+        # N, L, 4, 128
+        pos = proposals[:, :, :, None] / dim_t
+        # N, L, 4, 64, 2
+        pos = paddle.stack(
+            (pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), axis=4
+        ).flatten(2)
+        return pos
+
+    @staticmethod
+    def get_reference_points(spatial_shapes, valid_ratios):
+        """Get the reference points used in decoder.
+
+        Args:
+            spatial_shapes (Tensor): The shape of all feature maps,
+                has shape (num_level, 2).
+            valid_ratios (Tensor): The radios of valid points on the
+                feature map, has shape (bs, num_levels, 2).
+
+        Returns:
+            Tensor: reference points used in decoder, has \
+                shape (bs, num_keys, num_levels, 2).
+        """
+        reference_points_list = []
+        for lvl, (H, W) in enumerate(spatial_shapes):
+            ref_y, ref_x = paddle.meshgrid(
+                paddle.linspace(0.5, H - 0.5, H, dtype="float32"),
+                paddle.linspace(0.5, W - 0.5, W, dtype="float32"),
+            )
+            ref_y = ref_y.reshape((-1,))[None] / (valid_ratios[:, None, lvl, 1] * H)
+            ref_x = ref_x.reshape((-1,))[None] / (valid_ratios[:, None, lvl, 0] * W)
+            ref = paddle.stack((ref_x, ref_y), -1)
+            reference_points_list.append(ref)
+        reference_points = paddle.concat(reference_points_list, 1)
+        reference_points = reference_points[:, :, None] * valid_ratios[:, None]
+        return reference_points
+
+    def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
+        """Generate proposals from encoded memory.
+
+        Args:
+            memory (Tensor): The output of encoder, has shape
+                (bs, num_key, embed_dim). num_key is equal the number of points
+                on feature map from all level.
+            memory_padding_mask (Tensor): Padding mask for memory.
+                has shape (bs, num_key).
+            spatial_shapes (Tensor): The shape of all feature maps.
+                has shape (num_level, 2).
+
+        Returns:
+            tuple: A tuple of feature map and bbox prediction.
+
+                - output_memory (Tensor): The input of decoder, has shape
+                    (bs, num_key, embed_dim). num_key is equal the number of
+                    points on feature map from all levels.
+                - output_proposals (Tensor): The normalized proposal
+                    after a inverse sigmoid, has shape (bs, num_keys, 4).
+        """
+
+        N, S, C = memory.shape
+        proposals = []
+        _cur = 0
+        
+        for lvl, (H, W) in enumerate(spatial_shapes):
+            mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H * W)].reshape(
+                [N, H, W, 1]
+            )
+
+            valid_H = paddle.sum(paddle.logical_not(mask_flatten_[:, :, 0, 0]).astype("float"), 1)
+            valid_W = paddle.sum(paddle.logical_not(mask_flatten_[:, 0, :, 0]).astype("float"), 1)
+
+            grid_y, grid_x = paddle.meshgrid(
+                paddle.linspace(0, H - 1, H, dtype="float32"),
+                paddle.linspace(0, W - 1, W, dtype="float32"),
+            )
+            grid = paddle.concat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
+
+            scale = paddle.concat(
+                [valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1
+            ).reshape([N, 1, 1, 2])
+            grid = (grid.unsqueeze(0).expand((N, -1, -1, -1)) + 0.5) / scale
+            wh = paddle.ones_like(grid) * 0.05 * (2.0**lvl)
+            proposal = paddle.concat((grid, wh), -1).reshape([N, -1, 4])
+            proposals.append(proposal)
+            _cur += H * W
+        output_proposals = paddle.concat(proposals, 1)
+        output_proposals_valid = (
+            ((output_proposals > 0.01) & (output_proposals < 0.99))
+            .all(-1, keepdim=True)
+        )
+        output_proposals = paddle.log(output_proposals / (1 - output_proposals))
+        output_proposals = output_proposals.masked_fill(
+            memory_padding_mask.unsqueeze(-1),
+            float("inf"),
+        )
+        output_proposals = output_proposals.masked_fill(
+            ~output_proposals_valid, float("inf")
+        )
+
+        output_memory = memory
+        output_memory = output_memory.masked_fill(
+            memory_padding_mask.unsqueeze(-1), float(0)
+        )
+        output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
+        output_memory = self.enc_output_norm(self.enc_output(output_memory))
+        return output_memory, output_proposals
+
+    def forward(
+        self,
+        mlvl_feats,
+        mlvl_masks,
+        query_embed,
+        mlvl_pos_embeds,
+        reg_branches=None,
+        cls_branches=None,
+        return_encoder_output=False,
+        attn_masks=None,
+        **kwargs
+    ):
+        """Forward function for `Transformer`.
+
+        Args:
+            mlvl_feats (list(Tensor)): Input queries from
+                different level. Each element has shape
+                [bs, embed_dims, h, w].
+            mlvl_masks (list(Tensor)): The key_padding_mask from
+                different level used for encoder and decoder,
+                each element has shape  [bs, h, w].
+            query_embed (Tensor): The query embedding for decoder,
+                with shape [num_query, c].
+            mlvl_pos_embeds (list(Tensor)): The positional encoding
+                of feats from different level, has the shape
+                 [bs, embed_dims, h, w].
+            reg_branches (obj:`nn.ModuleList`): Regression heads for
+                feature maps from each decoder layer. Only would
+                be passed when
+                `with_box_refine` is True. Default to None.
+            cls_branches (obj:`nn.ModuleList`): Classification heads
+                for feature maps from each decoder layer. Only would
+                 be passed when `as_two_stage`
+                 is True. Default to None.
+
+
+        Returns:
+            tuple[Tensor]: results of decoder containing the following tensor.
+
+                - inter_states: Outputs from decoder. If
+                    return_intermediate_dec is True output has shape \
+                      (num_dec_layers, bs, num_query, embed_dims), else has \
+                      shape (1, bs, num_query, embed_dims).
+                - init_reference_out: The initial value of reference \
+                    points, has shape (bs, num_queries, 4).
+                - inter_references_out: The internal value of reference \
+                    points in decoder, has shape \
+                    (num_dec_layers, bs,num_query, embed_dims)
+                - enc_outputs_class: The classification score of \
+                    proposals generated from \
+                    encoder's feature maps, has shape \
+                    (batch, h*w, num_classes). \
+                    Only would be returned when `as_two_stage` is True, \
+                    otherwise None.
+                - enc_outputs_coord_unact: The regression results \
+                    generated from encoder's feature maps., has shape \
+                    (batch, h*w, 4). Only would \
+                    be returned when `as_two_stage` is True, \
+                    otherwise None.
+        """
+        assert self.as_two_stage or query_embed is not None
+
+        feat_flatten = []
+        mask_flatten = []
+        lvl_pos_embed_flatten = []
+        spatial_shapes = []
+        for lvl, (feat, mask, pos_embed) in enumerate(
+            zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)
+        ):
+            bs, c, h, w = feat.shape
+            spatial_shape = (h, w)
+            spatial_shapes.append(spatial_shape)
+            feat = feat.flatten(2).transpose((0, 2, 1))
+            mask = mask.flatten(1)
+            pos_embed = pos_embed.flatten(2).transpose((0, 2, 1))
+            lvl_pos_embed = pos_embed + self.level_embeds[lvl].reshape((1, 1, -1))            
+            lvl_pos_embed_flatten.append(lvl_pos_embed)
+            feat_flatten.append(feat)
+            mask_flatten.append(mask)
+
+        feat_flatten = paddle.concat(feat_flatten, 1)
+        mask_flatten = paddle.concat(mask_flatten, 1)
+        lvl_pos_embed_flatten = paddle.concat(lvl_pos_embed_flatten, 1)
+
+        spatial_shapes = paddle.to_tensor(spatial_shapes, dtype="int64")
+        # [l], 每一个level的起始index
+        level_start_index = paddle.concat(
+            [paddle.zeros([1], dtype="int64"), spatial_shapes.prod(1).cumsum(0)[:-1]]
+        )
+
+        valid_ratios = paddle.stack([self.get_valid_ratio(m) for m in mlvl_masks], 1)
+        reference_points = self.get_reference_points(spatial_shapes, valid_ratios)
+        
+        memory = self.encoder(
+            src=feat_flatten,
+            pos_embed=lvl_pos_embed_flatten,
+            src_mask=mask_flatten,
+            value_spatial_shapes=spatial_shapes,
+            reference_points=reference_points,
+            value_level_start_index=level_start_index,
+            valid_ratios=valid_ratios,
+        )
+        
+        bs, _, c = memory.shape
+        if self.as_two_stage:
+            output_memory, output_proposals = self.gen_encoder_output_proposals(
+                memory, mask_flatten, spatial_shapes
+            )
+            enc_outputs_class = cls_branches[self.decoder.num_layers](output_memory)
+            enc_outputs_coord_unact = (
+                reg_branches[self.decoder.num_layers](output_memory) + output_proposals
+            )
+            topk = self.two_stage_num_proposals
+            # We only use the first channel in enc_outputs_class as foreground,
+            # the other (num_classes - 1) channels are actually not used.
+            # Its targets are set to be 0s, which indicates the first
+            # class (foreground) because we use [0, num_classes - 1] to
+            # indicate class labels, background class is indicated by
+            # num_classes (similar convention in RPN).
+            topk_proposals = paddle.topk(enc_outputs_class[..., 0], topk, axis=1)[1]
+            # paddle.take_along_axis 对应torch.gather
+            topk_coords_unact = paddle.take_along_axis(
+                enc_outputs_coord_unact, topk_proposals.unsqueeze(-1).tile([1, 1, 4]),axis=1
+            )
+            topk_coords_unact = topk_coords_unact.detach()
+            reference_points = F.sigmoid(topk_coords_unact)
+            init_reference_out = reference_points
+            pos_trans_out = self.pos_trans_norm(
+                self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact.astype('float32')))
+            )
+            if not self.mixed_selection:
+                query_pos, query = paddle.split(pos_trans_out, pos_trans_out.shape[2]//c, axis=2)
+            else:
+                # query_embed here is the content embed for deformable DETR
+                query = query_embed.unsqueeze(0).expand([bs, -1, -1])
+                query_pos, _ = paddle.split(pos_trans_out, pos_trans_out.shape[2]//c, axis=2)
+        else:
+            query_pos, query = paddle.split(query_embed, query_embed.shape[1]//c, axis=1)
+            query_pos = query_pos.unsqueeze(0).expand([bs, -1, -1])
+            query = query.unsqueeze(0).expand([bs, -1, -1])
+            reference_points = F.sigmoid(self.reference_points(query_pos))
+            init_reference_out = reference_points
+
+        # decoder
+        inter_states, inter_references = self.decoder(
+            query=query,
+            memory=memory,
+            query_pos_embed=query_pos,  # error
+            memory_mask=mask_flatten,
+            reference_points=reference_points, # error
+            value_spatial_shapes=spatial_shapes,
+            value_level_start_index=level_start_index,
+            valid_ratios=valid_ratios,
+            reg_branches=reg_branches,
+            attn_masks=attn_masks,
+            **kwargs
+        )
+        inter_references_out = inter_references
+        if self.as_two_stage:
+            if return_encoder_output:
+                return (
+                    inter_states,
+                    init_reference_out,
+                    inter_references_out,
+                    enc_outputs_class,
+                    enc_outputs_coord_unact,
+                    memory,
+                )
+            return (
+                inter_states,
+                init_reference_out,
+                inter_references_out,
+                enc_outputs_class,
+                enc_outputs_coord_unact,
+            )
+        if return_encoder_output:
+            return (
+                inter_states,
+                init_reference_out,
+                inter_references_out,
+                None,
+                None,
+                memory,
+            )
+        return inter_states, init_reference_out, inter_references_out, None, None
+
+    def forward_aux(
+        self,
+        mlvl_feats,
+        mlvl_masks,
+        query_embed,
+        mlvl_pos_embeds,
+        pos_anchors,
+        pos_feats=None,
+        reg_branches=None,
+        cls_branches=None,
+        return_encoder_output=False,
+        attn_masks=None,
+        head_idx=0,
+        **kwargs
+    ):
+        feat_flatten = []
+        mask_flatten = []
+        spatial_shapes = []
+        for lvl, (feat, mask, pos_embed) in enumerate(
+            zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)
+        ):
+            bs, c, h, w = feat.shape
+            spatial_shape = (h, w)
+            spatial_shapes.append(spatial_shape)
+            feat = feat.flatten(2).transpose((0,2,1))
+            mask = mask.flatten(1)
+            feat_flatten.append(feat)
+            mask_flatten.append(mask)
+
+        feat_flatten = paddle.concat(feat_flatten, 1)
+        mask_flatten = paddle.concat(mask_flatten, 1)
+        spatial_shapes = paddle.to_tensor(spatial_shapes,dtype=paddle.int64)
+        # [l], 每一个level的起始index
+        level_start_index = paddle.concat(
+            [paddle.zeros([1], dtype="int64"), spatial_shapes.prod(1).cumsum(0)[:-1]]
+        )
+        valid_ratios = paddle.stack([self.get_valid_ratio(m) for m in mlvl_masks], 1)
+
+        memory = feat_flatten
+        bs, _, c = memory.shape
+        topk = pos_anchors.shape[1]
+        topk_coords_unact = inverse_sigmoid((pos_anchors))
+        reference_points = pos_anchors
+        init_reference_out = reference_points
+        if self.num_co_heads > 0:
+            pos_trans_out = self.aux_pos_trans_norm[head_idx](
+                self.aux_pos_trans[head_idx](
+                    self.get_proposal_pos_embed(topk_coords_unact)
+                )
+            )            
+            query_pos, query = paddle.split(pos_trans_out, pos_trans_out.shape[2]//c, axis=2)
+            if self.with_coord_feat:
+                query = query + self.pos_feats_norm[head_idx](
+                    self.pos_feats_trans[head_idx](pos_feats)
+                )
+                query_pos = query_pos + self.head_pos_embed.weight[head_idx]
+
+        # decoder
+        inter_states, inter_references = self.decoder(
+            query=query,
+            memory=memory,
+            query_pos_embed=query_pos,  # error
+            memory_mask=mask_flatten,
+            reference_points=reference_points, # error
+            value_spatial_shapes=spatial_shapes,
+            value_level_start_index=level_start_index,
+            valid_ratios=valid_ratios,
+            reg_branches=reg_branches,
+            attn_masks=attn_masks,
+            **kwargs
+        )
+
+        inter_references_out = inter_references
+        return inter_states, init_reference_out, inter_references_out