Skip to content

Commit

Permalink
Add small and tiny configs
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhuoz004 committed Jan 3, 2025
1 parent b0d43f3 commit a461a76
Show file tree
Hide file tree
Showing 9 changed files with 300 additions and 14 deletions.
4 changes: 3 additions & 1 deletion tripy/examples/segment-anything-model-v2/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
checkpoints/
checkpoints/*.pt
saved_engines/
output/
*.jpg
bedroom/
2 changes: 1 addition & 1 deletion tripy/examples/segment-anything-model-v2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ This is an implementation of SAM2 model ([original repository](https://github.co

```bash
python3 download_test_data.py
mkdir checkpoints && cd checkpoints && wget https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt
cd checkpoints && sh download_ckpt.sh
```

### Image pipeline
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
if command -v wget &> /dev/null; then
CMD="wget"
elif command -v curl &> /dev/null; then
CMD="curl -L -O"
else
echo "Please install wget or curl to download the checkpoints."
exit 1
fi

SAM2p1_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/092824"
sam2p1_hiera_t_url="${SAM2p1_BASE_URL}/sam2.1_hiera_tiny.pt"
sam2p1_hiera_s_url="${SAM2p1_BASE_URL}/sam2.1_hiera_small.pt"
sam2p1_hiera_l_url="${SAM2p1_BASE_URL}/sam2.1_hiera_large.pt"

echo "Downloading sam2.1_hiera_tiny.pt checkpoint..."
$CMD $sam2p1_hiera_t_url || { echo "Failed to download checkpoint from $sam2p1_hiera_t_url"; exit 1; }

echo "Downloading sam2.1_hiera_small.pt checkpoint..."
$CMD $sam2p1_hiera_s_url || { echo "Failed to download checkpoint from $sam2p1_hiera_s_url"; exit 1; }

echo "Downloading sam2.1_hiera_large.pt checkpoint..."
$CMD $sam2p1_hiera_l_url || { echo "Failed to download checkpoint from $sam2p1_hiera_l_url"; exit 1; }
116 changes: 116 additions & 0 deletions tripy/examples/segment-anything-model-v2/configs/sam2_hiera_s.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# @package _global_

# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 96
num_heads: 1
stages: [1, 2, 11, 2]
global_att_blocks: [7, 10, 13]
window_pos_embed_bkg_spatial_size: [7, 7]
block_pad_size: [0, 0, 0 ,0, 6, 6, 6, 0, 6, 6, 0, 6, 6, 0, 6, 3]
block_unpad_size: [0, 0, 0, 0, 64, 64, 64, 0, 64, 64, 0, 64, 64, 0, 32, 32]
dtype: float16
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [768, 384, 192, 96]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
dtype: float16

memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
dtype: float16
num_layers: 4
# memory attention layer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
# self rope attention
sa_rope_theta: 10000.0
sa_feat_sizes: [64, 64]
sa_embedding_dim: 256
sa_num_heads: 1
sa_downsample_rate: 1
sa_dropout: 0.1
# cross rope attention
ca_rope_theta: 10000.0
ca_feat_sizes: [64, 64]
ca_rope_k_repeat: True
ca_embedding_dim: 256
ca_num_heads: 1
ca_downsample_rate: 1
ca_dropout: 0.1
ca_kv_in_dim: 64

memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2

num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: false
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
model_precision: float16
116 changes: 116 additions & 0 deletions tripy/examples/segment-anything-model-v2/configs/sam2_hiera_t.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# @package _global_

# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 96
num_heads: 1
stages: [1, 2, 7, 2]
global_att_blocks: [5, 7, 9]
window_pos_embed_bkg_spatial_size: [7, 7]
block_pad_size: [0, 0, 0 ,0, 6, 0, 6, 0, 6, 0, 6, 3]
block_unpad_size: [0, 0, 0, 0, 64, 0, 64, 0, 64, 0, 32, 32]
dtype: float16
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [768, 384, 192, 96]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
dtype: float16

memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
dtype: float16
num_layers: 4
# memory attention layer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
# self rope attention
sa_rope_theta: 10000.0
sa_feat_sizes: [64, 64]
sa_embedding_dim: 256
sa_num_heads: 1
sa_downsample_rate: 1
sa_dropout: 0.1
# cross rope attention
ca_rope_theta: 10000.0
ca_feat_sizes: [64, 64]
ca_rope_k_repeat: True
ca_embedding_dim: 256
ca_num_heads: 1
ca_downsample_rate: 1
ca_dropout: 0.1
ca_kv_in_dim: 64

memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2

num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: false
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
model_precision: float16
6 changes: 4 additions & 2 deletions tripy/examples/segment-anything-model-v2/image_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

parser = argparse.ArgumentParser()
parser.add_argument("-b", "--batch", type=int, default=2, help="batch size of the input images, between [1, 4]")
parser.add_argument("-t", "--type", type=str, default="large", choices=["large", "small", "tiny"], help="type of the sam2 model")


def process_predictions(
Expand Down Expand Up @@ -105,8 +106,9 @@ def main(image_path: str, save_path: Optional[str] = None):
image_list = [image] * args.batch

# Initialize SAM2 model
sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
sam2_type = args.type
sam2_checkpoint = f"./checkpoints/sam2.1_hiera_{sam2_type}.pt"
model_cfg = f"sam2_hiera_{sam2_type[0]}.yaml"
device = torch.device("cuda")
sam2_model = build_sam2(
model_cfg,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def __init__(
q_stride: Tuple[int, int] = None,
act_layer: Callable = tp.gelu,
window_size: int = 0,
pad_size: int = 0,
unpad_size: int = 0,
dtype: tp.dtype = tp.float32,
):
super().__init__()
Expand All @@ -118,6 +120,8 @@ def __init__(
self.norm1 = norm_layer(dim)

self.window_size = window_size
self.pad_size = (pad_size, pad_size)
self.unpad_size = (unpad_size, unpad_size)

self.pool, self.q_stride = None, q_stride
if self.q_stride:
Expand Down Expand Up @@ -163,7 +167,7 @@ def call_norm(x, norm):
window_size = self.window_size
if window_size > 0:
H, W = x.shape[1:3]
x, pad_hw = window_partition(x, window_size)
x, pad_hw = window_partition(x, window_size, self.pad_size)

# Window Attention + Q Pooling (if stage change)
x = self.attn(x)
Expand All @@ -181,7 +185,7 @@ def mod_int(x, y):

# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, window_size, pad_hw, (H, W))
x = window_unpartition(x, window_size, pad_hw, self.unpad_size)

x = shortcut + x
# MLP
Expand Down Expand Up @@ -215,6 +219,8 @@ def __init__(
16,
20,
),
block_pad_size: List[int] = [],
block_unpad_size: List[int] = [],
return_interm_layers=True, # return feats from every stage
dtype: str = "float32",
):
Expand All @@ -226,6 +232,8 @@ def __init__(
self.window_spec = window_spec

depth = sum(stages)
self.block_pad_size = block_pad_size if block_pad_size else [0] * depth
self.block_unpad_size = block_unpad_size if block_unpad_size else [0] * depth
self.q_stride = q_stride
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
assert 0 <= q_pool <= len(self.stage_ends[:-1])
Expand Down Expand Up @@ -266,6 +274,8 @@ def __init__(
num_heads=num_heads,
q_stride=self.q_stride if i in self.q_pool_blocks else None,
window_size=window_size,
pad_size=self.block_pad_size[i],
unpad_size=self.block_unpad_size[i],
dtype=tp_dtype,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import nvtripy as tp


def window_partition(x, window_size):
def window_partition(x, window_size, pad_size):
"""
Partition into non-overlapping windows with padding if needed.
Args:
Expand All @@ -38,8 +38,12 @@ def window_partition(x, window_size):
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
# padding is not triggered
pad_h, pad_w = pad_size
Hp, Wp = H, W
if pad_h > 0 or pad_w > 0:
x = tp.pad(x, pad=((0, 0), (0, pad_h), (0, pad_w), (0, 0)))
Hp, Wp = H + pad_h, W + pad_w

x = tp.reshape(x, (B, Hp // window_size, window_size, Wp // window_size, window_size, C))
x = tp.permute(x, (0, 1, 3, 2, 4, 5))
windows = tp.reshape(x, (-1, window_size, window_size, C))
Expand All @@ -58,11 +62,16 @@ def window_unpartition(windows, window_size, pad_hw, hw):
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)

x = tp.reshape(windows, (B, Hp // window_size, Wp // window_size, window_size, window_size, -1))
x = tp.permute(x, (0, 1, 3, 2, 4, 5)) # [B, Hp//window_size, window_size, Wp//window_size, window_size, C]
x = tp.reshape(x, (B, Hp, Wp, -1)) # [B, Hp, Wp, C]

if H > 0 or W > 0:
x = x[:, :H, :W, :]

return x


Expand Down
Loading

0 comments on commit a461a76

Please sign in to comment.