Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions projects/BEVFusion/bevfusion/transforms_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ def sample_augmentation(self, results):
H, W = results["ori_shape"]
fH, fW = self.final_dim
if self.is_train:
resize = np.random.uniform(*self.resize_lim)
if isinstance(self.resize_lim, (int, float)):
aspect_ratio = min(fH / H, fW / W)
resize = np.random.uniform(aspect_ratio - self.resize_lim, aspect_ratio + self.resize_lim)
else:
resize = np.random.uniform(*self.resize_lim)

resize_dims = (int(W * resize), int(H * resize))
newW, newH = resize_dims
crop_h = int((1 - np.random.uniform(*self.bot_pct_lim)) * newH) - fH
Expand All @@ -35,7 +40,7 @@ def sample_augmentation(self, results):
flip = True
rotate = np.random.uniform(*self.rot_lim)
else:
resize = np.mean(self.resize_lim)
resize = min(fH / H, fW / W)
resize_dims = (int(W * resize), int(H * resize))
newW, newH = resize_dims
crop_h = int((1 - np.mean(self.bot_pct_lim)) * newH) - fH
Expand All @@ -52,7 +57,7 @@ def img_transform(self, img, rotation, translation, resize, resize_dims, crop, f
img = img.crop(crop)
if flip:
img = img.transpose(method=Image.FLIP_LEFT_RIGHT)
img = img.rotate(rotate)
img = img.rotate(rotate, resample=Image.BICUBIC) # Default rotation introduces artifacts.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice if you have examples showing artifacts with the default rotation


# post-homography transformation
rotation *= resize
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

# user setting
data_root = "data/t4dataset/"
info_directory_path = "info/user_name/"
train_gpu_size = 2
info_directory_path = "info/username/"
train_gpu_size = 4
train_batch_size = 2
test_batch_size = 2
val_interval = 5
max_epochs = 30
backend_args = None
Expand All @@ -19,43 +20,45 @@
point_cloud_range = [-122.4, -122.4, -3.0, 122.4, 122.4, 5.0]
voxel_size = [0.075, 0.075, 0.2]
grid_size = [3264, 3264, 41]

eval_class_range = {
"car": 121,
"truck": 121,
"bus": 121,
"bicycle": 121,
"pedestrian": 121,
"car": 120,
"truck": 120,
"bus": 120,
"bicycle": 120,
"pedestrian": 120,
}

# model parameter
input_modality = dict(use_lidar=True, use_camera=True)
point_load_dim = 5 # x, y, z, intensity, ring_id
point_use_dim = 5
point_intensity_dim = 3
sweeps_num = 1
max_num_points = 10
max_voxels = [120000, 160000]
num_proposals = 500
image_size = [256, 704]
lidar_sweep_dims = [0, 1, 2, 4]
num_workers = 1
sweeps_num = 1
image_size = [576, 864] # height, width
num_workers = 32
lidar_sweep_dims = [0, 1, 2, 4] # x, y, z, time_lag
lidar_feature_dims = 4

model = dict(
type="BEVFusion",
data_preprocessor=dict(
type="Det3DDataPreprocessor",
pad_size_divisor=32,
voxelize_cfg=dict(
max_num_points=max_num_points,
point_cloud_range=point_cloud_range,
voxel_size=voxel_size,
point_cloud_range=point_cloud_range,
max_voxels=max_voxels,
deterministic=True,
voxelize_reduce=True,
),
type="Det3DDataPreprocessor",
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=False,
),
pts_voxel_encoder=dict(type="HardSimpleVFE", num_features=4),
pts_middle_encoder=dict(in_channels=4, sparse_shape=grid_size),
pts_middle_encoder=dict(sparse_shape=grid_size, in_channels=lidar_feature_dims),
img_backbone=dict(
type="mmdet.SwinTransformer",
embed_dims=96,
Expand Down Expand Up @@ -92,16 +95,11 @@
in_channels=256,
out_channels=80,
image_size=image_size,
feature_size=[32, 88],
# xbound=[-54.0, 54.0, 0.3],
# ybound=[-54.0, 54.0, 0.3],
# xbound=[-122.4, 122.4, 0.68],
# ybound=[-122.4, 122.4, 0.68],
feature_size=[72, 108],
xbound=[-122.4, 122.4, 0.3],
ybound=[-122.4, 122.4, 0.3],
zbound=[-10.0, 10.0, 20.0],
# dbound=[1.0, 60.0, 0.5],
dbound=[1.0, 166.2, 1.4],
dbound=[1.0, 134, 1.4],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason we change it to 134, I am thinking we should make the depth and bin size even smaller, and make sure it's evenly divided by the bin size?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bin size: 1.4 could be a little big too large

downsample=2,
),
fusion_layer=dict(type="ConvFuser", in_channels=[80, 256], out_channels=256),
Expand All @@ -115,6 +113,7 @@
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2],
),
test_cfg=dict(
dataset="t4datasets",
grid_size=grid_size,
voxel_size=voxel_size[0:2],
pc_range=point_cloud_range[0:2],
Expand All @@ -124,49 +123,13 @@
voxel_size=voxel_size[0:2],
),
),
# Lidar pipeline
pts_voxel_encoder=dict(num_features=lidar_feature_dims),
)

# TODO: support object sample
# db_sampler = dict(
# data_root=data_root,
# info_path=data_root +'nuscenes_dbinfos_train.pkl',
# rate=1.0,
# prepare=dict(
# filter_by_difficulty=[-1],
# filter_by_min_points=dict(
# car=5,
# truck=5,
# bus=5,
# trailer=5,
# construction_vehicle=5,
# traffic_cone=5,
# barrier=5,
# motorcycle=5,
# bicycle=5,
# pedestrian=5)),
# classes=class_names,
# sample_groups=dict(
# car=2,
# truck=3,
# construction_vehicle=7,
# bus=4,
# trailer=6,
# barrier=2,
# motorcycle=6,
# bicycle=6,
# pedestrian=2,
# traffic_cone=2),
# points_loader=dict(
# type='LoadPointsFromFile',
# coord_type='LIDAR',
# load_dim=5,
# use_dim=[0, 1, 2, 3, 4],
# backend_args=backend_args))

train_pipeline = [
dict(
type="BEVLoadMultiViewImageFromFiles",
data_root=data_root,
to_float32=True,
color_type="color",
backend_args=backend_args,
Expand All @@ -175,39 +138,31 @@
type="LoadPointsFromFile",
coord_type="LIDAR",
load_dim=point_load_dim,
use_dim=point_use_dim,
use_dim=point_load_dim,
backend_args=backend_args,
),
# TODO: add feature
# dict(
# type="IntensityNorm",
# alpha=10.0,
# intensity_dim=point_intensity_dim,
# div_factor=255.0,
# ),
dict(
type="LoadPointsFromMultiSweeps",
sweeps_num=sweeps_num,
load_dim=5,
load_dim=point_load_dim,
use_dim=lidar_sweep_dims,
pad_empty_sweeps=True,
remove_close=True,
backend_args=backend_args,
test_mode=False,
),
dict(type="LoadAnnotations3D", with_bbox_3d=True, with_label_3d=True, with_attr_label=False),
# TODO: support object sample
# dict(type='ObjectSample', db_sampler=db_sampler),
dict(
type="ImageAug3D",
final_dim=image_size,
resize_lim=[0.38, 0.55],
resize_lim=0.02,
bot_pct_lim=[0.0, 0.0],
rot_lim=[-5.4, 5.4],
rand_flip=True,
is_train=True,
),
dict(
type="GlobalRotScaleTrans",
type="BEVFusionGlobalRotScaleTrans",
rot_range=[-1.571, 1.571],
scale_ratio_range=[0.8, 1.2],
translation_std=[1.0, 1.0, 0.2],
Expand Down Expand Up @@ -259,23 +214,21 @@
test_pipeline = [
dict(
type="BEVLoadMultiViewImageFromFiles",
data_root=data_root,
to_float32=True,
color_type="color",
backend_args=backend_args,
test_mode=True,
),
dict(
type="LoadPointsFromFile",
coord_type="LIDAR",
load_dim=5,
use_dim=5,
load_dim=point_load_dim,
use_dim=point_load_dim,
backend_args=backend_args,
),
dict(
type="LoadPointsFromMultiSweeps",
sweeps_num=sweeps_num,
load_dim=5,
load_dim=point_load_dim,
use_dim=lidar_sweep_dims,
pad_empty_sweeps=True,
remove_close=True,
Expand All @@ -285,7 +238,7 @@
dict(
type="ImageAug3D",
final_dim=image_size,
resize_lim=[0.48, 0.48],
resize_lim=0.02,
bot_pct_lim=[0.0, 0.0],
rot_lim=[0.0, 0.0],
rand_flip=False,
Expand Down Expand Up @@ -313,27 +266,31 @@
),
]

filter_cfg = dict(filter_frames_with_missing_image=True)

train_dataloader = dict(
batch_size=train_batch_size,
num_workers=num_workers,
persistent_workers=True,
sampler=dict(type="DefaultSampler", shuffle=True),
dataset=dict(
type=_base_.dataset_type,
pipeline=train_pipeline,
modality=input_modality,
backend_args=backend_args,
data_root=data_root,
ann_file=info_directory_path + _base_.info_train_file_name,
pipeline=train_pipeline,
metainfo=_base_.metainfo,
class_names=_base_.class_names,
modality=input_modality,
data_prefix=_base_.data_prefix,
test_mode=False,
data_prefix=_base_.data_prefix,
box_type_3d="LiDAR",
backend_args=backend_args,
filter_cfg=filter_cfg,
),
)

val_dataloader = dict(
batch_size=2,
batch_size=test_batch_size,
num_workers=num_workers,
persistent_workers=True,
sampler=dict(type="DefaultSampler", shuffle=False),
Expand All @@ -351,8 +308,9 @@
backend_args=backend_args,
),
)

test_dataloader = dict(
batch_size=2,
batch_size=test_batch_size,
num_workers=num_workers,
persistent_workers=True,
sampler=dict(type="DefaultSampler", shuffle=False),
Expand Down Expand Up @@ -392,6 +350,7 @@
name_mapping=_base_.name_mapping,
eval_class_range=eval_class_range,
filter_attributes=_base_.filter_attributes,
save_csv=True,
)

# learning rate
Expand Down Expand Up @@ -456,14 +415,8 @@
clip_grad=dict(max_norm=35, norm_type=2),
)

# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (8 GPUs) x (4 samples per GPU).
# auto_scale_lr = dict(enable=False, base_batch_size=32)
auto_scale_lr = dict(enable=False, base_batch_size=train_gpu_size * train_batch_size)
auto_scale_lr = dict(enable=True, base_batch_size=4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep the comment, and any reason we set it to True? Does it show any significant improvement/stability for training?


# Only set if the number of train_gpu_size more than 1
if train_gpu_size > 1:
sync_bn = "torch"

randomness = dict(seed=0, diff_rank_seed=False, deterministic=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason we delete it? I believe we need to keep it for reproducibility

Loading