diff --git a/projects/BEVFusion/bevfusion/transforms_3d.py b/projects/BEVFusion/bevfusion/transforms_3d.py index b5fbcfdb..5441c1d1 100644 --- a/projects/BEVFusion/bevfusion/transforms_3d.py +++ b/projects/BEVFusion/bevfusion/transforms_3d.py @@ -24,7 +24,12 @@ def sample_augmentation(self, results): H, W = results["ori_shape"] fH, fW = self.final_dim if self.is_train: - resize = np.random.uniform(*self.resize_lim) + if isinstance(self.resize_lim, (int, float)): + aspect_ratio = min(fH / H, fW / W) + resize = np.random.uniform(aspect_ratio - self.resize_lim, aspect_ratio + self.resize_lim) + else: + resize = np.random.uniform(*self.resize_lim) + resize_dims = (int(W * resize), int(H * resize)) newW, newH = resize_dims crop_h = int((1 - np.random.uniform(*self.bot_pct_lim)) * newH) - fH @@ -35,7 +40,7 @@ def sample_augmentation(self, results): flip = True rotate = np.random.uniform(*self.rot_lim) else: - resize = np.mean(self.resize_lim) + resize = min(fH / H, fW / W) resize_dims = (int(W * resize), int(H * resize)) newW, newH = resize_dims crop_h = int((1 - np.mean(self.bot_pct_lim)) * newH) - fH @@ -52,7 +57,7 @@ def img_transform(self, img, rotation, translation, resize, resize_dims, crop, f img = img.crop(crop) if flip: img = img.transpose(method=Image.FLIP_LEFT_RIGHT) - img = img.rotate(rotate) + img = img.rotate(rotate, resample=Image.BICUBIC) # Default rotation introduces artifacts. # post-homography transformation rotation *= resize diff --git a/projects/BEVFusion/configs/t4dataset/BEVFusion-CL-offline/bevfusion_camera_lidar_voxel_second_secfpn_2xb2_t4offline_no_intensity.py b/projects/BEVFusion/configs/t4dataset/BEVFusion-CL-offline/bevfusion_camera_lidar_offline_voxel_second_secfpn_4xb8_base.py similarity index 80% rename from projects/BEVFusion/configs/t4dataset/BEVFusion-CL-offline/bevfusion_camera_lidar_voxel_second_secfpn_2xb2_t4offline_no_intensity.py rename to projects/BEVFusion/configs/t4dataset/BEVFusion-CL-offline/bevfusion_camera_lidar_offline_voxel_second_secfpn_4xb8_base.py index 4813b5d8..130a4a89 100644 --- a/projects/BEVFusion/configs/t4dataset/BEVFusion-CL-offline/bevfusion_camera_lidar_voxel_second_secfpn_2xb2_t4offline_no_intensity.py +++ b/projects/BEVFusion/configs/t4dataset/BEVFusion-CL-offline/bevfusion_camera_lidar_offline_voxel_second_secfpn_4xb8_base.py @@ -8,9 +8,10 @@ # user setting data_root = "data/t4dataset/" -info_directory_path = "info/user_name/" -train_gpu_size = 2 +info_directory_path = "info/username/" +train_gpu_size = 4 train_batch_size = 2 +test_batch_size = 2 val_interval = 5 max_epochs = 30 backend_args = None @@ -19,43 +20,45 @@ point_cloud_range = [-122.4, -122.4, -3.0, 122.4, 122.4, 5.0] voxel_size = [0.075, 0.075, 0.2] grid_size = [3264, 3264, 41] + eval_class_range = { - "car": 121, - "truck": 121, - "bus": 121, - "bicycle": 121, - "pedestrian": 121, + "car": 120, + "truck": 120, + "bus": 120, + "bicycle": 120, + "pedestrian": 120, } # model parameter input_modality = dict(use_lidar=True, use_camera=True) point_load_dim = 5 # x, y, z, intensity, ring_id -point_use_dim = 5 -point_intensity_dim = 3 +sweeps_num = 1 max_num_points = 10 max_voxels = [120000, 160000] num_proposals = 500 -image_size = [256, 704] -lidar_sweep_dims = [0, 1, 2, 4] -num_workers = 1 -sweeps_num = 1 +image_size = [576, 864] # height, width +num_workers = 32 +lidar_sweep_dims = [0, 1, 2, 4] # x, y, z, time_lag +lidar_feature_dims = 4 model = dict( type="BEVFusion", data_preprocessor=dict( + type="Det3DDataPreprocessor", + pad_size_divisor=32, voxelize_cfg=dict( max_num_points=max_num_points, - point_cloud_range=point_cloud_range, voxel_size=voxel_size, + point_cloud_range=point_cloud_range, max_voxels=max_voxels, + deterministic=True, + voxelize_reduce=True, ), - type="Det3DDataPreprocessor", mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], bgr_to_rgb=False, ), - pts_voxel_encoder=dict(type="HardSimpleVFE", num_features=4), - pts_middle_encoder=dict(in_channels=4, sparse_shape=grid_size), + pts_middle_encoder=dict(sparse_shape=grid_size, in_channels=lidar_feature_dims), img_backbone=dict( type="mmdet.SwinTransformer", embed_dims=96, @@ -92,16 +95,11 @@ in_channels=256, out_channels=80, image_size=image_size, - feature_size=[32, 88], - # xbound=[-54.0, 54.0, 0.3], - # ybound=[-54.0, 54.0, 0.3], - # xbound=[-122.4, 122.4, 0.68], - # ybound=[-122.4, 122.4, 0.68], + feature_size=[72, 108], xbound=[-122.4, 122.4, 0.3], ybound=[-122.4, 122.4, 0.3], zbound=[-10.0, 10.0, 20.0], - # dbound=[1.0, 60.0, 0.5], - dbound=[1.0, 166.2, 1.4], + dbound=[1.0, 134, 1.4], downsample=2, ), fusion_layer=dict(type="ConvFuser", in_channels=[80, 256], out_channels=256), @@ -115,6 +113,7 @@ code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2], ), test_cfg=dict( + dataset="t4datasets", grid_size=grid_size, voxel_size=voxel_size[0:2], pc_range=point_cloud_range[0:2], @@ -124,49 +123,13 @@ voxel_size=voxel_size[0:2], ), ), + # Lidar pipeline + pts_voxel_encoder=dict(num_features=lidar_feature_dims), ) -# TODO: support object sample -# db_sampler = dict( -# data_root=data_root, -# info_path=data_root +'nuscenes_dbinfos_train.pkl', -# rate=1.0, -# prepare=dict( -# filter_by_difficulty=[-1], -# filter_by_min_points=dict( -# car=5, -# truck=5, -# bus=5, -# trailer=5, -# construction_vehicle=5, -# traffic_cone=5, -# barrier=5, -# motorcycle=5, -# bicycle=5, -# pedestrian=5)), -# classes=class_names, -# sample_groups=dict( -# car=2, -# truck=3, -# construction_vehicle=7, -# bus=4, -# trailer=6, -# barrier=2, -# motorcycle=6, -# bicycle=6, -# pedestrian=2, -# traffic_cone=2), -# points_loader=dict( -# type='LoadPointsFromFile', -# coord_type='LIDAR', -# load_dim=5, -# use_dim=[0, 1, 2, 3, 4], -# backend_args=backend_args)) - train_pipeline = [ dict( type="BEVLoadMultiViewImageFromFiles", - data_root=data_root, to_float32=True, color_type="color", backend_args=backend_args, @@ -175,39 +138,31 @@ type="LoadPointsFromFile", coord_type="LIDAR", load_dim=point_load_dim, - use_dim=point_use_dim, + use_dim=point_load_dim, backend_args=backend_args, ), - # TODO: add feature - # dict( - # type="IntensityNorm", - # alpha=10.0, - # intensity_dim=point_intensity_dim, - # div_factor=255.0, - # ), dict( type="LoadPointsFromMultiSweeps", sweeps_num=sweeps_num, - load_dim=5, + load_dim=point_load_dim, use_dim=lidar_sweep_dims, pad_empty_sweeps=True, remove_close=True, backend_args=backend_args, + test_mode=False, ), dict(type="LoadAnnotations3D", with_bbox_3d=True, with_label_3d=True, with_attr_label=False), - # TODO: support object sample - # dict(type='ObjectSample', db_sampler=db_sampler), dict( type="ImageAug3D", final_dim=image_size, - resize_lim=[0.38, 0.55], + resize_lim=0.02, bot_pct_lim=[0.0, 0.0], rot_lim=[-5.4, 5.4], rand_flip=True, is_train=True, ), dict( - type="GlobalRotScaleTrans", + type="BEVFusionGlobalRotScaleTrans", rot_range=[-1.571, 1.571], scale_ratio_range=[0.8, 1.2], translation_std=[1.0, 1.0, 0.2], @@ -259,23 +214,21 @@ test_pipeline = [ dict( type="BEVLoadMultiViewImageFromFiles", - data_root=data_root, to_float32=True, color_type="color", backend_args=backend_args, - test_mode=True, ), dict( type="LoadPointsFromFile", coord_type="LIDAR", - load_dim=5, - use_dim=5, + load_dim=point_load_dim, + use_dim=point_load_dim, backend_args=backend_args, ), dict( type="LoadPointsFromMultiSweeps", sweeps_num=sweeps_num, - load_dim=5, + load_dim=point_load_dim, use_dim=lidar_sweep_dims, pad_empty_sweeps=True, remove_close=True, @@ -285,7 +238,7 @@ dict( type="ImageAug3D", final_dim=image_size, - resize_lim=[0.48, 0.48], + resize_lim=0.02, bot_pct_lim=[0.0, 0.0], rot_lim=[0.0, 0.0], rand_flip=False, @@ -313,6 +266,8 @@ ), ] +filter_cfg = dict(filter_frames_with_missing_image=True) + train_dataloader = dict( batch_size=train_batch_size, num_workers=num_workers, @@ -320,20 +275,22 @@ sampler=dict(type="DefaultSampler", shuffle=True), dataset=dict( type=_base_.dataset_type, + pipeline=train_pipeline, + modality=input_modality, + backend_args=backend_args, data_root=data_root, ann_file=info_directory_path + _base_.info_train_file_name, - pipeline=train_pipeline, metainfo=_base_.metainfo, class_names=_base_.class_names, - modality=input_modality, - data_prefix=_base_.data_prefix, test_mode=False, + data_prefix=_base_.data_prefix, box_type_3d="LiDAR", - backend_args=backend_args, + filter_cfg=filter_cfg, ), ) + val_dataloader = dict( - batch_size=2, + batch_size=test_batch_size, num_workers=num_workers, persistent_workers=True, sampler=dict(type="DefaultSampler", shuffle=False), @@ -351,8 +308,9 @@ backend_args=backend_args, ), ) + test_dataloader = dict( - batch_size=2, + batch_size=test_batch_size, num_workers=num_workers, persistent_workers=True, sampler=dict(type="DefaultSampler", shuffle=False), @@ -392,6 +350,7 @@ name_mapping=_base_.name_mapping, eval_class_range=eval_class_range, filter_attributes=_base_.filter_attributes, + save_csv=True, ) # learning rate @@ -456,14 +415,8 @@ clip_grad=dict(max_norm=35, norm_type=2), ) -# Default setting for scaling LR automatically -# - `enable` means enable scaling LR automatically -# or not by default. -# - `base_batch_size` = (8 GPUs) x (4 samples per GPU). -# auto_scale_lr = dict(enable=False, base_batch_size=32) -auto_scale_lr = dict(enable=False, base_batch_size=train_gpu_size * train_batch_size) +auto_scale_lr = dict(enable=True, base_batch_size=4) +# Only set if the number of train_gpu_size more than 1 if train_gpu_size > 1: sync_bn = "torch" - -randomness = dict(seed=0, diff_rank_seed=False, deterministic=True) diff --git a/projects/BEVFusion/configs/t4dataset/BEVFusion-CL/bevfusion_camera_lidar_voxel_second_secfpn_4xb8_base.py b/projects/BEVFusion/configs/t4dataset/BEVFusion-CL/bevfusion_camera_lidar_voxel_second_secfpn_4xb8_base.py index 7fd9c26f..dfce02c4 100644 --- a/projects/BEVFusion/configs/t4dataset/BEVFusion-CL/bevfusion_camera_lidar_voxel_second_secfpn_4xb8_base.py +++ b/projects/BEVFusion/configs/t4dataset/BEVFusion-CL/bevfusion_camera_lidar_voxel_second_secfpn_4xb8_base.py @@ -35,7 +35,7 @@ max_num_points = 10 max_voxels = [120000, 160000] num_proposals = 500 -image_size = [256, 704] +image_size = [384, 576] # height, width num_workers = 32 lidar_sweep_dims = [0, 1, 2, 4] # x, y, z, time_lag lidar_feature_dims = 4 @@ -94,14 +94,11 @@ in_channels=256, out_channels=80, image_size=image_size, - feature_size=[32, 88], - # xbound=[-54.0, 54.0, 0.3], - # ybound=[-54.0, 54.0, 0.3], + feature_size=[48, 72], xbound=[-122.4, 122.4, 0.68], ybound=[-122.4, 122.4, 0.68], zbound=[-10.0, 10.0, 20.0], - # dbound=[1.0, 60.0, 0.5], - dbound=[1.0, 166.2, 1.4], + dbound=[1.0, 134, 1.4], downsample=2, ), fusion_layer=dict(type="ConvFuser", in_channels=[80, 256], out_channels=256), @@ -157,14 +154,14 @@ dict( type="ImageAug3D", final_dim=image_size, - resize_lim=[0.38, 0.55], + resize_lim=0.02, bot_pct_lim=[0.0, 0.0], rot_lim=[-5.4, 5.4], rand_flip=True, is_train=True, ), dict( - type="GlobalRotScaleTrans", + type="BEVFusionGlobalRotScaleTrans", rot_range=[-1.571, 1.571], scale_ratio_range=[0.8, 1.2], translation_std=[1.0, 1.0, 0.2], @@ -240,7 +237,7 @@ dict( type="ImageAug3D", final_dim=image_size, - resize_lim=[0.48, 0.48], + resize_lim=0.02, bot_pct_lim=[0.0, 0.0], rot_lim=[0.0, 0.0], rand_flip=False, @@ -421,8 +418,7 @@ # - `enable` means enable scaling LR automatically # or not by default. # - `base_batch_size` = (8 GPUs) x (4 samples per GPU). -# auto_scale_lr = dict(enable=False, base_batch_size=32) -auto_scale_lr = dict(enable=False, base_batch_size=train_gpu_size * train_batch_size) +auto_scale_lr = dict(enable=True, base_batch_size=32) # Only set if the number of train_gpu_size more than 1 if train_gpu_size > 1: diff --git a/projects/BEVFusion/docs/BEVFusion-CL-offline/v2/base.md b/projects/BEVFusion/docs/BEVFusion-CL-offline/v2/base.md new file mode 100644 index 00000000..6a68ca89 --- /dev/null +++ b/projects/BEVFusion/docs/BEVFusion-CL-offline/v2/base.md @@ -0,0 +1,41 @@ +# Deployed model for BEVFusion-CL base/2.X +## Summary + +### Overview + +| Eval range: 120m | mAP | car | truck | bus | bicycle | pedestrian | +| --------------------------------| ---- | ---- | ----- | ---- | ------- | ---------- | +| BEVFusion-CL-offline base/2.0.0 | 77.8 | 87.30 | 61.60 | 85.90 | 73.20 | 80.90 | +| BEVFusion-CL base/2.0.0 | 76.3 | 80.50 | 61.90 | 85.90 | 74.70 | 78.70 | + + +## Release + +### BEVFusion-CL-offline base/2.0.0 + +
+ The link of data and evaluation result + +- Model + - Training dataset: DB JPNTAXI v1.0 + DB JPNTAXI v2.0 + DB JPNTAXI v4.0 + DB GSM8 v1.0 + DB J6 v1.0 + DB J6 v2.0 + DB J6 v3.0 + DB J6 v5.0 + DB J6 Gen2 v1.0 + DB J6 Gen2 v2.0 + DB J6 Gen2 v4.0 + DB LargeBus v1.0 (total frames: 71,633) + - [Config file path](https://github.com/tier4/AWML/blob/50f35a8ae52c4892351be0c7aa5d260c1b310b7e/projects/BEVFusion/configs/t4dataset/BEVFusion-CL-offline/bevfusion_camera_lidar_offline_voxel_second_secfpn_4xb8_base.py) + - Training results [model-zoo] + - [logs.zip](https://download.autoware-ml-model-zoo.tier4.jp/autoware-ml/models/bevfusion/bevfusion-cl-offline/t4base/v2.0.0/logs.zip) + - [checkpoint_best.pth](https://download.autoware-ml-model-zoo.tier4.jp/autoware-ml/models/bevfusion/bevfusion-cl-offline/t4base/v2.0.0/best_NuScenes_metric_T4Metric_mAP_epoch_30.pth) + - [config.py](https://download.autoware-ml-model-zoo.tier4.jp/autoware-ml/models/bevfusion/bevfusion-cl-offline/t4base/v2.0.0/bevfusion_camera_lidar_voxel_second_secfpn_2xb2_t4offline_no_intensity.py) + - Train time: NVIDIA H100 80GB * 4 * 50 epochs = 3 days and 20 hours + - Batch size: 4*5 = 20 + +- Evaluation + - db_jpntaxi_v1 + db_jpntaxi_v2 + db_jpntaxi_v4 + db_gsm8_v1 + db_j6_v1 + db_j6_v2 + db_j6_v3 + db_j6_v5 + db_j6gen2_v1 + db_j6gen2_v1 + db_j6gen2_v4 + db_largebus_v1 (total frames: 5,703): + - Total mAP (eval range = 120m): 0.7503 + +| class_name | Count | mAP | AP@0.5m | AP@1.0m | AP@2.0m | AP@4.0m | +| ---- | ------- | ---- | ---- | ---- | ---- | ---- | +| car | 144,001 | 87.3 | 77.5 | 87.8 | 91.6 | 92.2 | +| truck | 20,823 | 61.6 | 41.0 | 61.3 | 69.0 | 74.9 | +| bus | 5,691 | 85.9 | 75.6 | 85.6 | 90.3 | 92.2 | +| bicycle | 5,007 | 73.2 | 71.4 | 73.5 | 73.7 | 74.1 | +| pedestrian | 42,034 | 80.9 | 79.5 | 80.5 | 81.3 | 82.3 | + +
diff --git a/projects/BEVFusion/docs/BEVFusion-CL/v2/base.md b/projects/BEVFusion/docs/BEVFusion-CL/v2/base.md index 0ef9f423..eaf2f9cb 100644 --- a/projects/BEVFusion/docs/BEVFusion-CL/v2/base.md +++ b/projects/BEVFusion/docs/BEVFusion-CL/v2/base.md @@ -7,7 +7,8 @@ | Eval range: 120m | mAP | car | truck | bus | bicycle | pedestrian | | --------------------------------| ---- | ---- | ----- | ---- | ------- | ---------- | | BEVFusion-CL base/2.0.0 (A) | 70.72 | 81.04 | **62.06** | 82.52 | **70.82** | 57.14 | -| BEVFusion-CL base/2.0.0 (B) | **75.03** | 79.62 | 61.20 | **86.67** | 69.99 | **77.62** | +| BEVFusion-CL base/2.0.0 (B) | **75.03** | 79.62 | 61.20 | **86.67** | 69.99 | **77.62** | +| BEVFusion-CL base/2.0.0 (C) | **76.3** | **80.50** | **61.90** | 85.90 | **74.70** | **78.70** | ### Datasets #### base @@ -40,12 +41,47 @@ | BEVFusion-CL base/2.0.0 (A) | 69.99 | 79.41 | 64.64 | 83.58 | 67.03 | 55.28 | | BEVFusion-CL base/2.0.0 (B) | 74.48 | 77.28 | 62.67 | 87.92 | 66.58 | 77.98 | +- BEVFusion-CL base/2.0.0 (A): Without intensity and training pedestrians with pooling pedestrians +- BEVFusion-CL base/2.0.0 (B): Same as `BEVFusion-CL base/2.0.0 (A)` without pooling pedestrians +- BEVFusion-CL base/2.0.0 (C): Same as `BEVFusion-CL base/2.0.0 (B)` with improved image ROI cropping, and augmentation parameter fixes. ## Release -### BEVFusion-CL base/2.0.0 -- BEVFusion-CL base/2.0.0 (A): Without intensity and training pedestrians with pooling pedestrians -- BEVFusion-CL base/2.0.0 (B): Same as `BEVFusion-CL base/2.0.0 (A)` without pooling pedestrians +### BEVFusion-CL base/2.0.0 (C) + +
+ The link of data and evaluation result + +- Model + - Training dataset: DB JPNTAXI v1.0 + DB JPNTAXI v2.0 + DB JPNTAXI v4.0 + DB GSM8 v1.0 + DB J6 v1.0 + DB J6 v2.0 + DB J6 v3.0 + DB J6 v5.0 + DB J6 Gen2 v1.0 + DB J6 Gen2 v2.0 + DB J6 Gen2 v4.0 + DB LargeBus v1.0 (total frames: 71,633) + - [Config file path](https://github.com/tier4/AWML/blob/50f35a8ae52c4892351be0c7aa5d260c1b310b7e/projects/BEVFusion/configs/t4dataset/BEVFusion-CL/bevfusion_camera_lidar_voxel_second_secfpn_4xb8_base.py) + - Deployed onnx model and ROS parameter files [[WebAuto (for internal)]](WIP) + - Deployed onnx and ROS parameter files [[model-zoo]] + - [image_backbone.onnx](https://download.autoware-ml-model-zoo.tier4.jp/autoware-ml/models/bevfusion/bevfusion-cl/t4base/v2.0.0/image_backbone.onnx) + - [main_body.onnx](https://download.autoware-ml-model-zoo.tier4.jp/autoware-ml/models/bevfusion/bevfusion-cl/t4base/v2.0.0/main_body.onnx) + - Training results [model-zoo] + - [logs.zip](https://download.autoware-ml-model-zoo.tier4.jp/autoware-ml/models/bevfusion/bevfusion-cl/t4base/v2.0.0/log.zip) + - [checkpoint_best.pth](https://download.autoware-ml-model-zoo.tier4.jp/autoware-ml/models/bevfusion/bevfusion-cl/t4base/v2.0.0/best_NuScenes_metric_T4Metric_mAP_epoch_48.pth) + - [config.py](https://download.autoware-ml-model-zoo.tier4.jp/autoware-ml/models/bevfusion/bevfusion-cl/t4base/v2.0.0/bevfusion_camera_lidar_voxel_second_secfpn_4xb8_base.py) + - [PR](https://github.com/tier4/AWML/pull/88) + - Train time: NVIDIA H100 80GB * 4 * 50 epochs = 3 days and 20 hours + - Batch size: 4*8 = 32 + +- Evaluation + - db_jpntaxi_v1 + db_jpntaxi_v2 + db_jpntaxi_v4 + db_gsm8_v1 + db_j6_v1 + db_j6_v2 + db_j6_v3 + db_j6_v5 + db_j6gen2_v1 + db_j6gen2_v1 + db_j6gen2_v4 + db_largebus_v1 (total frames: 5,703): + - Total mAP (eval range = 120m): 0.763 + +| class_name | Count | mAP | AP@0.5m | AP@1.0m | AP@2.0m | AP@4.0m | +| ---- | ------- | ---- | ---- | ---- | ---- | ---- | +| car | 144,001 | 80.5 | 69.2 | 80.5 | 85.1 | 87.2 | +| truck | 20,823 | 61.9 | 37.7 | 60.9 | 71.1 | 78.1 | +| bus | 5,691 | 85.9 | 71.9 | 86.0 | 92.1 | 93.5 | +| bicycle | 5,007 | 74.7 | 71.2 | 75.4 | 75.9 | 76.4 | +| pedestrian | 42,034 | 78.7 | 76.1 | 78.4 | 79.5 | 80.6 | + +
+ +### BEVFusion-CL base/2.0.0 (B) - We report only `BEVFusion-CL base/2.0.0 (B)` since the performance is much better than `BEVFusion-CL base/2.0.0 (A)`, and it is mainly due to it doesn't downsample the dense heatmaps for pedestrians, and thus it has more queries