diff --git a/ethology/reid/__init__.py b/ethology/reid/__init__.py new file mode 100644 index 00000000..d2f6d334 --- /dev/null +++ b/ethology/reid/__init__.py @@ -0,0 +1 @@ +# ReID module for ethology diff --git a/ethology/reid/backbones/__init__.py b/ethology/reid/backbones/__init__.py new file mode 100644 index 00000000..8e23892b --- /dev/null +++ b/ethology/reid/backbones/__init__.py @@ -0,0 +1 @@ +# Backbones for ReID models diff --git a/ethology/reid/backbones/hacnn.py b/ethology/reid/backbones/hacnn.py new file mode 100644 index 00000000..27771cf7 --- /dev/null +++ b/ethology/reid/backbones/hacnn.py @@ -0,0 +1,337 @@ +"""HACNN backbone for person re-identification.""" + +import torch +from torch import nn +from torch.nn import functional as F + +__all__ = ["HACNN"] + + +class ConvBlock(nn.Module): + def __init__(self, in_c, out_c, k, s=1, p=0): + """Convolutional block with batch norm and ReLU.""" + super().__init__() + self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p) + self.bn = nn.BatchNorm2d(out_c) + + def forward(self, x): + return F.relu(self.bn(self.conv(x))) + + +class InceptionA(nn.Module): + def __init__(self, in_channels, out_channels): + """InceptionA block.""" + super().__init__() + mid_channels = out_channels // 4 + self.stream1 = nn.Sequential( + ConvBlock(in_channels, mid_channels, 1), + ConvBlock(mid_channels, mid_channels, 3, p=1), + ) + self.stream2 = nn.Sequential( + ConvBlock(in_channels, mid_channels, 1), + ConvBlock(mid_channels, mid_channels, 3, p=1), + ) + self.stream3 = nn.Sequential( + ConvBlock(in_channels, mid_channels, 1), + ConvBlock(mid_channels, mid_channels, 3, p=1), + ) + self.stream4 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1), + ConvBlock(in_channels, mid_channels, 1), + ) + + def forward(self, x): + s1 = self.stream1(x) + s2 = self.stream2(x) + s3 = self.stream3(x) + s4 = self.stream4(x) + y = torch.cat([s1, s2, s3, s4], dim=1) + return y + + +class InceptionB(nn.Module): + def __init__(self, in_channels, out_channels): + """InceptionB block.""" + super().__init__() + mid_channels = out_channels // 4 + self.stream1 = nn.Sequential( + ConvBlock(in_channels, mid_channels, 1), + ConvBlock(mid_channels, mid_channels, 3, s=2, p=1), + ) + self.stream2 = nn.Sequential( + ConvBlock(in_channels, mid_channels, 1), + ConvBlock(mid_channels, mid_channels, 3, p=1), + ConvBlock(mid_channels, mid_channels, 3, s=2, p=1), + ) + self.stream3 = nn.Sequential( + nn.MaxPool2d(3, stride=2, padding=1), + ConvBlock(in_channels, mid_channels * 2, 1), + ) + + def forward(self, x): + s1 = self.stream1(x) + s2 = self.stream2(x) + s3 = self.stream3(x) + y = torch.cat([s1, s2, s3], dim=1) + return y + + +class SpatialAttn(nn.Module): + def __init__(self): + """Spatial attention block.""" + super().__init__() + self.conv1 = ConvBlock(1, 1, 3, s=2, p=1) + self.conv2 = ConvBlock(1, 1, 1) + + def forward(self, x): + x = x.mean(1, keepdim=True) + x = self.conv1(x) + x = F.interpolate( + x, + (x.size(2) * 2, x.size(3) * 2), + mode="bilinear", + align_corners=True, + ) + x = self.conv2(x) + return x + + +class ChannelAttn(nn.Module): + def __init__(self, in_channels, reduction_rate=16): + """Channel attention block.""" + super().__init__() + assert in_channels % reduction_rate == 0 + self.conv1 = ConvBlock(in_channels, in_channels // reduction_rate, 1) + self.conv2 = ConvBlock(in_channels // reduction_rate, in_channels, 1) + + def forward(self, x): + x = F.avg_pool2d(x, x.size()[2:]) + x = self.conv1(x) + x = self.conv2(x) + return x + + +class SoftAttn(nn.Module): + def __init__(self, in_channels): + """Soft attention block.""" + super().__init__() + self.spatial_attn = SpatialAttn() + self.channel_attn = ChannelAttn(in_channels) + self.conv = ConvBlock(in_channels, in_channels, 1) + + def forward(self, x): + y_spatial = self.spatial_attn(x) + y_channel = self.channel_attn(x) + y = y_spatial * y_channel + y = torch.sigmoid(self.conv(y)) + return y + + +class HardAttn(nn.Module): + def __init__(self, in_channels): + """Hard attention block.""" + super().__init__() + self.fc = nn.Linear(in_channels, 4 * 2) + self.init_params() + + def init_params(self): + self.fc.weight.data.zero_() + self.fc.bias.data.copy_( + torch.tensor( + [0, -0.75, 0, -0.25, 0, 0.25, 0, 0.75], dtype=torch.float + ) + ) + + def forward(self, x): + x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1)) + theta = torch.tanh(self.fc(x)) + theta = theta.view(-1, 4, 2) + return theta + + +class HarmAttn(nn.Module): + def __init__(self, in_channels): + """Harmonious attention block.""" + super().__init__() + self.soft_attn = SoftAttn(in_channels) + self.hard_attn = HardAttn(in_channels) + + def forward(self, x): + y_soft_attn = self.soft_attn(x) + theta = self.hard_attn(x) + return y_soft_attn, theta + + +class HACNN(nn.Module): + def __init__( + self, + num_classes, + loss="softmax", + nchannels=None, + feat_dim=512, + learn_region=True, + use_gpu=True, + **kwargs, + ): + """Harmonious Attention Convolutional Neural Network (HACNN) for person re-identification.""" + super().__init__() + if nchannels is None: + nchannels = [128, 256, 384] + self.loss = loss + self.learn_region = learn_region + self.use_gpu = use_gpu + self.conv = ConvBlock(3, 32, 3, s=2, p=1) + self.inception1 = nn.Sequential( + InceptionA(32, nchannels[0]), + InceptionB(nchannels[0], nchannels[0]), + ) + self.ha1 = HarmAttn(nchannels[0]) + self.inception2 = nn.Sequential( + InceptionA(nchannels[0], nchannels[1]), + InceptionB(nchannels[1], nchannels[1]), + ) + self.ha2 = HarmAttn(nchannels[1]) + self.inception3 = nn.Sequential( + InceptionA(nchannels[1], nchannels[2]), + InceptionB(nchannels[2], nchannels[2]), + ) + self.ha3 = HarmAttn(nchannels[2]) + self.fc_global = nn.Sequential( + nn.Linear(nchannels[2], feat_dim), + nn.BatchNorm1d(feat_dim), + nn.ReLU(), + ) + self.classifier_global = nn.Linear(feat_dim, num_classes) + if self.learn_region: + self.init_scale_factors() + self.local_conv1 = InceptionB(32, nchannels[0]) + self.local_conv2 = InceptionB(nchannels[0], nchannels[1]) + self.local_conv3 = InceptionB(nchannels[1], nchannels[2]) + self.fc_local = nn.Sequential( + nn.Linear(nchannels[2] * 4, feat_dim), + nn.BatchNorm1d(feat_dim), + nn.ReLU(), + ) + self.classifier_local = nn.Linear(feat_dim, num_classes) + self.feat_dim = feat_dim * 2 + else: + self.feat_dim = feat_dim + + def init_scale_factors(self): + """Initialize scale factors for STN.""" + self.scale_factors = [] + self.scale_factors.append( + torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float) + ) + self.scale_factors.append( + torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float) + ) + self.scale_factors.append( + torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float) + ) + self.scale_factors.append( + torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float) + ) + + def stn(self, x, theta): + """Spatial transformer network.""" + grid = F.affine_grid(theta, x.size()) + x = F.grid_sample(x, grid) + return x + + def transform_theta(self, theta_i, region_idx): + """Transform theta for a given region.""" + scale_factors = self.scale_factors[region_idx] + theta = torch.zeros(theta_i.size(0), 2, 3) + theta[:, :, :2] = scale_factors + theta[:, :, -1] = theta_i + if self.use_gpu: + theta = theta.to(next(self.parameters()).device) + return theta + + def forward(self, x): + """Forward pass.""" + assert x.size(2) == 160 and x.size(3) == 64, ( + f"Input size does not match, expected (160, 64) but got ({x.size(2)}, {x.size(3)})" + ) + x = self.conv(x) + x1 = self.inception1(x) + x1_attn, x1_theta = self.ha1(x1) + x1_out = x1 * x1_attn + if self.learn_region: + x1_local_list = [] + for region_idx in range(4): + x1_theta_i = x1_theta[:, region_idx, :] + x1_theta_i = self.transform_theta(x1_theta_i, region_idx) + x1_trans_i = self.stn(x, x1_theta_i) + x1_trans_i = F.interpolate( + x1_trans_i, (24, 28), mode="bilinear", align_corners=True + ) + x1_local_i = self.local_conv1(x1_trans_i) + x1_local_list.append(x1_local_i) + x2 = self.inception2(x1_out) + x2_attn, x2_theta = self.ha2(x2) + x2_out = x2 * x2_attn + if self.learn_region: + x2_local_list = [] + for region_idx in range(4): + x2_theta_i = x2_theta[:, region_idx, :] + x2_theta_i = self.transform_theta(x2_theta_i, region_idx) + x2_trans_i = self.stn(x1_out, x2_theta_i) + x2_trans_i = F.interpolate( + x2_trans_i, (12, 14), mode="bilinear", align_corners=True + ) + x2_local_i = x2_trans_i + x1_local_list[region_idx] + x2_local_i = self.local_conv2(x2_local_i) + x2_local_list.append(x2_local_i) + x3 = self.inception3(x2_out) + x3_attn, x3_theta = self.ha3(x3) + x3_out = x3 * x3_attn + if self.learn_region: + x3_local_list = [] + for region_idx in range(4): + x3_theta_i = x3_theta[:, region_idx, :] + x3_theta_i = self.transform_theta(x3_theta_i, region_idx) + x3_trans_i = self.stn(x2_out, x3_theta_i) + x3_trans_i = F.interpolate( + x3_trans_i, (6, 7), mode="bilinear", align_corners=True + ) + x3_local_i = x3_trans_i + x2_local_list[region_idx] + x3_local_i = self.local_conv3(x3_local_i) + x3_local_list.append(x3_local_i) + x_global = F.avg_pool2d(x3_out, x3_out.size()[2:]).view( + x3_out.size(0), x3_out.size(1) + ) + x_global = self.fc_global(x_global) + if self.learn_region: + x_local_list = [] + for region_idx in range(4): + x_local_i = x3_local_list[region_idx] + x_local_i = F.avg_pool2d(x_local_i, x_local_i.size()[2:]).view( + x_local_i.size(0), -1 + ) + x_local_list.append(x_local_i) + x_local = torch.cat(x_local_list, 1) + x_local = self.fc_local(x_local) + if not self.training: + if self.learn_region: + x_global = x_global / x_global.norm(p=2, dim=1, keepdim=True) + x_local = x_local / x_local.norm(p=2, dim=1, keepdim=True) + return torch.cat([x_global, x_local], 1) + else: + return x_global + prelogits_global = self.classifier_global(x_global) + if self.learn_region: + prelogits_local = self.classifier_local(x_local) + if self.loss == "softmax": + if self.learn_region: + return (prelogits_global, prelogits_local) + else: + return prelogits_global + elif self.loss == "triplet": + if self.learn_region: + return (prelogits_global, prelogits_local), (x_global, x_local) + else: + return prelogits_global, x_global + else: + raise KeyError(f"Unsupported loss: {self.loss}") diff --git a/ethology/reid/backbones/mlfn.py b/ethology/reid/backbones/mlfn.py new file mode 100644 index 00000000..c638a5d5 --- /dev/null +++ b/ethology/reid/backbones/mlfn.py @@ -0,0 +1,272 @@ +"""MLFN backbone for person re-identification.""" + +import torch +import torch.utils.model_zoo as model_zoo +from torch import nn +from torch.nn import functional as F + +__all__ = ["mlfn"] +model_urls = { + # training epoch = 5, top1 = 51.6 + "imagenet": "https://mega.nz/#!YHxAhaxC!yu9E6zWl0x5zscSouTdbZu8gdFFytDdl-RAdD2DEfpk", +} + + +class MLFNBlock(nn.Module): + def __init__( + self, in_channels, out_channels, stride, fsm_channels, groups=32 + ): + super().__init__() + self.groups = groups + mid_channels = out_channels // 2 + + # Factor Modules + super().__init__() + self.fm_bn1 = nn.BatchNorm2d(mid_channels) + self.fm_conv2 = nn.Conv2d( + mid_channels, + mid_channels, + 3, + stride=stride, + padding=1, + bias=False, + groups=self.groups, + ) + self.fm_bn2 = nn.BatchNorm2d(mid_channels) + self.fm_conv3 = nn.Conv2d(mid_channels, out_channels, 1, bias=False) + self.fm_bn3 = nn.BatchNorm2d(out_channels) + + # Factor Selection Module + self.fsm = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, fsm_channels[0], 1), + nn.BatchNorm2d(fsm_channels[0]), + nn.ReLU(inplace=True), + nn.Conv2d(fsm_channels[0], fsm_channels[1], 1), + nn.BatchNorm2d(fsm_channels[1]), + nn.ReLU(inplace=True), + nn.Conv2d(fsm_channels[1], self.groups, 1), + nn.BatchNorm2d(self.groups), + nn.Sigmoid(), + ) + + self.downsample = None + if in_channels != out_channels or stride > 1: + self.downsample = nn.Sequential( + nn.Conv2d( + in_channels, out_channels, 1, stride=stride, bias=False + ), + nn.BatchNorm2d(out_channels), + ) + + def forward(self, x): + residual = x + s = self.fsm(x) + + # reduce dimension + x = self.fm_conv1(x) + x = self.fm_bn1(x) + x = F.relu(x, inplace=True) + + # group convolution + x = self.fm_conv2(x) + x = self.fm_bn2(x) + x = F.relu(x, inplace=True) + + # factor selection + b, c = x.size(0), x.size(1) + n = c // self.groups + ss = s.repeat(1, n, 1, 1) # from (b, g, 1, 1) to (b, g*n=c, 1, 1) + ss = ss.view(b, n, self.groups, 1, 1) + ss = ss.permute(0, 2, 1, 3, 4).contiguous() + ss = ss.view(b, c, 1, 1) + x = ss * x + + # recover dimension + x = self.fm_conv3(x) + x = self.fm_bn3(x) + x = F.relu(x, inplace=True) + + if self.downsample is not None: + residual = self.downsample(residual) + + return F.relu(residual + x, inplace=True), s + + +class MLFN(nn.Module): + """Multi-Level Factorisation Net. + + Reference: + Chang et al. Multi-Level Factorisation Net for + Person Re-Identification. CVPR 2018. + + Public keys: + - ``mlfn``: MLFN (Multi-Level Factorisation Net). + """ + + def __init__( + self, + num_classes, + loss="softmax", + groups=32, + channels=None, + embed_dim=1024, + **kwargs, + ): + super().__init__() + if channels is None: + channels = [64, 256, 512, 1024, 2048] + channels = (None,) + self.groups = groups + + # first convolutional layer + self.conv1 = nn.Conv2d(3, channels[0], 7, stride=2, padding=3) + self.bn1 = nn.BatchNorm2d(channels[0]) + self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) + + # main body + self.feature = nn.ModuleList( + [ + # layer 1-3 + MLFNBlock(channels[0], channels[1], 1, [128, 64], self.groups), + MLFNBlock(channels[1], channels[1], 1, [128, 64], self.groups), + MLFNBlock(channels[1], channels[1], 1, [128, 64], self.groups), + # layer 4-7 + MLFNBlock( + channels[1], channels[2], 2, [256, 128], self.groups + ), + MLFNBlock( + channels[2], channels[2], 1, [256, 128], self.groups + ), + MLFNBlock( + channels[2], channels[2], 1, [256, 128], self.groups + ), + MLFNBlock( + channels[2], channels[2], 1, [256, 128], self.groups + ), + # layer 8-13 + MLFNBlock( + channels[2], channels[3], 2, [512, 128], self.groups + ), + MLFNBlock( + channels[3], channels[3], 1, [512, 128], self.groups + ), + MLFNBlock( + channels[3], channels[3], 1, [512, 128], self.groups + ), + MLFNBlock( + channels[3], channels[3], 1, [512, 128], self.groups + ), + MLFNBlock( + channels[3], channels[3], 1, [512, 128], self.groups + ), + MLFNBlock( + channels[3], channels[3], 1, [512, 128], self.groups + ), + # layer 14-16 + MLFNBlock( + channels[3], channels[4], 2, [512, 128], self.groups + ), + MLFNBlock( + channels[4], channels[4], 1, [512, 128], self.groups + ), + MLFNBlock( + channels[4], channels[4], 1, [512, 128], self.groups + ), + ] + ) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + + # projection functions + self.fc_x = nn.Sequential( + nn.Conv2d(channels[4], embed_dim, 1, bias=False), + nn.BatchNorm2d(embed_dim), + nn.ReLU(inplace=True), + ) + self.fc_s = nn.Sequential( + nn.Conv2d(self.groups * 16, embed_dim, 1, bias=False), + nn.BatchNorm2d(embed_dim), + nn.ReLU(inplace=True), + ) + + self.classifier = nn.Linear(embed_dim, num_classes) + + self.init_params() + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode="fan_out", nonlinearity="relu" + ) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x, inplace=True) + x = self.maxpool(x) + + s_hat = [] + for block in self.feature: + x, s = block(x) + s_hat.append(s) + s_hat = torch.cat(s_hat, 1) + + x = self.global_avgpool(x) + x = self.fc_x(x) + s_hat = self.fc_s(s_hat) + + v = (x + s_hat) * 0.5 + v = v.view(v.size(0), -1) + + if not self.training: + return v + + y = self.classifier(v) + + if self.loss == "softmax": + return y + elif self.loss == "triplet": + return y, v + else: + raise KeyError(f"Unsupported loss: {self.loss}") + + +def init_pretrained_weights(model, model_url): + """Initialize model with pretrained weights. + + Keep layers unchanged if they don't match pretrained layers in name or size. + """ + pretrain_dict = model_zoo.load_url(model_url) + model_dict = model.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + +def mlfn(num_classes, loss="softmax", pretrained=True, **kwargs): + model = MLFN(num_classes, loss, **kwargs) + if pretrained: + # init_pretrained_weights(model, model_urls['imagenet']) + import warnings + + warnings.warn( + "The imagenet pretrained weights need to be manually downloaded from {}".format( + model_urls["imagenet"] + ), + stacklevel=2, + ) + return model diff --git a/ethology/reid/backbones/mobilenetv2.py b/ethology/reid/backbones/mobilenetv2.py new file mode 100644 index 00000000..d0f5becd --- /dev/null +++ b/ethology/reid/backbones/mobilenetv2.py @@ -0,0 +1,280 @@ + +""" +MobileNetV2 backbone for person re-identification. +""" + + +import torch.utils.model_zoo as model_zoo +from torch import nn +from torch.nn import functional as F + +__all__ = ["mobilenetv2_x1_0", "mobilenetv2_x1_4"] + +model_urls = { + # 1.0: top-1 71.3 + "mobilenetv2_x1_0": "https://mega.nz/#!NKp2wAIA!1NH1pbNzY_M2hVk_hdsxNM1NUOWvvGPHhaNr-fASF6c", + # 1.4: top-1 73.9 + "mobilenetv2_x1_4": "https://mega.nz/#!RGhgEIwS!xN2s2ZdyqI6vQ3EwgmRXLEW3khr9tpXg96G9SUJugGk", +} + + +class ConvBlock(nn.Module): + """Basic convolutional block. + + convolution (bias discarded) + batch normalization + relu6. + + Args: + in_c (int): number of input channels. + out_c (int): number of output channels. + k (int or tuple): kernel size. + s (int or tuple): stride. + p (int or tuple): padding. + g (int): number of blocked connections from input channels + to output channels (default: 1). + + def __init__(self, in_c, out_c, k, s=1, p=0, g=1): + super().__init__() + self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p, bias=False, groups=g) + self.bn = nn.BatchNorm2d(out_c) + + # Only keep the correct __init__ + + def forward(self, x): + return F.relu6(self.bn(self.conv(x))) + + +class Bottleneck(nn.Module): + def __init__(self, in_channels, out_channels, expansion_factor, stride=1): + super().__init__() + mid_channels = in_channels * expansion_factor + self.use_residual = stride == 1 and in_channels == out_channels + self.conv1 = ConvBlock(in_channels, mid_channels, 1) + self.dwconv2 = ConvBlock( + mid_channels, mid_channels, 3, stride, 1, g=mid_channels + ) + self.conv3 = nn.Sequential( + nn.Conv2d(mid_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + ) + + def forward(self, x): + m = self.conv1(x) + m = self.dwconv2(m) + m = self.conv3(m) + if self.use_residual: + return x + m + else: + return m + + +class MobileNetV2(nn.Module): + """ + MobileNetV2 backbone for person re-identification. + + Reference: + Sandler et al. MobileNetV2: Inverted Residuals and Linear Bottlenecks. CVPR 2018. + + Public keys: + - mobilenetv2_x1_0: MobileNetV2 x1.0. + - mobilenetv2_x1_4: MobileNetV2 x1.4. + """ + + def __init__( + self, + num_classes, + width_mult=1, + loss="softmax", + fc_dims=None, + dropout_p=None, + **kwargs, + ): + super().__init__() + self.loss = loss + self.in_channels = int(32 * width_mult) + self.feature_dim = int(1280 * width_mult) if width_mult > 1 else 1280 + + # construct layers + self.conv1 = ConvBlock(3, self.in_channels, 3, s=2, p=1) + self.conv2 = self._make_layer( + Bottleneck, 1, int(16 * width_mult), 1, 1 + ) + self.conv3 = self._make_layer( + Bottleneck, 6, int(24 * width_mult), 2, 2 + ) + self.conv4 = self._make_layer( + Bottleneck, 6, int(32 * width_mult), 3, 2 + ) + self.conv5 = self._make_layer( + Bottleneck, 6, int(64 * width_mult), 4, 2 + ) + self.conv6 = self._make_layer( + Bottleneck, 6, int(96 * width_mult), 3, 1 + ) + self.conv7 = self._make_layer( + Bottleneck, 6, int(160 * width_mult), 3, 2 + ) + self.conv8 = self._make_layer( + Bottleneck, 6, int(320 * width_mult), 1, 1 + ) + self.conv9 = ConvBlock(self.in_channels, self.feature_dim, 1) + + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = self._construct_fc_layer( + fc_dims, self.feature_dim, dropout_p + ) + self.classifier = nn.Linear(self.feature_dim, num_classes) + + self._init_params() + + def _make_layer(self, block, t, c, n, s): + # t: expansion factor + # c: output channels + # n: number of blocks + # s: stride for first layer + layers = [] + layers.append(block(self.in_channels, c, t, s)) + self.in_channels = c + for _ in range(1, n): + layers.append(block(self.in_channels, c, t)) + return nn.Sequential(*layers) + + def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): + """Constructs fully connected layer. + + Args: + fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed + input_dim (int): input dimension + dropout_p (float): dropout probability, if None, dropout is unused + + """ + if fc_dims is None: + self.feature_dim = input_dim + return None + + assert isinstance(fc_dims, (list, tuple)), ( + f"fc_dims must be either list or tuple, but got {type(fc_dims)}" + ) + + layers = [] + for dim in fc_dims: + layers.append(nn.Linear(input_dim, dim)) + layers.append(nn.BatchNorm1d(dim)) + layers.append(nn.ReLU(inplace=True)) + if dropout_p is not None: + layers.append(nn.Dropout(p=dropout_p)) + input_dim = dim + + self.feature_dim = fc_dims[-1] + + return nn.Sequential(*layers) + + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode="fan_out", nonlinearity="relu" + ) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d) or isinstance( + m, nn.BatchNorm1d + ): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def featuremaps(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + x = self.conv5(x) + x = self.conv6(x) + x = self.conv7(x) + x = self.conv8(x) + x = self.conv9(x) + return x + + def forward(self, x): + f = self.featuremaps(x) + v = self.global_avgpool(f) + v = v.view(v.size(0), -1) + + if self.fc is not None: + v = self.fc(v) + + if not self.training: + return v + + y = self.classifier(v) + + if self.loss == "softmax": + return y + elif self.loss == "triplet": + return y, v + else: + raise KeyError(f"Unsupported loss: {self.loss}") + + +def init_pretrained_weights(model, model_url): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + pretrain_dict = model_zoo.load_url(model_url) + model_dict = model.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + +def mobilenetv2_x1_0(num_classes, loss, pretrained=True, **kwargs): + model = MobileNetV2( + num_classes, + loss=loss, + width_mult=1, + fc_dims=None, + dropout_p=None, + **kwargs, + ) + if pretrained: + # init_pretrained_weights(model, model_urls['mobilenetv2_x1_0']) + import warnings + + warnings.warn( + "The imagenet pretrained weights need to be manually downloaded from {}".format( + model_urls["mobilenetv2_x1_0"] + ) + ) + return model + + +def mobilenetv2_x1_4(num_classes, loss, pretrained=True, **kwargs): + model = MobileNetV2( + num_classes, + loss=loss, + width_mult=1.4, + fc_dims=None, + dropout_p=None, + **kwargs, + ) + if pretrained: + # init_pretrained_weights(model, model_urls['mobilenetv2_x1_4']) + import warnings + + warnings.warn( + "The imagenet pretrained weights need to be manually downloaded from {}".format( + model_urls["mobilenetv2_x1_4"] + ) + ) + return model + + +# Copied from boxmot/boxmot/reid/backbones/mobilenetv2.py diff --git a/ethology/reid/backbones/osnet.py b/ethology/reid/backbones/osnet.py new file mode 100644 index 00000000..c13dd5b7 --- /dev/null +++ b/ethology/reid/backbones/osnet.py @@ -0,0 +1,535 @@ +# Mikel Broström 🔥 BoxMOT 🧾 AGPL-3.0 license + + +import warnings + +import torch +from torch import nn +from torch.nn import functional as F + +__all__ = [ + "osnet_x1_0", + "osnet_x0_75", + "osnet_x0_5", + "osnet_x0_25", + "osnet_ibn_x1_0", +] + +pretrained_urls = { + "osnet_x1_0": "https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY", + "osnet_x0_75": "https://drive.google.com/uc?id=1uwA9fElHOk3ZogwbeY5GkLI6QPTX70Hq", + "osnet_x0_5": "https://drive.google.com/uc?id=16DGLbZukvVYgINws8u8deSaOqjybZ83i", + "osnet_x0_25": "https://drive.google.com/uc?id=1rb8UN5ZzPKRc_xvtHlyDh-cSz88YX9hs", + "osnet_ibn_x1_0": "https://drive.google.com/uc?id=1sr90V6irlYYDd4_4ISU2iruoRG8J__6l", +} + +# ...existing code for ConvLayer, Conv1x1, Conv1x1Linear, Conv3x3, LightConv3x3, ChannelGate, OSBlock... + + +class ConvLayer(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + groups=1, + IN=False, + ): + super(ConvLayer, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=False, + groups=groups, + ) + if IN: + self.bn = nn.InstanceNorm2d(out_channels, affine=True) + else: + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Conv1x1(nn.Module): + def __init__(self, in_channels, out_channels, stride=1, groups=1): + super(Conv1x1, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + 1, + stride=stride, + padding=0, + bias=False, + groups=groups, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Conv1x1Linear(nn.Module): + def __init__(self, in_channels, out_channels, stride=1): + super(Conv1x1Linear, self).__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, 1, stride=stride, padding=0, bias=False + ) + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class Conv3x3(nn.Module): + def __init__(self, in_channels, out_channels, stride=1, groups=1): + super(Conv3x3, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + 3, + stride=stride, + padding=1, + bias=False, + groups=groups, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class LightConv3x3(nn.Module): + def __init__(self, in_channels, out_channels): + super(LightConv3x3, self).__init__() + self.conv1 = nn.Conv2d( + in_channels, out_channels, 1, stride=1, padding=0, bias=False + ) + self.conv2 = nn.Conv2d( + out_channels, + out_channels, + 3, + stride=1, + padding=1, + bias=False, + groups=out_channels, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class ChannelGate(nn.Module): + def __init__( + self, + in_channels, + num_gates=None, + return_gates=False, + gate_activation="sigmoid", + reduction=16, + layer_norm=False, + ): + super(ChannelGate, self).__init__() + if num_gates is None: + num_gates = in_channels + self.return_gates = return_gates + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.fc1 = nn.Conv2d( + in_channels, + in_channels // reduction, + kernel_size=1, + bias=True, + padding=0, + ) + self.norm1 = None + if layer_norm: + self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1)) + self.relu = nn.ReLU(inplace=True) + self.fc2 = nn.Conv2d( + in_channels // reduction, + num_gates, + kernel_size=1, + bias=True, + padding=0, + ) + if gate_activation == "sigmoid": + self.gate_activation = nn.Sigmoid() + elif gate_activation == "relu": + self.gate_activation = nn.ReLU(inplace=True) + elif gate_activation == "linear": + self.gate_activation = None + else: + raise RuntimeError(f"Unknown gate activation: {gate_activation}") + + def forward(self, x): + input = x + x = self.global_avgpool(x) + x = self.fc1(x) + if self.norm1 is not None: + x = self.norm1(x) + x = self.relu(x) + x = self.fc2(x) + if self.gate_activation is not None: + x = self.gate_activation(x) + if self.return_gates: + return x + return input * x + + +class OSBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + IN=False, + bottleneck_reduction=4, + **kwargs, + ): + super(OSBlock, self).__init__() + mid_channels = out_channels // bottleneck_reduction + self.conv1 = Conv1x1(in_channels, mid_channels) + self.conv2a = LightConv3x3(mid_channels, mid_channels) + self.conv2b = nn.Sequential( + LightConv3x3(mid_channels, mid_channels), + LightConv3x3(mid_channels, mid_channels), + ) + self.conv2c = nn.Sequential( + LightConv3x3(mid_channels, mid_channels), + LightConv3x3(mid_channels, mid_channels), + LightConv3x3(mid_channels, mid_channels), + ) + self.conv2d = nn.Sequential( + LightConv3x3(mid_channels, mid_channels), + LightConv3x3(mid_channels, mid_channels), + LightConv3x3(mid_channels, mid_channels), + LightConv3x3(mid_channels, mid_channels), + ) + self.gate = ChannelGate(mid_channels) + self.conv3 = Conv1x1Linear(mid_channels, out_channels) + self.downsample = None + if in_channels != out_channels: + self.downsample = Conv1x1Linear(in_channels, out_channels) + self.IN = None + if IN: + self.IN = nn.InstanceNorm2d(out_channels, affine=True) + + def forward(self, x): + identity = x + x1 = self.conv1(x) + x2a = self.conv2a(x1) + x2b = self.conv2b(x1) + x2c = self.conv2c(x1) + x2d = self.conv2d(x1) + x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d) + x3 = self.conv3(x2) + if self.downsample is not None: + identity = self.downsample(identity) + out = x3 + identity + if self.IN is not None: + out = self.IN(out) + return F.relu(out) + + +class OSNet(nn.Module): + def __init__( + self, + num_classes, + blocks, + layers, + channels, + feature_dim=512, + loss="softmax", + IN=False, + **kwargs, + ): + super(OSNet, self).__init__() + num_blocks = len(blocks) + assert num_blocks == len(layers) + assert num_blocks == len(channels) - 1 + self.loss = loss + self.feature_dim = feature_dim + self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN) + self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) + self.conv2 = self._make_layer( + blocks[0], + layers[0], + channels[0], + channels[1], + reduce_spatial_size=True, + IN=IN, + ) + self.conv3 = self._make_layer( + blocks[1], + layers[1], + channels[1], + channels[2], + reduce_spatial_size=True, + ) + self.conv4 = self._make_layer( + blocks[2], + layers[2], + channels[2], + channels[3], + reduce_spatial_size=False, + ) + self.conv5 = Conv1x1(channels[3], channels[3]) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = self._construct_fc_layer( + self.feature_dim, channels[3], dropout_p=None + ) + self.classifier = nn.Linear(self.feature_dim, num_classes) + self._init_params() + + def _make_layer( + self, + block, + layer, + in_channels, + out_channels, + reduce_spatial_size, + IN=False, + ): + layers = [] + layers.append(block(in_channels, out_channels, IN=IN)) + for i in range(1, layer): + layers.append(block(out_channels, out_channels, IN=IN)) + if reduce_spatial_size: + layers.append( + nn.Sequential( + Conv1x1(out_channels, out_channels), + nn.AvgPool2d(2, stride=2), + ) + ) + return nn.Sequential(*layers) + + def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): + if fc_dims is None or fc_dims < 0: + self.feature_dim = input_dim + return None + if isinstance(fc_dims, int): + fc_dims = [fc_dims] + layers = [] + for dim in fc_dims: + layers.append(nn.Linear(input_dim, dim)) + layers.append(nn.BatchNorm1d(dim)) + layers.append(nn.ReLU(inplace=True)) + if dropout_p is not None: + layers.append(nn.Dropout(p=dropout_p)) + input_dim = dim + self.feature_dim = fc_dims[-1] + return nn.Sequential(*layers) + + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode="fan_out", nonlinearity="relu" + ) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d) or isinstance( + m, nn.BatchNorm1d + ): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def featuremaps(self, x): + x = self.conv1(x) + x = self.maxpool(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + x = self.conv5(x) + return x + + def forward(self, x, return_featuremaps=False): + x = self.featuremaps(x) + if return_featuremaps: + return x + v = self.global_avgpool(x) + v = v.view(v.size(0), -1) + if self.fc is not None: + v = self.fc(v) + if not self.training: + return v + y = self.classifier(v) + if self.loss == "softmax": + return y + elif self.loss == "triplet": + return y, v + else: + raise KeyError(f"Unsupported loss: {self.loss}") + + +def init_pretrained_weights(model, key=""): + import os + from collections import OrderedDict + + import gdown + + def _get_torch_home(): + ENV_TORCH_HOME = "TORCH_HOME" + ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" + DEFAULT_CACHE_DIR = "~/.cache" + torch_home = os.path.expanduser( + os.getenv( + ENV_TORCH_HOME, + os.path.join( + os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch" + ), + ) + ) + return torch_home + + filename = key + "_imagenet.pth" + # Try ethology/models/ directory first + ethology_root = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../../") + ) + models_dir = os.path.join(ethology_root, "models") + os.makedirs(models_dir, exist_ok=True) + local_file = os.path.join(models_dir, filename) + torch_home = _get_torch_home() + model_dir = os.path.join(torch_home, "checkpoints") + os.makedirs(model_dir, exist_ok=True) + cached_file = os.path.join(model_dir, filename) + # Prefer ethology/models/ directory file if present + if os.path.exists(local_file): + print(f"[OSNet] Loading model weights from {local_file}") + cached_file = local_file + elif os.path.exists(cached_file): + print(f"[OSNet] Loading model weights from {cached_file}") + else: + print(f"[OSNet] Downloading model weights to {cached_file}") + gdown.download(pretrained_urls[key], cached_file, quiet=False) + state_dict = torch.load(cached_file) + model_dict = model.state_dict() + new_state_dict = OrderedDict() + matched_layers, discarded_layers = [], [] + for k, v in state_dict.items(): + if k.startswith("module."): + k = k[7:] + if k in model_dict and model_dict[k].size() == v.size(): + new_state_dict[k] = v + matched_layers.append(k) + else: + discarded_layers.append(k) + model_dict.update(new_state_dict) + model.load_state_dict(model_dict) + if len(matched_layers) == 0: + warnings.warn( + f'The pretrained weights from "{cached_file}" cannot be loaded, ' + "please check the key names manually " + "(** ignored and continue **)" + ) + else: + print( + f'Successfully loaded imagenet pretrained weights from "{cached_file}"' + ) + if len(discarded_layers) > 0: + print( + "** The following layers are discarded " + f"due to unmatched keys or layer size: {discarded_layers}" + ) + + +def osnet_x1_0(num_classes=1000, pretrained=True, loss="softmax", **kwargs): + model = OSNet( + num_classes, + blocks=[OSBlock, OSBlock, OSBlock], + layers=[2, 2, 2], + channels=[64, 256, 384, 512], + loss=loss, + **kwargs, + ) + if pretrained: + init_pretrained_weights(model, key="osnet_x1_0") + return model + + +def osnet_x0_75(num_classes=1000, pretrained=True, loss="softmax", **kwargs): + model = OSNet( + num_classes, + blocks=[OSBlock, OSBlock, OSBlock], + layers=[2, 2, 2], + channels=[48, 192, 288, 384], + loss=loss, + **kwargs, + ) + if pretrained: + init_pretrained_weights(model, key="osnet_x0_75") + return model + + +def osnet_x0_5(num_classes=1000, pretrained=True, loss="softmax", **kwargs): + model = OSNet( + num_classes, + blocks=[OSBlock, OSBlock, OSBlock], + layers=[2, 2, 2], + channels=[32, 128, 192, 256], + loss=loss, + **kwargs, + ) + if pretrained: + init_pretrained_weights(model, key="osnet_x0_5") + return model + + +def osnet_x0_25(num_classes=1000, pretrained=True, loss="softmax", **kwargs): + model = OSNet( + num_classes, + blocks=[OSBlock, OSBlock, OSBlock], + layers=[2, 2, 2], + channels=[16, 64, 96, 128], + loss=loss, + **kwargs, + ) + if pretrained: + init_pretrained_weights(model, key="osnet_x0_25") + return model + + +def osnet_ibn_x1_0( + num_classes=1000, pretrained=True, loss="softmax", **kwargs +): + model = OSNet( + num_classes, + blocks=[OSBlock, OSBlock, OSBlock], + layers=[2, 2, 2], + channels=[64, 256, 384, 512], + loss=loss, + IN=True, + **kwargs, + ) + if pretrained: + init_pretrained_weights(model, key="osnet_ibn_x1_0") + return model diff --git a/ethology/reid/backbones/osnet_ain.py b/ethology/reid/backbones/osnet_ain.py new file mode 100644 index 00000000..2ef3da25 --- /dev/null +++ b/ethology/reid/backbones/osnet_ain.py @@ -0,0 +1,547 @@ +# Mikel Broström 🔥 BoxMOT 🧾 AGPL-3.0 license + + +import warnings + +import torch +from torch import nn +from torch.nn import functional as F + +__all__ = [ + "osnet_ain_x1_0", + "osnet_ain_x0_75", + "osnet_ain_x0_5", + "osnet_ain_x0_25", +] + +pretrained_urls = { + "osnet_ain_x1_0": "https://drive.google.com/uc?id=1-CaioD9NaqbHK_kzSMW8VE4_3KcsRjEo", + "osnet_ain_x0_75": "https://drive.google.com/uc?id=1apy0hpsMypqstfencdH-jKIUEFOW4xoM", + "osnet_ain_x0_5": "https://drive.google.com/uc?id=1KusKvEYyKGDTUBVRxRiz55G31wkihB6l", + "osnet_ain_x0_25": "https://drive.google.com/uc?id=1SxQt2AvmEcgWNhaRb2xC4rP6ZwVDP0Wt", +} + +# ...existing code for ConvLayer, Conv1x1, Conv1x1Linear, Conv3x3, LightConv3x3, LightConvStream, ChannelGate, OSBlock, OSBlockINin, OSNet, init_pretrained_weights, and instantiation functions... + + +class ConvLayer(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + groups=1, + IN=False, + ): + super(ConvLayer, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=False, + groups=groups, + ) + if IN: + self.bn = nn.InstanceNorm2d(out_channels, affine=True) + else: + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return self.relu(x) + + +class Conv1x1(nn.Module): + def __init__(self, in_channels, out_channels, stride=1, groups=1): + super(Conv1x1, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + 1, + stride=stride, + padding=0, + bias=False, + groups=groups, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return self.relu(x) + + +class Conv1x1Linear(nn.Module): + def __init__(self, in_channels, out_channels, stride=1, bn=True): + super(Conv1x1Linear, self).__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, 1, stride=stride, padding=0, bias=False + ) + self.bn = None + if bn: + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + return x + + +class Conv3x3(nn.Module): + def __init__(self, in_channels, out_channels, stride=1, groups=1): + super(Conv3x3, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + 3, + stride=stride, + padding=1, + bias=False, + groups=groups, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return self.relu(x) + + +class LightConv3x3(nn.Module): + def __init__(self, in_channels, out_channels): + super(LightConv3x3, self).__init__() + self.conv1 = nn.Conv2d( + in_channels, out_channels, 1, stride=1, padding=0, bias=False + ) + self.conv2 = nn.Conv2d( + out_channels, + out_channels, + 3, + stride=1, + padding=1, + bias=False, + groups=out_channels, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.bn(x) + return self.relu(x) + + +class LightConvStream(nn.Module): + def __init__(self, in_channels, out_channels, depth): + super(LightConvStream, self).__init__() + assert depth >= 1 + layers = [LightConv3x3(in_channels, out_channels)] + for i in range(depth - 1): + layers.append(LightConv3x3(out_channels, out_channels)) + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class ChannelGate(nn.Module): + def __init__( + self, + in_channels, + num_gates=None, + return_gates=False, + gate_activation="sigmoid", + reduction=16, + layer_norm=False, + ): + super(ChannelGate, self).__init__() + if num_gates is None: + num_gates = in_channels + self.return_gates = return_gates + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.fc1 = nn.Conv2d( + in_channels, + in_channels // reduction, + kernel_size=1, + bias=True, + padding=0, + ) + self.norm1 = None + if layer_norm: + self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1)) + self.relu = nn.ReLU() + self.fc2 = nn.Conv2d( + in_channels // reduction, + num_gates, + kernel_size=1, + bias=True, + padding=0, + ) + if gate_activation == "sigmoid": + self.gate_activation = nn.Sigmoid() + elif gate_activation == "relu": + self.gate_activation = nn.ReLU() + elif gate_activation == "linear": + self.gate_activation = None + else: + raise RuntimeError(f"Unknown gate activation: {gate_activation}") + + def forward(self, x): + input = x + x = self.global_avgpool(x) + x = self.fc1(x) + if self.norm1 is not None: + x = self.norm1(x) + x = self.relu(x) + x = self.fc2(x) + if self.gate_activation is not None: + x = self.gate_activation(x) + if self.return_gates: + return x + return input * x + + +class OSBlock(nn.Module): + def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs): + super(OSBlock, self).__init__() + assert T >= 1 + assert out_channels >= reduction and out_channels % reduction == 0 + mid_channels = out_channels // reduction + self.conv1 = Conv1x1(in_channels, mid_channels) + self.conv2 = nn.ModuleList( + [ + LightConvStream(mid_channels, mid_channels, t) + for t in range(1, T + 1) + ] + ) + self.gate = ChannelGate(mid_channels) + self.conv3 = Conv1x1Linear(mid_channels, out_channels) + self.downsample = None + if in_channels != out_channels: + self.downsample = Conv1x1Linear(in_channels, out_channels) + + def forward(self, x): + identity = x + x1 = self.conv1(x) + x2 = 0 + for conv2_t in self.conv2: + x2_t = conv2_t(x1) + x2 = x2 + self.gate(x2_t) + x3 = self.conv3(x2) + if self.downsample is not None: + identity = self.downsample(identity) + out = x3 + identity + return F.relu(out) + + +class OSBlockINin(nn.Module): + def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs): + super(OSBlockINin, self).__init__() + assert T >= 1 + assert out_channels >= reduction and out_channels % reduction == 0 + mid_channels = out_channels // reduction + self.conv1 = Conv1x1(in_channels, mid_channels) + self.conv2 = nn.ModuleList( + [ + LightConvStream(mid_channels, mid_channels, t) + for t in range(1, T + 1) + ] + ) + self.gate = ChannelGate(mid_channels) + self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn=False) + self.downsample = None + if in_channels != out_channels: + self.downsample = Conv1x1Linear(in_channels, out_channels) + self.IN = nn.InstanceNorm2d(out_channels, affine=True) + + def forward(self, x): + identity = x + x1 = self.conv1(x) + x2 = 0 + for conv2_t in self.conv2: + x2_t = conv2_t(x1) + x2 = x2 + self.gate(x2_t) + x3 = self.conv3(x2) + x3 = self.IN(x3) + if self.downsample is not None: + identity = self.downsample(identity) + out = x3 + identity + return F.relu(out) + + +class OSNet(nn.Module): + def __init__( + self, + num_classes, + blocks, + layers, + channels, + feature_dim=512, + loss="softmax", + conv1_IN=False, + **kwargs, + ): + super(OSNet, self).__init__() + num_blocks = len(blocks) + assert num_blocks == len(layers) + assert num_blocks == len(channels) - 1 + self.loss = loss + self.feature_dim = feature_dim + self.conv1 = ConvLayer( + 3, channels[0], 7, stride=2, padding=3, IN=conv1_IN + ) + self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) + self.conv2 = self._make_layer( + blocks[0], layers[0], channels[0], channels[1] + ) + self.pool2 = nn.Sequential( + Conv1x1(channels[1], channels[1]), nn.AvgPool2d(2, stride=2) + ) + self.conv3 = self._make_layer( + blocks[1], layers[1], channels[1], channels[2] + ) + self.pool3 = nn.Sequential( + Conv1x1(channels[2], channels[2]), nn.AvgPool2d(2, stride=2) + ) + self.conv4 = self._make_layer( + blocks[2], layers[2], channels[2], channels[3] + ) + self.conv5 = Conv1x1(channels[3], channels[3]) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = self._construct_fc_layer( + self.feature_dim, channels[3], dropout_p=None + ) + self.classifier = nn.Linear(self.feature_dim, num_classes) + self._init_params() + + def _make_layer(self, blocks, layer, in_channels, out_channels): + layers = [] + layers += [blocks[0](in_channels, out_channels)] + for i in range(1, len(blocks)): + layers += [blocks[i](out_channels, out_channels)] + return nn.Sequential(*layers) + + def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): + if fc_dims is None or fc_dims < 0: + self.feature_dim = input_dim + return None + if isinstance(fc_dims, int): + fc_dims = [fc_dims] + layers = [] + for dim in fc_dims: + layers.append(nn.Linear(input_dim, dim)) + layers.append(nn.BatchNorm1d(dim)) + layers.append(nn.ReLU()) + if dropout_p is not None: + layers.append(nn.Dropout(p=dropout_p)) + input_dim = dim + self.feature_dim = fc_dims[-1] + return nn.Sequential(*layers) + + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode="fan_out", nonlinearity="relu" + ) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif ( + isinstance(m, nn.BatchNorm2d) + or isinstance(m, nn.BatchNorm1d) + or isinstance(m, nn.InstanceNorm2d) + ): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def featuremaps(self, x): + x = self.conv1(x) + x = self.maxpool(x) + x = self.conv2(x) + x = self.pool2(x) + x = self.conv3(x) + x = self.pool3(x) + x = self.conv4(x) + x = self.conv5(x) + return x + + def forward(self, x, return_featuremaps=False): + x = self.featuremaps(x) + if return_featuremaps: + return x + v = self.global_avgpool(x) + v = v.view(v.size(0), -1) + if self.fc is not None: + v = self.fc(v) + if not self.training: + return v + y = self.classifier(v) + if self.loss == "softmax": + return y + elif self.loss == "triplet": + return y, v + else: + raise KeyError(f"Unsupported loss: {self.loss}") + + +def init_pretrained_weights(model, key=""): + import errno + import os + from collections import OrderedDict + + import gdown + + def _get_torch_home(): + ENV_TORCH_HOME = "TORCH_HOME" + ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" + DEFAULT_CACHE_DIR = "~/.cache" + torch_home = os.path.expanduser( + os.getenv( + ENV_TORCH_HOME, + os.path.join( + os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch" + ), + ) + ) + return torch_home + + torch_home = _get_torch_home() + model_dir = os.path.join(torch_home, "checkpoints") + try: + os.makedirs(model_dir) + except OSError as e: + if e.errno == errno.EEXIST: + pass + else: + raise + filename = key + "_imagenet.pth" + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + gdown.download(pretrained_urls[key], cached_file, quiet=False) + state_dict = torch.load(cached_file) + model_dict = model.state_dict() + new_state_dict = OrderedDict() + matched_layers, discarded_layers = [], [] + for k, v in state_dict.items(): + if k.startswith("module."): + k = k[7:] + if k in model_dict and model_dict[k].size() == v.size(): + new_state_dict[k] = v + matched_layers.append(k) + else: + discarded_layers.append(k) + model_dict.update(new_state_dict) + model.load_state_dict(model_dict) + if len(matched_layers) == 0: + warnings.warn( + f'The pretrained weights from "{cached_file}" cannot be loaded, ' + "please check the key names manually " + "(** ignored and continue **)" + ) + else: + print( + f'Successfully loaded imagenet pretrained weights from "{cached_file}"' + ) + if len(discarded_layers) > 0: + print( + "** The following layers are discarded " + f"due to unmatched keys or layer size: {discarded_layers}" + ) + + +def osnet_ain_x1_0( + num_classes=1000, pretrained=True, loss="softmax", **kwargs +): + model = OSNet( + num_classes, + blocks=[ + [OSBlockINin, OSBlockINin], + [OSBlock, OSBlockINin], + [OSBlockINin, OSBlock], + ], + layers=[2, 2, 2], + channels=[64, 256, 384, 512], + loss=loss, + conv1_IN=True, + **kwargs, + ) + if pretrained: + init_pretrained_weights(model, key="osnet_ain_x1_0") + return model + + +def osnet_ain_x0_75( + num_classes=1000, pretrained=True, loss="softmax", **kwargs +): + model = OSNet( + num_classes, + blocks=[ + [OSBlockINin, OSBlockINin], + [OSBlock, OSBlockINin], + [OSBlockINin, OSBlock], + ], + layers=[2, 2, 2], + channels=[48, 192, 288, 384], + loss=loss, + conv1_IN=True, + **kwargs, + ) + if pretrained: + init_pretrained_weights(model, key="osnet_ain_x0_75") + return model + + +def osnet_ain_x0_5( + num_classes=1000, pretrained=True, loss="softmax", **kwargs +): + model = OSNet( + num_classes, + blocks=[ + [OSBlockINin, OSBlockINin], + [OSBlock, OSBlockINin], + [OSBlockINin, OSBlock], + ], + layers=[2, 2, 2], + channels=[32, 128, 192, 256], + loss=loss, + conv1_IN=True, + **kwargs, + ) + if pretrained: + init_pretrained_weights(model, key="osnet_ain_x0_5") + return model + + +def osnet_ain_x0_25( + num_classes=1000, pretrained=True, loss="softmax", **kwargs +): + model = OSNet( + num_classes, + blocks=[ + [OSBlockINin, OSBlockINin], + [OSBlock, OSBlockINin], + [OSBlockINin, OSBlock], + ], + layers=[2, 2, 2], + channels=[16, 64, 96, 128], + loss=loss, + conv1_IN=True, + **kwargs, + ) + if pretrained: + init_pretrained_weights(model, key="osnet_ain_x0_25") + return model diff --git a/ethology/reid/backbones/resnet.py b/ethology/reid/backbones/resnet.py new file mode 100644 index 00000000..8ee75172 --- /dev/null +++ b/ethology/reid/backbones/resnet.py @@ -0,0 +1,707 @@ + +""" +Code source: https://github.com/pytorch/vision +""" + +from __future__ import absolute_import, division +import torch.utils.model_zoo as model_zoo +from torch import nn + +__all__ = [ + "resnet18", + "resnet34", + "resnet50", + "resnet101", + "resnet152", + "resnext50_32x4d", + "resnext101_32x8d", + "resnet50_fc512", +] + +model_urls = { + "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", + "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", + "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", + "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", + "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", + "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", + "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", +} +} + +# ...existing code for conv3x3, conv1x1, BasicBlock, Bottleneck, ResNet, init_pretrained_weights, and instantiation functions... + +<<<<<<< HEAD + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes, out_planes, stride=1): + return nn.Conv2d( + in_planes, out_planes, kernel_size=1, stride=stride, bias=False + ) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + ): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError( + "BasicBlock only supports groups=1 and base_width=64" + ) + if dilation > 1: + raise NotImplementedError( + "Dilation > 1 not supported in BasicBlock" + ) + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + out = self.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + ): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + out = self.relu(out) + return out + + +class ResNet(nn.Module): + def __init__( + self, + num_classes, + loss, + block, + layers, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm_layer=None, + last_stride=2, + fc_dims=None, + dropout_p=None, + **kwargs, + ): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + self.loss = loss + self.feature_dim = 512 * block.expansion + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + f"replace_stride_with_dilation should be None or a 3-element tuple, got {replace_stride_with_dilation}" + ) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d( + 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False + ) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer( + block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0], + ) + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1], + ) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=last_stride, + dilate=replace_stride_with_dilation[2], + ) + self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = self._construct_fc_layer( + fc_dims, 512 * block.expansion, dropout_p + ) + self.classifier = nn.Linear(self.feature_dim, num_classes) + self._init_params() + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) + return nn.Sequential(*layers) + + def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): + if fc_dims is None: + self.feature_dim = input_dim + return None + assert isinstance(fc_dims, (list, tuple)), ( + f"fc_dims must be either list or tuple, but got {type(fc_dims)}" + ) + layers = [] + for dim in fc_dims: + layers.append(nn.Linear(input_dim, dim)) + layers.append(nn.BatchNorm1d(dim)) + layers.append(nn.ReLU(inplace=True)) + if dropout_p is not None: + layers.append(nn.Dropout(p=dropout_p)) + input_dim = dim + self.feature_dim = fc_dims[-1] + return nn.Sequential(*layers) + + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode="fan_out", nonlinearity="relu" + ) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d) or isinstance( + m, nn.BatchNorm1d + ): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def featuremaps(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def forward(self, x): + f = self.featuremaps(x) + v = self.global_avgpool(f) + v = v.view(v.size(0), -1) + if self.fc is not None: + v = self.fc(v) + if not self.training: + return v + y = self.classifier(v) + if self.loss == "softmax": + return y + elif self.loss == "triplet": + return y, v + else: + raise KeyError(f"Unsupported loss: {self.loss}") + + +def init_pretrained_weights(model, model_url): + pretrain_dict = model_zoo.load_url(model_url) + model_dict = model.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + +def resnet18(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=BasicBlock, + layers=[2, 2, 2, 2], + last_stride=2, + fc_dims=None, + dropout_p=None, + **kwargs, + ) + if pretrained: + init_pretrained_weights(model, model_urls["resnet18"]) + return model + + +def resnet34(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=BasicBlock, + layers=[3, 4, 6, 3], + last_stride=2, + fc_dims=None, + dropout_p=None, + **kwargs, + ) + if pretrained: + init_pretrained_weights(model, model_urls["resnet34"]) + return model + + +def resnet50(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=Bottleneck, + layers=[3, 4, 6, 3], + last_stride=2, + fc_dims=None, + dropout_p=None, + **kwargs, + ) + if pretrained: + init_pretrained_weights(model, model_urls["resnet50"]) + return model + + +def resnet101(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=Bottleneck, + layers=[3, 4, 23, 3], + last_stride=2, + fc_dims=None, + dropout_p=None, + **kwargs, + ) + if pretrained: + init_pretrained_weights(model, model_urls["resnet101"]) + return model + + +def resnet152(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=Bottleneck, + layers=[3, 8, 36, 3], + last_stride=2, + fc_dims=None, + dropout_p=None, + **kwargs, + ) + if pretrained: + init_pretrained_weights(model, model_urls["resnet152"]) + return model + + +def resnext50_32x4d(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=Bottleneck, + layers=[3, 4, 6, 3], + last_stride=2, + fc_dims=None, + dropout_p=None, + groups=32, + width_per_group=4, + **kwargs, + ) + if pretrained: + init_pretrained_weights(model, model_urls["resnext50_32x4d"]) + return model + + +def resnext101_32x8d(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=Bottleneck, + layers=[3, 4, 23, 3], + last_stride=2, + fc_dims=None, + dropout_p=None, + groups=32, + width_per_group=8, + **kwargs, + ) + if pretrained: + init_pretrained_weights(model, model_urls["resnext101_32x8d"]) + return model + + +def resnet50_fc512(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=Bottleneck, + layers=[3, 4, 6, 3], + last_stride=1, + fc_dims=[512], + dropout_p=None, + **kwargs, + ) + if pretrained: + init_pretrained_weights(model, model_urls["resnet50"]) + return model +======= +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + +def conv1x1(in_planes, out_planes, stride=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +class BasicBlock(nn.Module): + expansion = 1 + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + def forward(self, x): + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + out = self.relu(out) + return out + +class Bottleneck(nn.Module): + expansion = 4 + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + def forward(self, x): + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + out = self.relu(out) + return out + +class ResNet(nn.Module): + def __init__(self, num_classes, loss, block, layers, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, last_stride=2, fc_dims=None, dropout_p=None, **kwargs): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + self.loss = loss + self.feature_dim = 512 * block.expansion + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride, dilate=replace_stride_with_dilation[2]) + self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = self._construct_fc_layer(fc_dims, 512 * block.expansion, dropout_p) + self.classifier = nn.Linear(self.feature_dim, num_classes) + self._init_params() + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer)) + return nn.Sequential(*layers) + def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): + if fc_dims is None: + self.feature_dim = input_dim + return None + assert isinstance(fc_dims, (list, tuple)), "fc_dims must be either list or tuple, but got {}".format(type(fc_dims)) + layers = [] + for dim in fc_dims: + layers.append(nn.Linear(input_dim, dim)) + layers.append(nn.BatchNorm1d(dim)) + layers.append(nn.ReLU(inplace=True)) + if dropout_p is not None: + layers.append(nn.Dropout(p=dropout_p)) + input_dim = dim + self.feature_dim = fc_dims[-1] + return nn.Sequential(*layers) + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + def featuremaps(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + def forward(self, x): + f = self.featuremaps(x) + v = self.global_avgpool(f) + v = v.view(v.size(0), -1) + if self.fc is not None: + v = self.fc(v) + if not self.training: + return v + y = self.classifier(v) + if self.loss == "softmax": + return y + elif self.loss == "triplet": + return y, v + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) + +def init_pretrained_weights(model, model_url): + pretrain_dict = model_zoo.load_url(model_url) + model_dict = model.state_dict() + pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + +def resnet18(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet(num_classes=num_classes, loss=loss, block=BasicBlock, layers=[2, 2, 2, 2], last_stride=2, fc_dims=None, dropout_p=None, **kwargs) + if pretrained: + init_pretrained_weights(model, model_urls["resnet18"]) + return model + +def resnet34(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet(num_classes=num_classes, loss=loss, block=BasicBlock, layers=[3, 4, 6, 3], last_stride=2, fc_dims=None, dropout_p=None, **kwargs) + if pretrained: + init_pretrained_weights(model, model_urls["resnet34"]) + return model + +def resnet50(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet(num_classes=num_classes, loss=loss, block=Bottleneck, layers=[3, 4, 6, 3], last_stride=2, fc_dims=None, dropout_p=None, **kwargs) + if pretrained: + init_pretrained_weights(model, model_urls["resnet50"]) + return model + +def resnet101(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet(num_classes=num_classes, loss=loss, block=Bottleneck, layers=[3, 4, 23, 3], last_stride=2, fc_dims=None, dropout_p=None, **kwargs) + if pretrained: + init_pretrained_weights(model, model_urls["resnet101"]) + return model + +def resnet152(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet(num_classes=num_classes, loss=loss, block=Bottleneck, layers=[3, 8, 36, 3], last_stride=2, fc_dims=None, dropout_p=None, **kwargs) + if pretrained: + init_pretrained_weights(model, model_urls["resnet152"]) + return model + +def resnext50_32x4d(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet(num_classes=num_classes, loss=loss, block=Bottleneck, layers=[3, 4, 6, 3], last_stride=2, fc_dims=None, dropout_p=None, groups=32, width_per_group=4, **kwargs) + if pretrained: + init_pretrained_weights(model, model_urls["resnext50_32x4d"]) + return model + +def resnext101_32x8d(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet(num_classes=num_classes, loss=loss, block=Bottleneck, layers=[3, 4, 23, 3], last_stride=2, fc_dims=None, dropout_p=None, groups=32, width_per_group=8, **kwargs) + if pretrained: + init_pretrained_weights(model, model_urls["resnext101_32x8d"]) + return model + +def resnet50_fc512(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet(num_classes=num_classes, loss=loss, block=Bottleneck, layers=[3, 4, 6, 3], last_stride=1, fc_dims=[512], dropout_p=None, **kwargs) + if pretrained: + init_pretrained_weights(model, model_urls["resnet50"]) + return model +>>>>>>> a4dd694 (style(reid): fix ruff errors in hacnn.py and mlfn.py\n\n- Add missing docstrings\n- Use super() instead of super(Class, self)\n- Avoid mutable default arguments\n- Fix long lines and other ruff issues) diff --git a/ethology/reid/backends/__init__.py b/ethology/reid/backends/__init__.py new file mode 100644 index 00000000..f032d396 --- /dev/null +++ b/ethology/reid/backends/__init__.py @@ -0,0 +1 @@ +# Backends for ReID inference diff --git a/ethology/reid/backends/base_backend.py b/ethology/reid/backends/base_backend.py new file mode 100644 index 00000000..0edb2826 --- /dev/null +++ b/ethology/reid/backends/base_backend.py @@ -0,0 +1,176 @@ +from abc import abstractmethod +from pathlib import Path + +import cv2 +import gdown +import numpy as np +import torch +from filelock import SoftFileLock + +from ethology.reid.core.registry import ReIDModelRegistry + +# from ethology.utils import logger as LOGGER # If needed, implement or set LOGGER +# from ethology.utils.checks import RequirementsChecker # If needed, implement or set RequirementsChecker + + +class BaseModelBackend: + def __init__(self, weights, device, half): + self.weights = weights[0] if isinstance(weights, list) else weights + if isinstance(self.weights, str): + self.weights = Path(self.weights) + # LOGGER.info(self.weights) + self.device = device + self.half = half + self.model = None + # Support both string and torch.device for device + if hasattr(self.device, "type"): + self.cuda = torch.cuda.is_available() and self.device.type != "cpu" + else: + self.cuda = torch.cuda.is_available() and self.device != "cpu" + + self.download_model(self.weights) + self.model_name = ReIDModelRegistry.get_model_name(self.weights) + + self.model = ReIDModelRegistry.build_model( + self.model_name, + self.weights, + num_classes=ReIDModelRegistry.get_nr_classes(self.weights), + pretrained=not (self.weights and self.weights.is_file()), + use_gpu=device, + ) + # self.checker = RequirementsChecker() + + self.load_model(self.weights) + + self.mean_array = torch.tensor( + [0.485, 0.456, 0.406], device=self.device + ).view(1, 3, 1, 1) + self.std_array = torch.tensor( + [0.229, 0.224, 0.225], device=self.device + ).view(1, 3, 1, 1) + if "clip" in self.model_name: + self.mean_array = torch.tensor( + [0.5, 0.5, 0.5], device=self.device + ).view(1, 3, 1, 1) + self.std_array = torch.tensor( + [0.5, 0.5, 0.5], device=self.device + ).view(1, 3, 1, 1) + + if "vehicleid" in self.weights.name or "veri" in self.weights.name: + input_shape = (256, 256) + elif "lmbn" in self.model_name: + input_shape = (384, 128) + elif "hacnn" in self.model_name: + input_shape = (160, 64) + else: + input_shape = (256, 128) + self.input_shape = input_shape + + def get_crops(self, xyxys, img): + h, w = img.shape[:2] + interpolation_method = cv2.INTER_LINEAR + num_crops = len(xyxys) + crops = torch.empty( + (num_crops, 3, *self.input_shape), + dtype=torch.half if self.half else torch.float, + device=self.device, + ) + for i, box in enumerate(xyxys): + x1, y1, x2, y2 = box.round().astype("int") + x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w, x2), min(h, y2) + crop = img[y1:y2, x1:x2] + crop = cv2.resize( + crop, + (self.input_shape[1], self.input_shape[0]), + interpolation=interpolation_method, + ) + crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB) + crop = torch.from_numpy(crop).to( + self.device, dtype=torch.half if self.half else torch.float + ) + crops[i] = torch.permute(crop, (2, 0, 1)) + crops = crops / 255.0 + crops = (crops - self.mean_array) / self.std_array + return crops + + @torch.no_grad() + def get_features(self, xyxys, img): + if xyxys.size != 0: + crops = self.get_crops(xyxys, img) + crops = self.inference_preprocess(crops) + features = self.forward(crops) + features = self.inference_postprocess(features) + else: + features = np.array([]) + features = features / np.linalg.norm(features, axis=-1, keepdims=True) + return features + + def warmup(self, imgsz=[(256, 128, 3)]): + if self.device.type != "cpu": + im = np.random.randint(0, 255, *imgsz, dtype=np.uint8) + crops = self.get_crops( + xyxys=np.array([[0, 0, 64, 64], [0, 0, 128, 128]]), img=im + ) + crops = self.inference_preprocess(crops) + self.forward(crops) + + def to_numpy(self, x): + return x.cpu().numpy() if isinstance(x, torch.Tensor) else x + + def inference_preprocess(self, x): + if self.half: + if isinstance(x, torch.Tensor): + if x.dtype != torch.float16: + x = x.half() + elif isinstance(x, np.ndarray): + if x.dtype != np.float16: + x = x.astype(np.float16) + if hasattr(self, "nhwc") and self.nhwc: + if isinstance(x, torch.Tensor): + x = x.permute(0, 2, 3, 1) + elif isinstance(x, np.ndarray): + x = np.transpose(x, (0, 2, 3, 1)) + return x + + def inference_postprocess(self, features): + if isinstance(features, (list, tuple)): + return ( + self.to_numpy(features[0]) + if len(features) == 1 + else [self.to_numpy(x) for x in features] + ) + else: + return self.to_numpy(features) + + @abstractmethod + def forward(self, im_batch): + raise NotImplementedError( + "This method should be implemented by subclasses." + ) + + @abstractmethod + def load_model(self, w): + raise NotImplementedError( + "This method should be implemented by subclasses." + ) + + def download_model(self, w): + if isinstance(w, str): + w = Path(w) + if w.suffix != ".pt": + return + model_url = ReIDModelRegistry.get_model_url(w) + lock = SoftFileLock(str(w) + ".lock", timeout=300) + with lock: + if w.exists() or "openvino" in w.name: + # LOGGER.info(f"[PID {os.getpid()}] Found existing ReID weights at {w}; skipping download.") + return + if model_url: + # LOGGER.info(f"[PID {os.getpid()}] Downloading ReID weights from {model_url} → {w}") + gdown.download(model_url, str(w), quiet=False) + else: + # LOGGER.error( + # f"No URL associated with the chosen ReID weights ({w}).\n" + # f"Choose one of the following:" + # ) + ReIDModelRegistry.show_downloadable_models() diff --git a/ethology/reid/backends/onnx_backend.py b/ethology/reid/backends/onnx_backend.py new file mode 100644 index 00000000..41aefb7f --- /dev/null +++ b/ethology/reid/backends/onnx_backend.py @@ -0,0 +1,34 @@ +from ethology.reid.backends.base_backend import BaseModelBackend + + +class ONNXBackend(BaseModelBackend): + def __init__(self, weights, device, half): + super().__init__(weights, device, half) + self.nhwc = False + self.half = half + + def load_model(self, w): + # ONNXRuntime will attempt to use the first provider, and if it fails or is not + # available for some reason, it will fall back to the next provider in the list + if self.device.type == "mps": + # self.checker.check_packages(("onnxruntime-silicon==1.18.1",)) + providers = ["MPSExecutionProvider", "CPUExecutionProvider"] + elif self.device.type == "cuda": + # self.checker.check_packages(("onnxruntime-gpu==1.18.1",)) + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + else: + # self.checker.check_packages(("onnxruntime==1.18.1",)) + providers = ["CPUExecutionProvider"] + import onnxruntime + + self.session = onnxruntime.InferenceSession( + str(w), providers=providers + ) + + def forward(self, im_batch): + im_batch = im_batch.cpu().numpy() + features = self.session.run( + [self.session.get_outputs()[0].name], + {self.session.get_inputs()[0].name: im_batch}, + )[0] + return features diff --git a/ethology/reid/backends/openvino_backend.py b/ethology/reid/backends/openvino_backend.py new file mode 100644 index 00000000..0c56a06e --- /dev/null +++ b/ethology/reid/backends/openvino_backend.py @@ -0,0 +1,49 @@ +from pathlib import Path + +from ethology.reid.backends.base_backend import BaseModelBackend + +# Note: LOGGER can be replaced with print or a local logger if needed + + +class OpenVinoBackend(BaseModelBackend): + def __init__(self, weights, device, half): + super().__init__(weights, device, half) + self.nhwc = False + self.half = half + + def load_model(self, w): + # self.checker.check_packages(("openvino>=2025.2.0",)) + + print(f"Loading {w} for OpenVINO inference...") + try: + # requires openvino-dev: https://pypi.org/project/openvino-dev/ + from openvino import Core, Layout + except ImportError: + print( + f"Running {self.__class__} with the specified OpenVINO weights\n{w.name}\n" + "requires openvino pip package to be installed!\n" + "$ pip install openvino>=2025.2.0\n" + ) + raise + ie = Core() + w = Path(w) + print(w) + if w.suffix == ".bin": + w = w.with_suffix(".xml") + + if not w.is_file(): # if not *.xml + w = next( + Path(w).glob("*.xml") + ) # get *.xml file from *_openvino_model dir + network = ie.read_model(model=w, weights=Path(w).with_suffix(".bin")) + if network.get_parameters()[0].get_layout().empty: + network.get_parameters()[0].set_layout(Layout("NCWH")) + self.executable_network = ie.compile_model( + network, device_name="CPU" + ) # device_name="MYRIAD" for Intel NCS2 + self.output_layer = next(iter(self.executable_network.outputs)) + + def forward(self, im_batch): + im_batch = im_batch.cpu().numpy() # FP32 + features = self.executable_network([im_batch])[self.output_layer] + return features diff --git a/ethology/reid/backends/pytorch_backend.py b/ethology/reid/backends/pytorch_backend.py new file mode 100644 index 00000000..2e859cc8 --- /dev/null +++ b/ethology/reid/backends/pytorch_backend.py @@ -0,0 +1,20 @@ +from ethology.reid.backends.base_backend import BaseModelBackend +from ethology.reid.core.registry import ReIDModelRegistry + + +class PyTorchBackend(BaseModelBackend): + def __init__(self, weights, device, half): + super().__init__(weights, device, half) + self.nhwc = False + self.half = half + + def load_model(self, w): + # Load a PyTorch model + if w and w.is_file(): + ReIDModelRegistry.load_pretrained_weights(self.model, w) + self.model.to(self.device).eval() + self.model.half() if self.half else self.model.float() + + def forward(self, im_batch): + features = self.model(im_batch) + return features diff --git a/ethology/reid/backends/tensorrt_backend.py b/ethology/reid/backends/tensorrt_backend.py new file mode 100644 index 00000000..4f6e95b0 --- /dev/null +++ b/ethology/reid/backends/tensorrt_backend.py @@ -0,0 +1,400 @@ +# Note: LOGGER can be replaced with print or a local logger if needed +import os +from collections import OrderedDict, namedtuple + +import numpy as np +import torch + +from ethology.reid.backends.base_backend import BaseModelBackend + +Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr")) + + +class TensorRTBackend(BaseModelBackend): + def __init__(self, engine_path, device=None): + import hashlib + + import requests + + self.device = device or ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") + ) + self.fp16 = False + self.model_ = None + self.context = None + self.bindings = None + self.binding_addrs = None + self.is_trt10 = False + # Download engine if engine_path is a URL + if engine_path.startswith("http://") or engine_path.startswith( + "https://" + ): + # Use a hash of the URL for filename + engine_hash = hashlib.md5(engine_path.encode()).hexdigest() + filename = f"trt_engine_{engine_hash}.engine" + cache_dir = os.path.expanduser("~/.cache/ethology/tensorrt/") + os.makedirs(cache_dir, exist_ok=True) + cached_file = os.path.join(cache_dir, filename) + if not os.path.exists(cached_file): + print( + f"[TensorRT] Downloading engine from {engine_path} to {cached_file}" + ) + with requests.get(engine_path, stream=True) as r: + r.raise_for_status() + with open(cached_file, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + else: + print(f"[TensorRT] Using cached engine at {cached_file}") + self.engine_path = cached_file + else: + self.engine_path = engine_path + self.load_model(self.engine_path) + + def load_model(self, w): + print(f"Loading {w} for TensorRT inference...") + try: + import pycuda.autoinit # noqa: F401 + import pycuda.driver as cuda + import tensorrt as trt + except ImportError: + raise ImportError( + "TensorRT and pycuda are required for TensorRTBackend. Please install them and ensure libnvinfer.so.8 is available in LD_LIBRARY_PATH." + ) + + if self.device.type == "cpu": + if torch.cuda.is_available(): + self.device = torch.device("cuda:0") + else: + raise ValueError( + "CUDA device not available for TensorRT inference." + ) + + Binding = namedtuple( + "Binding", ("name", "dtype", "shape", "data", "ptr") + ) + logger = trt.Logger(trt.Logger.INFO) + + # Deserialize the engine + with open(w, "rb") as f, trt.Runtime(logger) as runtime: + self.model_ = runtime.deserialize_cuda_engine(f.read()) + + # Execution context + self.context = self.model_.create_execution_context() + self.bindings = OrderedDict() + + self.is_trt10 = not hasattr(self.model_, "num_bindings") + num = ( + range(self.model_.num_io_tensors) + if self.is_trt10 + else range(self.model_.num_bindings) + ) + + # Parse bindings + for index in num: + if self.is_trt10: + name = self.model_.get_tensor_name(index) + dtype = trt.nptype(self.model_.get_tensor_dtype(name)) + is_input = ( + self.model_.get_tensor_mode(name) == trt.TensorIOMode.INPUT + ) + if is_input and -1 in tuple( + self.model_.get_tensor_shape(name) + ): + self.context.set_input_shape( + name, + tuple( + self.model_.get_tensor_profile_shape(name, 0)[1] + ), + ) + if is_input and dtype == np.float16: + self.fp16 = True + + shape = tuple(self.context.get_tensor_shape(name)) + + else: + name = self.model_.get_binding_name(index) + dtype = trt.nptype(self.model_.get_binding_dtype(index)) + is_input = self.model_.binding_is_input(index) + + # Handle dynamic shapes + if is_input and -1 in self.model_.get_binding_shape(index): + profile_index = 0 + min_shape, opt_shape, max_shape = ( + self.model_.get_profile_shape(profile_index, index) + ) + self.context.set_binding_shape(index, opt_shape) + + if is_input and dtype == np.float16: + self.fp16 = True + + shape = tuple(self.context.get_binding_shape(index)) + data = torch.from_numpy(np.empty(shape, dtype=dtype)).to( + self.device + ) + self.bindings[name] = Binding( + name, dtype, shape, data, int(data.data_ptr()) + ) + + self.binding_addrs = OrderedDict( + (n, d.ptr) for n, d in self.bindings.items() + ) + + # Execution context + self.context = self.model_.create_execution_context() + self.bindings = OrderedDict() + + self.is_trt10 = not hasattr(self.model_, "num_bindings") + num = ( + range(self.model_.num_io_tensors) + if self.is_trt10 + else range(self.model_.num_bindings) + ) + + # Parse bindings + for index in num: + if self.is_trt10: + name = self.model_.get_tensor_name(index) + dtype = trt.nptype(self.model_.get_tensor_dtype(name)) + is_input = ( + self.model_.get_tensor_mode(name) == trt.TensorIOMode.INPUT + ) + if is_input and -1 in tuple( + self.model_.get_tensor_shape(name) + ): + self.context.set_input_shape( + name, + tuple( + self.model_.get_tensor_profile_shape(name, 0)[1] + ), + ) + if is_input and dtype == np.float16: + self.fp16 = True + + shape = tuple(self.context.get_tensor_shape(name)) + + else: + name = self.model_.get_binding_name(index) + dtype = trt.nptype(self.model_.get_binding_dtype(index)) + is_input = self.model_.binding_is_input(index) + + # Handle dynamic shapes + if is_input and -1 in self.model_.get_binding_shape(index): + profile_index = 0 + min_shape, opt_shape, max_shape = ( + self.model_.get_profile_shape(profile_index, index) + ) + self.context.set_binding_shape(index, opt_shape) + + if is_input and dtype == np.float16: + self.fp16 = True + + shape = tuple(self.context.get_binding_shape(index)) + data = torch.from_numpy(np.empty(shape, dtype=dtype)).to( + self.device + ) + self.bindings[name] = Binding( + name, dtype, shape, data, int(data.data_ptr()) + ) + + self.binding_addrs = OrderedDict( + (n, d.ptr) for n, d in self.bindings.items() + ) + + def forward(self, im_batch): + temp_im_batch = im_batch.clone() + batch_array = [] + inp_batch = im_batch.shape[0] + out_batch = self.bindings["output"].shape[0] + resultant_features = [] + + # Divide batch to sub batches + while inp_batch > out_batch: + batch_array.append(temp_im_batch[:out_batch]) + temp_im_batch = temp_im_batch[out_batch:] + inp_batch = temp_im_batch.shape[0] + if temp_im_batch.shape[0] > 0: + batch_array.append(temp_im_batch) + + for temp_batch in batch_array: + # Adjust for dynamic shapes + if temp_batch.shape != self.bindings["images"].shape: + if self.is_trt10: + self.context.set_input_shape("images", temp_batch.shape) + self.bindings["images"] = self.bindings["images"]._replace( + shape=temp_batch.shape + ) + self.bindings["output"].data.resize_( + tuple(self.context.get_tensor_shape("output")) + ) + else: + i_in = self.model_.get_binding_index("images") + i_out = self.model_.get_binding_index("output") + self.context.set_binding_shape(i_in, temp_batch.shape) + self.bindings["images"] = self.bindings["images"]._replace( + shape=temp_batch.shape + ) + output_shape = tuple(self.context.get_binding_shape(i_out)) + self.bindings["output"].data.resize_(output_shape) + + s = self.bindings["images"].shape + assert temp_batch.shape == s, ( + f"Input size {temp_batch.shape} does not match model size {s}" + ) + + self.binding_addrs["images"] = int(temp_batch.data_ptr()) + + # Execute inference + self.context.execute_v2(list(self.binding_addrs.values())) + features = self.bindings["output"].data + resultant_features.append(features.clone()) + + if len(resultant_features) == 1: + return resultant_features[0] + else: + rslt_features = torch.cat(resultant_features, dim=0) + rslt_features = rslt_features[: im_batch.shape[0]] + return rslt_features + + def load_model(self, w): + print(f"Loading {w} for TensorRT inference...") + # self.checker.check_packages(("nvidia-tensorrt",)) + try: + import tensorrt as trt # TensorRT library + except ImportError: + raise ImportError("Please install tensorrt to use this backend.") + + if self.device.type == "cpu": + if torch.cuda.is_available(): + self.device = torch.device("cuda:0") + else: + raise ValueError( + "CUDA device not available for TensorRT inference." + ) + + Binding = namedtuple( + "Binding", ("name", "dtype", "shape", "data", "ptr") + ) + logger = trt.Logger(trt.Logger.INFO) + + # Deserialize the engine + with open(w, "rb") as f, trt.Runtime(logger) as runtime: + self.model_ = runtime.deserialize_cuda_engine(f.read()) + + # Execution context + self.context = self.model_.create_execution_context() + self.bindings = OrderedDict() + + self.is_trt10 = not hasattr(self.model_, "num_bindings") + num = ( + range(self.model_.num_io_tensors) + if self.is_trt10 + else range(self.model_.num_bindings) + ) + + # Parse bindings + for index in num: + if self.is_trt10: + name = self.model_.get_tensor_name(index) + dtype = trt.nptype(self.model_.get_tensor_dtype(name)) + is_input = ( + self.model_.get_tensor_mode(name) == trt.TensorIOMode.INPUT + ) + if is_input and -1 in tuple( + self.model_.get_tensor_shape(name) + ): + self.context.set_input_shape( + name, + tuple( + self.model_.get_tensor_profile_shape(name, 0)[1] + ), + ) + if is_input and dtype == np.float16: + self.fp16 = True + + shape = tuple(self.context.get_tensor_shape(name)) + + else: + name = self.model_.get_binding_name(index) + dtype = trt.nptype(self.model_.get_binding_dtype(index)) + is_input = self.model_.binding_is_input(index) + + # Handle dynamic shapes + if is_input and -1 in self.model_.get_binding_shape(index): + profile_index = 0 + min_shape, opt_shape, max_shape = ( + self.model_.get_profile_shape(profile_index, index) + ) + self.context.set_binding_shape(index, opt_shape) + + if is_input and dtype == np.float16: + self.fp16 = True + + shape = tuple(self.context.get_binding_shape(index)) + data = torch.from_numpy(np.empty(shape, dtype=dtype)).to( + self.device + ) + self.bindings[name] = Binding( + name, dtype, shape, data, int(data.data_ptr()) + ) + + self.binding_addrs = OrderedDict( + (n, d.ptr) for n, d in self.bindings.items() + ) + + def forward(self, im_batch): + temp_im_batch = im_batch.clone() + batch_array = [] + inp_batch = im_batch.shape[0] + out_batch = self.bindings["output"].shape[0] + resultant_features = [] + + # Divide batch to sub batches + while inp_batch > out_batch: + batch_array.append(temp_im_batch[:out_batch]) + temp_im_batch = temp_im_batch[out_batch:] + inp_batch = temp_im_batch.shape[0] + if temp_im_batch.shape[0] > 0: + batch_array.append(temp_im_batch) + + for temp_batch in batch_array: + # Adjust for dynamic shapes + if temp_batch.shape != self.bindings["images"].shape: + if self.is_trt10: + self.context.set_input_shape("images", temp_batch.shape) + self.bindings["images"] = self.bindings["images"]._replace( + shape=temp_batch.shape + ) + self.bindings["output"].data.resize_( + tuple(self.context.get_tensor_shape("output")) + ) + else: + i_in = self.model_.get_binding_index("images") + i_out = self.model_.get_binding_index("output") + self.context.set_binding_shape(i_in, temp_batch.shape) + self.bindings["images"] = self.bindings["images"]._replace( + shape=temp_batch.shape + ) + output_shape = tuple(self.context.get_binding_shape(i_out)) + self.bindings["output"].data.resize_(output_shape) + + s = self.bindings["images"].shape + assert temp_batch.shape == s, ( + f"Input size {temp_batch.shape} does not match model size {s}" + ) + + self.binding_addrs["images"] = int(temp_batch.data_ptr()) + + # Execute inference + self.context.execute_v2(list(self.binding_addrs.values())) + features = self.bindings["output"].data + resultant_features.append(features.clone()) + + if len(resultant_features) == 1: + return resultant_features[0] + else: + rslt_features = torch.cat(resultant_features, dim=0) + rslt_features = rslt_features[: im_batch.shape[0]] + return rslt_features diff --git a/ethology/reid/backends/tflite_backend.py b/ethology/reid/backends/tflite_backend.py new file mode 100644 index 00000000..eb10d4e8 --- /dev/null +++ b/ethology/reid/backends/tflite_backend.py @@ -0,0 +1,42 @@ +from pathlib import Path + +import numpy as np +import torch + +from ethology.reid.backends.base_backend import BaseModelBackend + +# Note: LOGGER can be replaced with print or a local logger if needed + + +class TFLiteBackend(BaseModelBackend): + """A class to handle TensorFlow Lite model inference with dynamic batch size support.""" + + def __init__(self, weights: Path, device: str, half: bool): + super().__init__(weights, device, half) + self.nhwc = True + self.half = False + + def load_model(self, w): + # self.checker.check_packages(("tensorflow",)) + print(f"Loading {str(w)} for TensorFlow Lite inference...") + import tensorflow as tf + + self.interpreter = tf.lite.Interpreter(model_path=str(w)) + self.interpreter.allocate_tensors() + self.input_details = self.interpreter.get_input_details() + self.output_details = self.interpreter.get_output_details() + self.current_allocated_batch_size = self.input_details[0]["shape"][0] + + def forward(self, im_batch: torch.Tensor) -> np.ndarray: + im_batch = im_batch.cpu().numpy() + batch_size = im_batch.shape[0] + if batch_size != self.current_allocated_batch_size: + self.interpreter.resize_tensor_input( + self.input_details[0]["index"], [batch_size, 256, 128, 3] + ) + self.interpreter.allocate_tensors() + self.current_allocated_batch_size = batch_size + self.interpreter.set_tensor(self.input_details[0]["index"], im_batch) + self.interpreter.invoke() + features = self.interpreter.get_tensor(self.output_details[0]["index"]) + return features diff --git a/ethology/reid/backends/torchscript_backend.py b/ethology/reid/backends/torchscript_backend.py new file mode 100644 index 00000000..1142fcc4 --- /dev/null +++ b/ethology/reid/backends/torchscript_backend.py @@ -0,0 +1,21 @@ +import torch + +from ethology.reid.backends.base_backend import BaseModelBackend + +# Note: LOGGER can be replaced with print or a local logger if needed + + +class TorchscriptBackend(BaseModelBackend): + def __init__(self, weights, device, half): + super().__init__(weights, device, half) + self.nhwc = False + self.half = half + + def load_model(self, w): + print(f"Loading {w} for TorchScript inference...") + self.model = torch.jit.load(w) + self.model.half() if self.half else self.model.float() + + def forward(self, im_batch): + features = self.model(im_batch) + return features diff --git a/ethology/reid/core/__init__.py b/ethology/reid/core/__init__.py new file mode 100644 index 00000000..9dab06ec --- /dev/null +++ b/ethology/reid/core/__init__.py @@ -0,0 +1 @@ +# Core logic for ReID diff --git a/ethology/reid/core/auto_backend.py b/ethology/reid/core/auto_backend.py new file mode 100644 index 00000000..22f2c4e2 --- /dev/null +++ b/ethology/reid/core/auto_backend.py @@ -0,0 +1,89 @@ +from pathlib import Path + +import torch + +from ethology.reid.backends.onnx_backend import ONNXBackend +from ethology.reid.backends.openvino_backend import OpenVinoBackend +from ethology.reid.backends.pytorch_backend import PyTorchBackend + +try: + from ethology.reid.backends.tensorrt_backend import TensorRTBackend +except ImportError: + + class TensorRTBackend: + def __init__(self, *args, **kwargs): + raise ImportError( + "TensorRT and pycuda are required for TensorRTBackend. Please install them and ensure libcudnn.so.8 is available in LD_LIBRARY_PATH." + ) + + +from ethology.reid.backends.tflite_backend import TFLiteBackend +from ethology.reid.backends.torchscript_backend import TorchscriptBackend + +# from ethology.reid.core import export_formats # If needed, implement or copy export_formats +# from ethology.utils import WEIGHTS # If needed, implement or set WEIGHTS +# from ethology.utils import logger as LOGGER # If needed, implement or set LOGGER +# from ethology.utils.torch_utils import select_device # If needed, implement or set select_device + + +class ReidAutoBackend: + def __init__( + self, + weights: Path, + device: torch.device = torch.device("cpu"), + half: bool = False, + ): + super().__init__() + w = weights[0] if isinstance(weights, list) else weights + ( + self.pt, + self.pth, + self.jit, + self.onnx, + self.xml, + self.engine, + self.tflite, + ) = self.model_type(w) + self.weights = weights + self.device = device # For simplicity, skip select_device for now + self.half = half + self.model = self.get_backend() + + def get_backend(self): + backend_map = { + self.pt or self.pth: PyTorchBackend, + self.jit: TorchscriptBackend, + self.onnx: ONNXBackend, + self.engine: TensorRTBackend, + self.xml: OpenVinoBackend, + self.tflite: TFLiteBackend, + } + for condition, backend_class in backend_map.items(): + if condition: + return backend_class(self.weights, self.device, self.half) + raise RuntimeError("This model framework is not supported yet!") + + def check_suffix( + self, + file: Path = "osnet_x0_25_msmt17.pt", + suffix: str | tuple[str, ...] = (".pt",), + msg: str = "", + ): + suffix = [suffix] if isinstance(suffix, str) else list(suffix) + files = [file] if isinstance(file, (str, Path)) else list(file) + for f in files: + file_suffix = Path(f).suffix.lower() + if file_suffix and file_suffix not in suffix: + print( + f"File {f} does not have an acceptable suffix. Expected: {suffix}" + ) + + def model_type(self, p: Path) -> tuple[bool, ...]: + # For demo, just check for .pt + sf = [".pt", ".pth", ".jit", ".onnx", ".xml", ".engine", ".tflite"] + self.check_suffix(p, sf) + types = [str(Path(p)).endswith(s) for s in sf] + # OpenVINO explicit check + if Path(p).suffix in [".xml", ".bin"]: + types[3] = True + return tuple(types) diff --git a/ethology/reid/core/config.py b/ethology/reid/core/config.py new file mode 100644 index 00000000..dc17cc14 --- /dev/null +++ b/ethology/reid/core/config.py @@ -0,0 +1,16 @@ +MODEL_TYPES = [ + "resnet50", + "resnet101", + "mlfn", + "hacnn", + "mobilenetv2_x1_0", + "mobilenetv2_x1_4", + "osnet_x1_0", + "osnet_x0_75", + "osnet_x0_5", + "osnet_x0_25", + "osnet_ibn_x1_0", + "osnet_ain_x1_0", + "lmbn_n", + "clip", +] diff --git a/ethology/reid/core/factory.py b/ethology/reid/core/factory.py new file mode 100644 index 00000000..27406383 --- /dev/null +++ b/ethology/reid/core/factory.py @@ -0,0 +1,44 @@ +# Import model constructors from ethology's local backbones +from ethology.reid.backbones.hacnn import HACNN +from ethology.reid.backbones.mlfn import mlfn +from ethology.reid.backbones.mobilenetv2 import ( + mobilenetv2_x1_0, + mobilenetv2_x1_4, +) +from ethology.reid.backbones.osnet import ( + osnet_ibn_x1_0, + osnet_x0_5, + osnet_x0_25, + osnet_x0_75, + osnet_x1_0, +) +from ethology.reid.backbones.osnet_ain import ( + osnet_ain_x0_5, + osnet_ain_x0_25, + osnet_ain_x0_75, + osnet_ain_x1_0, +) +from ethology.reid.backbones.resnet import resnet50, resnet101 + +# from ethology.reid.backbones.lmbn.lmbn_n import LMBN_n # If present +# from ethology.reid.backbones.clip.make_model import make_model # If present + +MODEL_FACTORY = { + "resnet50": resnet50, + "resnet101": resnet101, + "mobilenetv2_x1_0": mobilenetv2_x1_0, + "mobilenetv2_x1_4": mobilenetv2_x1_4, + "hacnn": HACNN, + "mlfn": mlfn, + "osnet_x1_0": osnet_x1_0, + "osnet_x0_75": osnet_x0_75, + "osnet_x0_5": osnet_x0_5, + "osnet_x0_25": osnet_x0_25, + "osnet_ibn_x1_0": osnet_ibn_x1_0, + "osnet_ain_x1_0": osnet_ain_x1_0, + "osnet_ain_x0_75": osnet_ain_x0_75, + "osnet_ain_x0_5": osnet_ain_x0_5, + "osnet_ain_x0_25": osnet_ain_x0_25, + # "lmbn_n": LMBN_n, # Uncomment if implemented + # "clip": make_model, # Uncomment if implemented +} diff --git a/ethology/reid/core/handler.py b/ethology/reid/core/handler.py new file mode 100644 index 00000000..ba521ab2 --- /dev/null +++ b/ethology/reid/core/handler.py @@ -0,0 +1,36 @@ +# Main handler for ReID in ethology + +# Thin wrapper to use BoxMOT ReID models in ethology +from pathlib import Path + +import numpy as np + +# Import ethology's local ReID handler +from ethology.reid.core.reid_handler import ReID as EthologyReID + + +class ReIDHandler: + """Ethology ReID handler using local models and backends.""" + + def __init__(self, weights: str | Path, device="cpu", half=False): + self.model = EthologyReID(weights=weights, device=device, half=half) + + def extract_features( + self, frame: np.ndarray, dets: np.ndarray + ) -> np.ndarray: + """Extract feature embeddings for detections in a frame. + + Parameters + ---------- + frame : np.ndarray + (H, W, C) BGR image. + dets : np.ndarray + (N, 6) array of detections (x1, y1, x2, y2, conf, cls). + + Returns + ------- + np.ndarray + (N, D) feature embeddings. + + """ + return self.model(frame, dets) diff --git a/ethology/reid/core/registry.py b/ethology/reid/core/registry.py new file mode 100644 index 00000000..4b9c27fd --- /dev/null +++ b/ethology/reid/core/registry.py @@ -0,0 +1,88 @@ +from collections import OrderedDict + +import torch + +from ethology.reid.core.config import ( + MODEL_TYPES, # , NR_CLASSES_DICT, TRAINED_URLS +) +from ethology.reid.core.factory import MODEL_FACTORY + +# from ethology.utils import logger as LOGGER # If needed, implement or set LOGGER + + +class ReIDModelRegistry: + """Encapsulates model registration and related utilities.""" + + @staticmethod + def show_downloadable_models(): + # LOGGER.info("Available .pt ReID models for automatic download") + # LOGGER.info(list(TRAINED_URLS.keys())) + pass + + @staticmethod + def get_model_name(model): + for name in MODEL_TYPES: + if name in model.name: + return name + return None + + @staticmethod + def get_model_url(model): + # return TRAINED_URLS.get(model.name, None) + return None + + @staticmethod + def load_pretrained_weights(model, weight_path): + device = "cpu" if not torch.cuda.is_available() else None + checkpoint = torch.load( + weight_path, + map_location=torch.device("cpu") if device == "cpu" else None, + weights_only=False, + encoding="latin1", + ) + state_dict = checkpoint.get("state_dict", checkpoint) + model_dict = model.state_dict() + new_state_dict = OrderedDict() + matched_layers, discarded_layers = [], [] + for k, v in state_dict.items(): + key = k[7:] if k.startswith("module.") else k + if key in model_dict and model_dict[key].size() == v.size(): + new_state_dict[key] = v + matched_layers.append(key) + else: + discarded_layers.append(key) + model_dict.update(new_state_dict) + model.load_state_dict(model_dict) + + @staticmethod + def show_available_models(): + # LOGGER.info("Available models:") + # LOGGER.info(list(MODEL_FACTORY.keys())) + pass + + @staticmethod + def get_nr_classes(weights): + # dataset_key = weights.name.split("_")[1] + # return NR_CLASSES_DICT.get(dataset_key, 1) + return 1 + + @staticmethod + def build_model( + name, + weights, + num_classes, + loss="softmax", + pretrained=True, + use_gpu=True, + ): + if name not in MODEL_FACTORY: + available = list(MODEL_FACTORY.keys()) + raise KeyError( + f"Unknown model '{name}'. Must be one of {available}" + ) + return MODEL_FACTORY[name]( + num_classes=num_classes, + loss=loss, + pretrained=pretrained, + use_gpu=use_gpu, + ) diff --git a/ethology/reid/core/reid_handler.py b/ethology/reid/core/reid_handler.py new file mode 100644 index 00000000..62d42209 --- /dev/null +++ b/ethology/reid/core/reid_handler.py @@ -0,0 +1,33 @@ +from pathlib import Path + +import numpy as np + +from ethology.reid.core.auto_backend import ReidAutoBackend + + +class ReID: + def __init__(self, weights: str | Path, device="cpu", half=False): + self.weights = Path(weights) + self.device = device + self.half = half + self.backend = ReidAutoBackend( + weights=self.weights, device=device, half=half + ) + self.model = self.backend.model + + def __call__(self, frame: np.ndarray, dets: np.ndarray) -> np.ndarray: + """Extract features for detections in a frame. + + Args: + frame: (H, W, C) BGR image + dets: (N, 6) detections (x1, y1, x2, y2, conf, cls) or similar. + + Returns: + embs: (N, D) embeddings. + + """ + if dets.shape[0] == 0: + return np.empty((0, 0)) + xyxy = dets[:, :4] + embs = self.model.get_features(xyxy, frame) + return embs diff --git a/examples/reid_mot_example.py b/examples/reid_mot_example.py new file mode 100644 index 00000000..304fd9b1 --- /dev/null +++ b/examples/reid_mot_example.py @@ -0,0 +1,22 @@ +"""Example: Using the new ReID trajectory utility with MOT results + +This script demonstrates how to use the ethology reid trajectory handler with a sample MOT output. +""" + +from ethology.reid.core.reid_handler import ReIDTrajectoryHandler + +# Example: Dummy MOT results (replace with your actual MOT output) +mot_results = [ + {"id": 1, "trajectory": [(0, 0), (1, 1), (2, 2)]}, + {"id": 2, "trajectory": [(5, 5), (6, 6), (7, 7)]}, +] + +# Initialize the handler (adjust parameters as needed) +reid_handler = ReIDTrajectoryHandler(model_name="osnet", device="cpu") + +# Run re-identification on the MOT results +reid_results = reid_handler.reidentify(mot_results) + +print("ReID Results:") +for item in reid_results: + print(item) diff --git a/notebook_cotracker_offline.py b/notebook_cotracker_offline.py new file mode 100644 index 00000000..ef424920 --- /dev/null +++ b/notebook_cotracker_offline.py @@ -0,0 +1,304 @@ +"""Offline tracking with CoTracker3.""" + +# %% +# Imports +import os +from datetime import datetime +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn.functional as F +from cotracker.utils.visualizer import read_video_from_path +from movement.io import load_bboxes, load_poses, save_poses + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + + +DEFAULT_DEVICE = ( + "cuda" + if torch.cuda.is_available() + else "mps" + if torch.backends.mps.is_available() + else "cpu" +) + +# %matplotlib widget + +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# Data paths +video_path = ( + "/home/sminano/swc/project_ethology/tap_models_crabs/" + "input/04.09.2023-04-Right_RE_test.mp4" +) + +ground_truth_data = Path( + "/home/sminano/swc/project_ethology/tap_models_crabs/input/04.09.2023-04-Right_RE_test_corrected_ST_SM_20241029_113207.csv" +) + +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# Parameters + +# query points +step_between_query_frames: int = 1000 +individuals_gt_ids: list[int] = [] + +# downsample video +scale_factor: float = 0.25 + +# clip video +chunk_start: int = 0 +chunk_width = 100 + + +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# Select query points + +ds_gt = load_bboxes.from_file( + file_path=ground_truth_data, + source_software="VIA-tracks", + use_frame_numbers_from_file=False, +) + + +# ------------------ +# Select individuals to use as query points +if len(individuals_gt_ids) == 0: + ds_gt_one = ds_gt +else: + ds_gt_one = ds_gt.isel(individuals=[i - 1 for i in individuals_gt_ids]) + +print(ds_gt_one) + +# Select frames +list_frames = list(range(ds_gt_one.sizes["time"])) +frames_to_select = np.array(list_frames)[ + chunk_start : chunk_start + chunk_width : step_between_query_frames +] # every N frame +print(frames_to_select) +# -------------------- + +# Prepare query points array +# it has frame as first column +queries_array = np.vstack( + [ + np.hstack( + [ + f + * np.ones((ds_gt_one.sizes["individuals"], 1)), # frame column + ds_gt_one.position.sel(time=f).values.T, # x, y columns + ] + ) + for f in range(ds_gt_one.sizes["time"]) + ] +) + +# Remove rows with nans in position +queries_array = queries_array[~np.any(np.isnan(queries_array), axis=1), :] + +# Filter selected query points +queries_sel = queries_array[ + [col in frames_to_select for col in queries_array[:, 0]], : +] + +print(queries_sel.shape) + + +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# Downsample queries by the same scale factor as the video +queries_downsampled = queries_sel * scale_factor +queries_downsampled[:, 0] = queries_sel[:, 0] +print(queries_downsampled.shape) # torch.Size([1, 614, 2]) +print(queries_downsampled) + +# convert to torch tensor and place on device +queries_downsampled_tensor: torch.Tensor = ( + torch.from_numpy(queries_downsampled).to(torch.float).to(DEFAULT_DEVICE) +) + + +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# Read video +# TODO: is it faster with sleap_io? yes! but then converting +# to torch is very slow +# %time video_full = read_video_from_path(video_path) # Wall time: 13.4 s +# %time video_full = sio.load_video(video_path) # Wall time: 27.4 ms +# %time video_full = np.array(sio.load_video(video_path)) + +video_full = read_video_from_path(video_path) +print(type(video_full)) +print(video_full.shape) # (614, 2160, 4096, 3) + +# as torch tensor +video_full = torch.from_numpy(video_full).permute(0, 3, 1, 2)[None] + +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# Downsample video +video_downsampled = F.interpolate( + video_full[0], scale_factor=scale_factor, mode="bilinear" +)[None] + +print(video_downsampled.shape) # torch.Size([1, 614, 3, 540, 1024]) +print(video_downsampled.device) + + +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# Select first part of the video only (to fit in GPU) +# video = video[:, : video.shape[1] // 8] +video_downsampled_chunk = video_downsampled[ + :, chunk_start : chunk_start + chunk_width, :, :, : +] # 75 frames +print(video_downsampled_chunk.shape) # torch.Size([1, 75, 3, 540, 1024]) + +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# Convert to float and place video on device +# Why do we need .float conversion? +# chatgpt: Mathematical operations like convolutions, normalizations, +# or matrix mults expect float32 or float16 + + +device = "cuda" +# video = video.float().to(device) +# video = video.half().to(device) # Use half precision for memory efficiency +video_downsampled_chunk = video_downsampled_chunk.to(torch.float).to( + device +) # torch.float16 + + +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# Visualize query points over frames + +# Create a list of frame numbers corresponding to each point +frame_numbers = queries_downsampled_tensor[:, 0].unique().tolist() + +for frame_number in frame_numbers: + if frame_number in list(range(video_downsampled_chunk.shape[1])): + # get the query points for the current frame + queries_one_frame = queries_downsampled_tensor[ + queries_downsampled_tensor[:, 0] == frame_number + ] + + fig, ax = plt.subplots(1, 1) + # plot frame + ax.imshow( + video_downsampled_chunk.permute(0, 1, -2, -1, -3)[ + 0, frame_number, :, :, : + ] + .cpu() + .numpy() + .astype(np.int32) + ) # B T C H W -> H W C + # plot query points + ax.scatter( + x=queries_one_frame[:, 1].cpu(), + y=queries_one_frame[:, 2].cpu(), + s=5, + c="red", + ) + + ax.set_title(f"Frame {frame_number}") + ax.set_xlim(0, video_downsampled_chunk.shape[4]) + ax.set_ylim(0, video_downsampled_chunk.shape[3]) + ax.invert_yaxis() + + +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# Get Offline CoTracker model +model = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline") + +# Use the model in half precision and move it to the GPU +# Note: this is for memory usage +model = model.to(device) # .half().to(device) # .to(torch.float16).to(device) + +print(model.model.window_len) + +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# all_half = all(p.dtype == torch.float16 for p in model.parameters()) +# print("All parameters are float16:", all_half) + +# for name, param in model.named_parameters(): +# # print(f"{name}: {param.dtype}") +# if param.dtype == torch.float32: +# param.data = param.data.to(torch.float16) +# print("PATATA") + +# for name, buffer in model.named_buffers(): +# print(f"{name}: {buffer.dtype}") + + +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# Run CoTracker +pred_tracks, pred_visibility = model( + video_downsampled_chunk, + queries=queries_downsampled_tensor[None], + backward_tracking=True, +) # B T N 2, B T N 1 + + +# from torch.cuda.amp import autocast +# model.eval() +# with torch.no_grad(), torch.autocast(device_type="cuda"): +# pred_tracks, pred_visibility = model( +# video, queries=queries[None], #backward_tracking=True +# ) # B T N 2, B T N 1 + +# %% +# TODO: Can I upsample the results to the original video res? +print( + pred_tracks.shape +) # (1, 307, 2, 2) --> Batch, Time, N of points, 2 (x,y) +print(pred_visibility.shape) + +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# Upsample results to the original video resolution + +pred_tracks_upsampled = pred_tracks * 1 / scale_factor + +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# Save as a movement dataset +# (n_frames, n_space, n_keypoints, n_individuals) + +# assuming 1 query is 1 individual +position_array = ( + pred_tracks_upsampled.permute(1, -1, 0, -2).cpu().numpy() +) # (T, 2, 1, Nqueries) +visibility_array = pred_visibility.cpu().numpy()[0] # (T, Nqueries) + +# set position to nan if non visible +# (improve this) +for i in range(visibility_array.shape[1]): + position_array[~visibility_array[:, i], :, :, i] = np.nan + +# ----------------------------- +# # get each track from its query point +# position_array_fix = np.vstack( +# [ +# position_array[ +# frames_to_select[i]:(frames_to_select[i+1] +# if i OOM + + +# # %% +# vis = Visualizer( +# save_dir="./output", +# linewidth=1, +# mode="cool", +# tracks_leave_trace=-1, +# fps=10, +# ) + +# vis.visualize( +# video, +# pred_tracks, # .to('cpu'), +# pred_visibility, +# query_frame=grid_query_frame, +# filename=f"queries_{timestamp}", +# ) + +# %% diff --git a/tests/test_unit/test_reid_handler.py b/tests/test_unit/test_reid_handler.py new file mode 100644 index 00000000..bc8a199c --- /dev/null +++ b/tests/test_unit/test_reid_handler.py @@ -0,0 +1,16 @@ +import numpy as np + +from ethology.reid.core.handler import ReIDHandler + + +def test_extract_features_shape(): + handler = ReIDHandler(weights="osnet_x0_25_imagenet.pth") + frame = np.random.randint(0, 255, (128, 64, 3), dtype=np.uint8) + dets = np.array( + [ + [10, 10, 50, 100, 0.9, 1], + [60, 20, 100, 110, 0.8, 2], + ] + ) + feats = handler.extract_features(frame, dets) + assert feats.shape[0] == dets.shape[0]