Skip to content

Commit

Permalink
commit code
Browse files Browse the repository at this point in the history
  • Loading branch information
Leminhbinh0209 authored Nov 17, 2023
1 parent f90a15e commit 619c623
Show file tree
Hide file tree
Showing 15 changed files with 1,591 additions and 1 deletion.
55 changes: 54 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,54 @@
# Gradient Alignment for Cross-Domain Face Anti-Spoofing
# Gradient Alignment for Cross-Domain Face Anti-Spoofing (ID: 1186)

## Overview

<p align="center">
<img src="asset/objective.png" width="900" alt="overall pipeline">
<p>

## 1. Installation
- Ubuntu 18.04.5 LTS
- CUDA 11.3
- Python 3.6.12
- pytorch == 1.10.1
## 2. Dataset

- Idiap Replay Attack [[paper](https://ieeexplore.ieee.org/document/6313548)]
- OULU-NPU [[paper](http://ieeexplore.ieee.org/document/7961798)]
- CASIA-MFSD [[paper](https://ieeexplore.ieee.org/document/6199754)]
- MSU-MFSD [[paper](http://biometrics.cse.msu.edu/Publications/Face/WenHanJain_FaceSpoofDetection_TIFS15.pdf)]

#### Data pre-processing: Follow the preprocessing steps in [SAFAS](https://github.com/sunyiyou/SAFAS).

## 3. Training
Runing
```python train.py --config ./configs/ICM2O.yaml```

| Methods | $\mathrm{ICM} \rightarrow \mathrm{O}$ | | $\mathrm{OCM} \rightarrow \mathrm{I}$ | | $\mathrm{OCI} \rightarrow \mathrm{M}$ | | $\mathrm{OMI} \rightarrow \mathrm{C}$ | |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| | HTER $\downarrow$ | $\mathrm{AUC} \uparrow$ | HTER $\downarrow$ | $\mathrm{AUC} \uparrow$ | HTER $\downarrow$ | $\mathrm{AUC} \uparrow$ | HTER $\downarrow$ | $\mathrm{AUC} \uparrow$ |
| MMD-AAE | 40.98 | 63.08 | 31.58 | 75.18 | 27.08 | 83.19 | 44.59 | 58.29 |
| MADDG | 27.98 | 80.02 | 22.19 | 84.99 | 17.69 | 88.06 | 24.50 | 84.51 |
| RFM | 16.45 | 91.16 | 17.30 | 90.48 | 13.89 | 93.98 | 20.27 | 88.16 |
| SSDG-M | 25.17 | 81.83 | 18.21 | 94.61 | 16.67 | 90.47 | 23.11 | 85.45 |
| SSDG-R | 15.61 | 91.54 | 11.71 | 96.59 | 7.38 | 97.17 | 10.44 | 95.94 |
| $\mathrm{D}^2 \mathrm{AM} $ | 15.27 | 90.87 | 15.43 | 91.22 | 12.70 | 95.66 | 20.98 | 85.58 |
| SDA | 23.10 | 84.30 | 15.60 | 90.10 | 15.40 | 91.80 | 24.50 | 84.40 |
| DRDG | 15.63 | 91.75 | 15.56 | 91.79 | 12.43 | 95.81 | 19.05 | 88.79 |
| ANRL | 15.67 | 91.90 | 16.03 | 91.04 | 10.83 | 96.75 | 17.85 | 89.26 |
| SSAN | 13.72 | 93.63 | 8.88 | 96.79 | 6.67 | 98.75 | 10.00 | 96.67 |
| AMEL | 11.31 | 93.96 | 18.60 | 88.79 | 10.23 | 96.62 | 11.88 | 94.39 |
| EBDG | 15.66 | 92.02 | 18.69 | 92.28 | 9.56 | 97.17 | 18.34 | 90.01 |
| PathNet | 11.82 | 95.07 | 13.40 | 95.67 | 7.10 | 98.46 | 11.33 | 94.58 |
| IADG | 8.86 | 97.14 | 10.62 | 94.50 | 5.41 | 98.19 | 8.70 | 96.40 |
| SA-FAS | 10.00 | 96.23 | 6.58 | 97.54 | 5.95 | 96.55 | 8.78 | 95.37 |
| UDG-FAS | 10.97 | 95.36 | 5.86 | 98.62 | 5.95 | 98.47 | 9.82 | 96.76 |
| GAC-FAS (**ours**) | $8.60^{0.28}$ | $97.16^{0.40}$ | $4.29^{0.83}$ | $98.87^{0.60}$ | $5.00^{0.00}$ | $97.56^{0.06}$ | $8.20^{0.43}$ | $95.16^{0.09}$ |

## 4. Landscape visualization

[[paper](https://proceedings.neurips.cc/paper/7875-visualizing-the-loss-landscape-of-neural-nets.pdf)][[code](https://github.com/tomgoldstein/loss-landscape)][[software](http://paraview.org/)]

<p align="center">
<img src="asset/loss_landscape.png" width="900" alt="overall pipeline">
<p>
Binary file added asset/loss_landscape.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added asset/objective.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
50 changes: 50 additions & 0 deletions configs/ICM2O.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
protocol: "ICM2O"
train_set: ["Replay_attack", "CASIA_MFSD", "MSU_MFSD"]
test_set: ["OULU"]

running_name: ""

PATH:
data_folder: "./datasets/"
output_folder: "./logs/"

SYS:
num_gpus: 1
GPUs: "0"
num_workers: 4
MODEL:
model_name: "resnet18"
norm: True
usebias: False
image_size: 256
num_classes: 1

TRAIN:
pretrained: "imagenet"
batch_size: 96
lr: 0.005
fc_lr_scale: 10
weight_decay: 0.0001
momentum: 0.9
lr_step_size: 40
lr_gamma: 0.5
optimizer: "SGD"
scheduler: "step"
warming_epochs: 1
epochs: 150
loss_func: "bce"
logit_scale: 12
rotate: True
cutout: True
feat_loss: "supcon"
lambda_constrast: 0.2

minimizer: "gac-fas"
minimizer_warming: 10

GAC:
rho: 0.1
eta: 0.0
alpha: 0.0002
TEST:
eval_preq: 5
270 changes: 270 additions & 0 deletions dataloaders/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
import os
import io
import torch
from torch.utils.data import Dataset
import math
from glob import glob
import re
from .meta import DEVICE_INFOS
import numpy as np
from PIL import Image
import random

def list_dirs_at_depth(root_dir, depth):
if depth < 0:
return []
elif depth == 0:
return [root_dir]
else:
sub_dirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
return [d for sub_dir in sub_dirs for d in list_dirs_at_depth(sub_dir, depth-1)]

class FaceDataset(Dataset):
def __init__(self,
dataset_name,
root_dir,
is_train=True,
label=None,
transform=None,
map_size=32,
UUID=-1,
img_size=256,
test_per_video=1):
self.is_train = is_train
self.video_list = [folder for folder in list_dirs_at_depth(os.path.join(root_dir, 'train' if is_train else 'test'), 2) if len(os.listdir(folder)) > 0]
if label is not None and label != 'all':
self.video_list = list(filter(lambda x: label in x, self.video_list))
print(f"({root_dir.split('/')[-1]}) Total video: {len(self.video_list)}: {len([u for u in self.video_list if 'live' in u])} vs. {len([u for u in self.video_list if 'live' not in u])}" )

self.dataset_name = dataset_name
self.root_dir = root_dir
self.transform = transform
self.map_size = map_size
self.UUID = UUID
self.image_size = img_size

if not is_train:
self.frame_per_video = test_per_video
self.video_list = sum([self.video_list]*test_per_video, [])
else:
self.frame_per_video = 1

self.init_frame_list()
def __len__(self):
return len(self.video_list)

def shuffle(self):
if self.is_train:
random.shuffle(self.video_list)

def init_frame_list(self):
"""
Create dictionary of
"""
self.video_frame_list = dict(zip([os.path.join(self.root_dir, video_name) for video_name in self.video_list],
[[] for _ in self.video_list]))
for video_path in self.video_frame_list:

if not self.is_train:
"""
In the mode test, we only need on face per video
"""
all_crop_faces = glob(os.path.join(video_path, "crop_*.jpg"))
assert len(all_crop_faces) > 2, f"Cannot find the image in folder {video_path}"
# all_crop_faces.sort()
self.video_frame_list[video_path] = all_crop_faces # [len(all_crop_faces)//2:len(all_crop_faces)//2+1] # Select only one middle frame for reproducible
else:
all_crop_faces = glob(os.path.join(video_path, "crop_*.jpg"))
assert len(all_crop_faces) > 2, f"Cannot find the image in folder {video_path}"
self.video_frame_list[video_path] = all_crop_faces

return True

def get_client_from_video_name(self, video_name):
video_name = video_name.split('/')[-1]
if 'msu' in self.dataset_name.lower() or 'replay' in self.dataset_name.lower():
match = re.findall('client(\d\d\d)', video_name)
if len(match) > 0:
client_id = match[0]
else:
raise RuntimeError('no client')
elif 'oulu' in self.dataset_name.lower():
match = re.findall('(\d+)_\d$', video_name)
if len(match) > 0:
client_id = match[0]
else:
raise RuntimeError('no client')
elif 'casia' in self.dataset_name.lower():

match = re.findall('(\d+)_[H|N][R|M]_\d$', video_name)
if len(match) > 0:
client_id = match[0]
else:
print(f"Cannot find client from : {video_name}")
raise RuntimeError('no client')
else:
raise RuntimeError("no dataset found")
return client_id

def __getitem__(self, idx):
idx = idx % len(self.video_list) # Incase testing with many frame per video
video_name = self.video_list[idx]
spoofing_label = int('live' in video_name)
if self.dataset_name in DEVICE_INFOS:
if 'live' in video_name:
patterns = DEVICE_INFOS[self.dataset_name]['live']
elif 'spoof' in video_name:
patterns = DEVICE_INFOS[self.dataset_name]['spoof']
else:
raise RuntimeError(f"Cannot find the label infor from the video: {video_name}")
device_tag = None
for pattern in patterns:
if len(re.findall(pattern, video_name)) > 0:
if device_tag is not None:
raise RuntimeError("Multiple Match")
device_tag = pattern
if device_tag is None:
raise RuntimeError("No Match")
else:
device_tag = 'live' if spoofing_label else 'spoof'

client_id = self.get_client_from_video_name(video_name)

image_dir = os.path.join(self.root_dir, video_name)

if self.is_train:
image_x, _, _, = self.sample_image(image_dir, is_train=True)
transformed_image1 = self.transform(image_x)
transformed_image2 = self.transform(image_x, )


else:
image_x, _, _ = self.sample_image(image_dir, is_train=False, rep=None)
transformed_image1 = transformed_image2 = self.transform(image_x)



sample = {"image_x_v1": transformed_image1,
"image_x_v2": transformed_image2,
"label": spoofing_label,
"UUID": self.UUID,
'device_tag': device_tag,
'video': video_name,
'client_id': client_id}
return sample


def sample_image(self, image_dir, is_train=False, rep=None):
"""
rep is the parameter from the __getitem__ function to reduce randomness of test phase
"""
image_path = np.random.choice(self.video_frame_list[image_dir])
image_id = int(image_path.split('/')[-1].split('_')[-1].split('.')[0])

info_name = f"infov1_{image_id:04d}.npy"
info_path = os.path.join(image_dir, info_name)

try:
info = None
image = Image.open(image_path)
except:
if is_train:
return self.sample_image(image_dir, is_train)
else:
raise ValueError(f"Error in the file {info_path}")
return image, info, image_id * 5

class Identity(): # used for skipping transforms
def __call__(self, im):
return im

class RandomCutout(object):
def __init__(self, n_holes, p=0.5):
"""
Args:
n_holes (int): Number of patches to cut out of each image.
p (int): probability to apply cutout
"""
self.n_holes = n_holes
self.p = p

def rand_bbox(self, W, H, lam):
"""
Return a random box
"""
cut_rat = np.sqrt(1. - lam)
cut_w = np.int(W * cut_rat)
cut_h = np.int(H * cut_rat)

# uniform
cx = np.random.randint(W)
cy = np.random.randint(H)

bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)

return bbx1, bby1, bbx2, bby2

def __call__(self, img):
"""
Args:
img (Tensor): Tensor image of size (C, H, W).
Returns:
Tensor: Image with n_holes of dimension length x length cut out of it.
"""
if np.random.rand(1) > self.p:
return img

h = img.size(1)
w = img.size(2)
lam = np.random.beta(1.0, 1.0)
bbx1, bby1, bbx2, bby2 = self.rand_bbox(w, h, lam)
for n in range(self.n_holes):
img[:,bby1:bby2, bbx1:bbx2] = img[:,bby1:bby2, bbx1:bbx2].mean(dim=[-2,-1],keepdim=True)
return img

class RandomJPEGCompression(object):
def __init__(self, quality_min=30, quality_max=90, p=0.5):
assert 0 <= quality_min <= 100 and 0 <= quality_max <= 100
self.quality_min = quality_min
self.quality_max = quality_max
self.p = p
def __call__(self, img):
if np.random.rand(1) > self.p:
return img
# Choose a random quality for JPEG compression
quality = np.random.randint(self.quality_min, self.quality_max)

# Save the image to a bytes buffer using JPEG format
buffer = io.BytesIO()
img.save(buffer, format='JPEG', quality=quality)

# Reload the image from the buffer
img = Image.open(buffer)
return img

class RoundRobinDataset(Dataset):
def __init__(self, datasets):
self.datasets = datasets
self.lengths = [len(dataset) for dataset in datasets]
self.total_len = sum(self.lengths)

def __getitem__(self, index):
# Determine which dataset to sample from
dataset_id = index % len(self.datasets)

# Adjust index to fit within the chosen dataset's length
inner_index = index // len(self.datasets)
inner_index = inner_index % self.lengths[dataset_id]
return self.datasets[dataset_id][inner_index]

def shuffle(self):
for dataset in self.datasets:
dataset.shuffle()

def __len__(self):
# Return the length of the largest dataset times the number of datasets
return max(self.lengths) * len(self.datasets)
Loading

0 comments on commit 619c623

Please sign in to comment.