Skip to content

Commit

Permalink
Merge pull request #875 from marrlab/coral
Browse files Browse the repository at this point in the history
coral is just cross domain MMD distance
  • Loading branch information
smilesun authored Sep 17, 2024
2 parents 9da2dee + 5a164a9 commit 77b64d6
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 0 deletions.
48 changes: 48 additions & 0 deletions domainlab/algos/trainers/mmd_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Alexej, Xudong
"""
import torch
from domainlab.algos.trainers.train_basic import TrainerBasic


class TrainerMMDBase(TrainerBasic):
"""
causal matching
"""
def my_cdist(self, x1, x2):
"""
distance for Gaussian
"""
# along the last dimension
x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
# x_2_norm is [batchsize, 1]
# matrix multiplication (2nd, 3rd) and addition to first argument
# X1[batchsize, dimfeat] * X2[dimfeat, batchsize)
# alpha: Scaling factor for the matrix product (default: 1)
# x2_norm.transpose(-2, -1) is row vector
# x_1_norm is column vector
res = torch.addmm(x2_norm.transpose(-2, -1),
x1,
x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
return res.clamp_min_(1e-30)

def gaussian_kernel(self, x, y):
"""
kernel for MMD
"""
gamma = [0.001, 0.01, 0.1, 1, 10, 100, 1000]
dist = self.my_cdist(x, y)
tensor = torch.zeros_like(dist)
for g in gamma:
tensor.add_(torch.exp(dist.mul(-g)))
return tensor

def mmd(self, x, y):
"""
maximum mean discrepancy
"""
kxx = self.gaussian_kernel(x, x).mean()
kyy = self.gaussian_kernel(y, y).mean()
kxy = self.gaussian_kernel(x, y).mean()
return kxx + kyy - 2 * kxy
49 changes: 49 additions & 0 deletions domainlab/algos/trainers/train_coral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
Deep CORAL: Correlation Alignment for Deep
Domain Adaptation
[au] Alexej, Xudong
"""
from domainlab.algos.trainers.mmd_base import TrainerMMDBase
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


class TrainerCoral(TrainerMMDBase):
"""
cross domain MMD
"""
def cross_domain_mmd(self, tuple_data_domains_batch):
"""
domain-pairwise mmd
"""
list_cross_domain_mmd = []
list_domain_erm_loss = []
num_domains = len(tuple_data_domains_batch)
for ind_domain_a in range(num_domains):
data_a, y_a, *_ = tuple_data_domains_batch[ind_domain_a]
feat_a = self.get_model().extract_semantic_feat(data_a)
list_domain_erm_loss.append(sum(self.get_model().cal_task_loss(data_a, y_a)))
for ind_domain_b in range(ind_domain_a, num_domains):
data_b, *_ = tuple_data_domains_batch[ind_domain_b]
feat_b = self.get_model().extract_semantic_feat(data_b)
mmd = self.mmd(feat_a, feat_b)
list_cross_domain_mmd.append(mmd)
return list_domain_erm_loss, list_cross_domain_mmd

def tr_epoch(self, epoch):
list_loaders = list(self.dict_loader_tr.values())
loaders_zip = zip(*list_loaders)
self.model.train()
self.model.convert4backpack()
self.epo_loss_tr = 0

for ind_batch, tuple_data_domains_batch in enumerate(loaders_zip):
self.optimizer.zero_grad()
list_domain_erm_loss, list_cross_domain_mmd = self.cross_domain_mmd(tuple_data_domains_batch)
loss = sum(list_domain_erm_loss) + get_gamma_reg(self.aconf, self.name) * sum(list_cross_domain_mmd)
loss.backward()
self.optimizer.step()
self.epo_loss_tr += loss.detach().item()
self.after_batch(epoch, ind_batch)

flag_stop = self.observer.update(epoch) # notify observer
return flag_stop
2 changes: 2 additions & 0 deletions domainlab/algos/trainers/zoo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from domainlab.algos.trainers.train_fishr import TrainerFishr
from domainlab.algos.trainers.train_irm import TrainerIRM
from domainlab.algos.trainers.train_causIRL import TrainerCausalIRL
from domainlab.algos.trainers.train_coral import TrainerCoral


class TrainerChainNodeGetter(object):
Expand Down Expand Up @@ -57,6 +58,7 @@ def __call__(self, lst_candidates=None, default=None, lst_excludes=None):
chain = TrainerIRM(chain)
chain = TrainerHyperScheduler(chain)
chain = TrainerCausalIRL(chain)
chain = TrainerCoral(chain)
node = chain.handle(self.request)
head = node
while self._list_str_trainer:
Expand Down
13 changes: 13 additions & 0 deletions tests/test_coral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
end-end test
"""
from tests.utils_test import utils_test_algo


def test_coral():
"""
coral
"""
args = "--te_d 0 --tr_d 3 7 --bs=32 --debug --task=mnistcolor10 \
--model=erm --nname=conv_bn_pool_2 --trainer=coral"
utils_test_algo(args)

0 comments on commit 77b64d6

Please sign in to comment.