Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ethology/reid/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# ReID module for ethology
1 change: 1 addition & 0 deletions ethology/reid/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Backbones for ReID models
337 changes: 337 additions & 0 deletions ethology/reid/backbones/hacnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,337 @@
"""HACNN backbone for person re-identification."""

import torch
from torch import nn
from torch.nn import functional as F

__all__ = ["HACNN"]


class ConvBlock(nn.Module):
def __init__(self, in_c, out_c, k, s=1, p=0):
"""Convolutional block with batch norm and ReLU."""
super().__init__()
self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
self.bn = nn.BatchNorm2d(out_c)

def forward(self, x):
return F.relu(self.bn(self.conv(x)))


class InceptionA(nn.Module):
def __init__(self, in_channels, out_channels):
"""InceptionA block."""
super().__init__()
mid_channels = out_channels // 4
self.stream1 = nn.Sequential(
ConvBlock(in_channels, mid_channels, 1),
ConvBlock(mid_channels, mid_channels, 3, p=1),
)
self.stream2 = nn.Sequential(
ConvBlock(in_channels, mid_channels, 1),
ConvBlock(mid_channels, mid_channels, 3, p=1),
)
self.stream3 = nn.Sequential(
ConvBlock(in_channels, mid_channels, 1),
ConvBlock(mid_channels, mid_channels, 3, p=1),
)
self.stream4 = nn.Sequential(
nn.AvgPool2d(3, stride=1, padding=1),
ConvBlock(in_channels, mid_channels, 1),
)

def forward(self, x):
s1 = self.stream1(x)
s2 = self.stream2(x)
s3 = self.stream3(x)
s4 = self.stream4(x)
y = torch.cat([s1, s2, s3, s4], dim=1)
return y


class InceptionB(nn.Module):
def __init__(self, in_channels, out_channels):
"""InceptionB block."""
super().__init__()
mid_channels = out_channels // 4
self.stream1 = nn.Sequential(
ConvBlock(in_channels, mid_channels, 1),
ConvBlock(mid_channels, mid_channels, 3, s=2, p=1),
)
self.stream2 = nn.Sequential(
ConvBlock(in_channels, mid_channels, 1),
ConvBlock(mid_channels, mid_channels, 3, p=1),
ConvBlock(mid_channels, mid_channels, 3, s=2, p=1),
)
self.stream3 = nn.Sequential(
nn.MaxPool2d(3, stride=2, padding=1),
ConvBlock(in_channels, mid_channels * 2, 1),
)

def forward(self, x):
s1 = self.stream1(x)
s2 = self.stream2(x)
s3 = self.stream3(x)
y = torch.cat([s1, s2, s3], dim=1)
return y


class SpatialAttn(nn.Module):
def __init__(self):
"""Spatial attention block."""
super().__init__()
self.conv1 = ConvBlock(1, 1, 3, s=2, p=1)
self.conv2 = ConvBlock(1, 1, 1)

def forward(self, x):
x = x.mean(1, keepdim=True)
x = self.conv1(x)
x = F.interpolate(
x,
(x.size(2) * 2, x.size(3) * 2),
mode="bilinear",
align_corners=True,
)
x = self.conv2(x)
return x


class ChannelAttn(nn.Module):
def __init__(self, in_channels, reduction_rate=16):
"""Channel attention block."""
super().__init__()
assert in_channels % reduction_rate == 0
self.conv1 = ConvBlock(in_channels, in_channels // reduction_rate, 1)
self.conv2 = ConvBlock(in_channels // reduction_rate, in_channels, 1)

def forward(self, x):
x = F.avg_pool2d(x, x.size()[2:])
x = self.conv1(x)
x = self.conv2(x)
return x


class SoftAttn(nn.Module):
def __init__(self, in_channels):
"""Soft attention block."""
super().__init__()
self.spatial_attn = SpatialAttn()
self.channel_attn = ChannelAttn(in_channels)
self.conv = ConvBlock(in_channels, in_channels, 1)

def forward(self, x):
y_spatial = self.spatial_attn(x)
y_channel = self.channel_attn(x)
y = y_spatial * y_channel
y = torch.sigmoid(self.conv(y))
return y


class HardAttn(nn.Module):
def __init__(self, in_channels):
"""Hard attention block."""
super().__init__()
self.fc = nn.Linear(in_channels, 4 * 2)
self.init_params()

def init_params(self):
self.fc.weight.data.zero_()
self.fc.bias.data.copy_(
torch.tensor(
[0, -0.75, 0, -0.25, 0, 0.25, 0, 0.75], dtype=torch.float
)
)

def forward(self, x):
x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1))
theta = torch.tanh(self.fc(x))
theta = theta.view(-1, 4, 2)
return theta


class HarmAttn(nn.Module):
def __init__(self, in_channels):
"""Harmonious attention block."""
super().__init__()
self.soft_attn = SoftAttn(in_channels)
self.hard_attn = HardAttn(in_channels)

def forward(self, x):
y_soft_attn = self.soft_attn(x)
theta = self.hard_attn(x)
return y_soft_attn, theta


class HACNN(nn.Module):
def __init__(
self,
num_classes,
loss="softmax",
nchannels=None,
feat_dim=512,
learn_region=True,
use_gpu=True,
**kwargs,
):
"""Harmonious Attention Convolutional Neural Network (HACNN) for person re-identification."""
super().__init__()
if nchannels is None:
nchannels = [128, 256, 384]
self.loss = loss
self.learn_region = learn_region
self.use_gpu = use_gpu
self.conv = ConvBlock(3, 32, 3, s=2, p=1)
self.inception1 = nn.Sequential(
InceptionA(32, nchannels[0]),
InceptionB(nchannels[0], nchannels[0]),
)
self.ha1 = HarmAttn(nchannels[0])
self.inception2 = nn.Sequential(
InceptionA(nchannels[0], nchannels[1]),
InceptionB(nchannels[1], nchannels[1]),
)
self.ha2 = HarmAttn(nchannels[1])
self.inception3 = nn.Sequential(
InceptionA(nchannels[1], nchannels[2]),
InceptionB(nchannels[2], nchannels[2]),
)
self.ha3 = HarmAttn(nchannels[2])
self.fc_global = nn.Sequential(
nn.Linear(nchannels[2], feat_dim),
nn.BatchNorm1d(feat_dim),
nn.ReLU(),
)
self.classifier_global = nn.Linear(feat_dim, num_classes)
if self.learn_region:
self.init_scale_factors()
self.local_conv1 = InceptionB(32, nchannels[0])
self.local_conv2 = InceptionB(nchannels[0], nchannels[1])
self.local_conv3 = InceptionB(nchannels[1], nchannels[2])
self.fc_local = nn.Sequential(
nn.Linear(nchannels[2] * 4, feat_dim),
nn.BatchNorm1d(feat_dim),
nn.ReLU(),
)
self.classifier_local = nn.Linear(feat_dim, num_classes)
self.feat_dim = feat_dim * 2
else:
self.feat_dim = feat_dim

def init_scale_factors(self):
"""Initialize scale factors for STN."""
self.scale_factors = []
self.scale_factors.append(
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)
)
self.scale_factors.append(
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)
)
self.scale_factors.append(
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)
)
self.scale_factors.append(
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)
)

def stn(self, x, theta):
"""Spatial transformer network."""
grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid)
return x

def transform_theta(self, theta_i, region_idx):
"""Transform theta for a given region."""
scale_factors = self.scale_factors[region_idx]
theta = torch.zeros(theta_i.size(0), 2, 3)
theta[:, :, :2] = scale_factors
theta[:, :, -1] = theta_i
if self.use_gpu:
theta = theta.to(next(self.parameters()).device)
return theta

def forward(self, x):
"""Forward pass."""
assert x.size(2) == 160 and x.size(3) == 64, (
f"Input size does not match, expected (160, 64) but got ({x.size(2)}, {x.size(3)})"
)
x = self.conv(x)
x1 = self.inception1(x)
x1_attn, x1_theta = self.ha1(x1)
x1_out = x1 * x1_attn
if self.learn_region:
x1_local_list = []
for region_idx in range(4):
x1_theta_i = x1_theta[:, region_idx, :]
x1_theta_i = self.transform_theta(x1_theta_i, region_idx)
x1_trans_i = self.stn(x, x1_theta_i)
x1_trans_i = F.interpolate(
x1_trans_i, (24, 28), mode="bilinear", align_corners=True
)
x1_local_i = self.local_conv1(x1_trans_i)
x1_local_list.append(x1_local_i)
x2 = self.inception2(x1_out)
x2_attn, x2_theta = self.ha2(x2)
x2_out = x2 * x2_attn
if self.learn_region:
x2_local_list = []
for region_idx in range(4):
x2_theta_i = x2_theta[:, region_idx, :]
x2_theta_i = self.transform_theta(x2_theta_i, region_idx)
x2_trans_i = self.stn(x1_out, x2_theta_i)
x2_trans_i = F.interpolate(
x2_trans_i, (12, 14), mode="bilinear", align_corners=True
)
x2_local_i = x2_trans_i + x1_local_list[region_idx]
x2_local_i = self.local_conv2(x2_local_i)
x2_local_list.append(x2_local_i)
x3 = self.inception3(x2_out)
x3_attn, x3_theta = self.ha3(x3)
x3_out = x3 * x3_attn
if self.learn_region:
x3_local_list = []
for region_idx in range(4):
x3_theta_i = x3_theta[:, region_idx, :]
x3_theta_i = self.transform_theta(x3_theta_i, region_idx)
x3_trans_i = self.stn(x2_out, x3_theta_i)
x3_trans_i = F.interpolate(
x3_trans_i, (6, 7), mode="bilinear", align_corners=True
)
x3_local_i = x3_trans_i + x2_local_list[region_idx]
x3_local_i = self.local_conv3(x3_local_i)
x3_local_list.append(x3_local_i)
x_global = F.avg_pool2d(x3_out, x3_out.size()[2:]).view(
x3_out.size(0), x3_out.size(1)
)
x_global = self.fc_global(x_global)
if self.learn_region:
x_local_list = []
for region_idx in range(4):
x_local_i = x3_local_list[region_idx]
x_local_i = F.avg_pool2d(x_local_i, x_local_i.size()[2:]).view(
x_local_i.size(0), -1
)
x_local_list.append(x_local_i)
x_local = torch.cat(x_local_list, 1)
x_local = self.fc_local(x_local)
if not self.training:
if self.learn_region:
x_global = x_global / x_global.norm(p=2, dim=1, keepdim=True)
x_local = x_local / x_local.norm(p=2, dim=1, keepdim=True)
return torch.cat([x_global, x_local], 1)
else:
return x_global
prelogits_global = self.classifier_global(x_global)
if self.learn_region:
prelogits_local = self.classifier_local(x_local)
if self.loss == "softmax":
if self.learn_region:
return (prelogits_global, prelogits_local)
else:
return prelogits_global
elif self.loss == "triplet":
if self.learn_region:
return (prelogits_global, prelogits_local), (x_global, x_local)
else:
return prelogits_global, x_global
else:
raise KeyError(f"Unsupported loss: {self.loss}")
Loading