Skip to content

Commit c1ca9ba

Browse files
authored
Merge pull request #31 from hand10ryo/feature/adaptors
Feature/adaptors
2 parents 9e96b1e + a051ac6 commit c1ca9ba

21 files changed

+1356
-229
lines changed

PyTorchCML/adaptors/BaseAdaptor.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from torch import nn
2+
3+
4+
class BaseAdaptor(nn.Module):
5+
"""Astract class of module for domain adaptation."""
6+
7+
def __init__(self, weight):
8+
"""Set some parameters
9+
10+
Args:
11+
weight (float, optional): Loss weights for domain adaptation. Defaults to 1e-3.
12+
"""
13+
super().__init__()
14+
self.weight = weight
15+
16+
def forward(self, indices, embeddings):
17+
"""Method to calculate loss for domain adaptation.
18+
19+
Args:
20+
indices (torch.Tensor): Indices of users or items. size = (n_user, n_sample)
21+
embeddings (torch.Tensor): The embeddings corresponding to indices. size = (n_user, n_sample, n_dim)
22+
23+
Raises:
24+
NotImplementedError: [description]
25+
"""
26+
raise NotImplementedError

PyTorchCML/adaptors/MLPAdaptor.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
3+
from torch import nn
4+
5+
from .BaseAdaptor import BaseAdaptor
6+
7+
8+
class MLPAdaptor(BaseAdaptor):
9+
"""Class of module for domain adaptation with MLP."""
10+
11+
def __init__(
12+
self,
13+
features: torch.Tensor,
14+
n_dim: int = 20,
15+
n_hidden: list = [100],
16+
weight: float = 1e-3,
17+
):
18+
"""Set MLP model for domain adaptation.
19+
20+
Args:
21+
features (torch.Tensor): A feature of users or items. size = (n_user, n_feature)
22+
n_dim (int, optional): A number of dimention of embeddings. Defaults to 20.
23+
n_hidden (list, optional): A list of numbers of neuron for each hidden layers. Defaults to [100].
24+
weight (float, optional): Loss weights for domain adaptation. Defaults to 1e-3.
25+
"""
26+
super().__init__(weight)
27+
self.features_embeddings = nn.Embedding.from_pretrained(features)
28+
self.features_embeddings.weight.requires_grad = False
29+
30+
self.n_input = features.shape[1]
31+
self.n_hidden = n_hidden
32+
self.n_output = n_dim
33+
34+
projection_layers = [nn.Linear(self.n_input, self.n_hidden[0]), nn.ReLU()]
35+
for i in range(len(self.n_hidden) - 1):
36+
layer = [nn.Linear(self.n_hidden[i], self.n_hidden[i + 1]), nn.ReLU()]
37+
projection_layers += layer
38+
projection_layers += [nn.Linear(self.n_hidden[-1], self.n_output)]
39+
40+
self.projector = nn.Sequential(*projection_layers)
41+
42+
def forward(self, indices: torch.Tensor, embeddings: torch.Tensor):
43+
"""Method to calculate loss for domain adaptation.
44+
45+
Args:
46+
indices (torch.Tensor): Indices of users or items. size = (n_user, n_sample)
47+
embeddings (torch.Tensor): The embeddings corresponding to indices. size = (n_user, n_sample, n_dim)
48+
49+
Returns:
50+
[torch.Tensor]: loss for domain adaptation. dim = 0.
51+
"""
52+
features = self.features_embeddings(indices)
53+
projection = self.projector(features)
54+
dist = torch.sqrt(torch.pow(projection - embeddings, 2).sum(axis=2))
55+
return self.weight * dist.sum()

PyTorchCML/adaptors/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# flake8: noqa
2+
from .BaseAdaptor import BaseAdaptor
3+
from .MLPAdaptor import MLPAdaptor

PyTorchCML/losses/BaseLoss.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@ def __init__(self, regularizers: list = []):
1111

1212
def forward(
1313
self, embeddings_dict: dict, batch: torch.Tensor, column_names: dict
14+
) -> torch.Tensor:
15+
loss = self.main(embeddings_dict, batch, column_names)
16+
loss += self.regularize(embeddings_dict)
17+
return loss
18+
19+
def main(
20+
self, embeddings_dict: dict, batch: torch.Tensor, column_names: dict
1421
) -> torch.Tensor:
1522
"""
1623
Args:
@@ -34,18 +41,18 @@ def forward(
3441
3542
--- example code ---
3643
37-
# embeddings_dict = {
38-
# "user_embedding": user_emb,
39-
# "pos_item_embedding": pos_item_emb,
40-
# "neg_item_embedding": neg_item_emb,
41-
# "user_bias": user_bias,
42-
# "pos_item_bias": pos_item_bias,
43-
# "neg_item_bias": neg_item_bias,
44-
#}
44+
embeddings_dict = {
45+
"user_embedding": user_embedding,
46+
"pos_item_embedding": pos_item_embedding,
47+
"neg_item_embedding": neg_item_embedding,
48+
"user_bias": user_bias,
49+
"pos_item_bias": pos_item_bias,
50+
"neg_item_bias": neg_item_bias,
51+
}
4552
4653
loss = loss_function(embeddings_dict, batch, column_names)
47-
reg = self.regularize(embeddings_dict)
48-
return loss + reg
54+
55+
return loss
4956
"""
5057

5158
raise NotImplementedError

PyTorchCML/losses/LogitPairwiseLoss.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ def __init__(self, regularizers: list = []):
1111
super().__init__(regularizers)
1212
self.LogSigmoid = nn.LogSigmoid()
1313

14-
def forward(
14+
def main(
1515
self, embeddings_dict: dict, batch: torch.Tensor, column_names: dict
1616
) -> torch.Tensor:
17-
"""Method of forwarding loss
17+
"""Method of forwarding main loss
1818
1919
Args:
2020
embeddings_dict (dict): A dictionary of embddings which has following key and values.
@@ -55,6 +55,5 @@ def forward(
5555
neg_loss = -nn.LogSigmoid()(-neg_y_hat).sum()
5656

5757
loss = (pos_loss + neg_loss) / (n_batch * (n_pos + n_neg))
58-
reg = self.regularize(embeddings_dict)
5958

60-
return loss + reg
59+
return loss

PyTorchCML/losses/MSEPairwiseLoss.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
class MSEPairwiseLoss(BaseLoss):
77
"""Class of loss for MSE in implicit feedback"""
88

9-
def forward(
9+
def main(
1010
self, embeddings_dict: dict, batch: torch.Tensor, column_names: dict
1111
) -> torch.Tensor:
12-
"""Method of forwarding loss
12+
"""Method of forwarding main loss
1313
1414
Args:
1515
embeddings_dict (dict): A dictionary of embddings which has following key and values.
@@ -53,6 +53,5 @@ def forward(
5353
neg_loss = (torch.sigmoid(neg_r_hat) ** 2).sum()
5454

5555
loss = (pos_loss + neg_loss) / (n_batch * (n_pos + n_neg))
56-
reg = self.regularize(embeddings_dict)
5756

58-
return loss + reg
57+
return loss

PyTorchCML/losses/MinTripletLoss.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ def __init__(self, margin: float = 1, regularizers: list = []):
1111
self.margin = margin
1212
self.ReLU = nn.ReLU()
1313

14-
def forward(
14+
def main(
1515
self, embeddings_dict: dict, batch: torch.Tensor, column_names: dict
1616
) -> torch.Tensor:
17-
"""Method of forwarding loss
17+
"""Method of forwarding main loss
1818
1919
Args:
2020
embeddings_dict (dict): A dictionary of embddings which has following key and values.
@@ -41,6 +41,5 @@ def forward(
4141
pairwiseloss = self.ReLU(self.margin + pos_dist ** 2 - min_neg_dist.values ** 2)
4242

4343
loss = torch.mean(pairwiseloss)
44-
reg = self.regularize(embeddings_dict)
4544

46-
return loss + reg
45+
return loss

PyTorchCML/losses/RelevancePairwiseLoss.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ def __init__(self, regularizers: list = [], delta: str = "logistic"):
2828
else:
2929
raise NotImplementedError
3030

31-
def forward(
31+
def main(
3232
self, embeddings_dict: dict, batch: torch.Tensor, column_names: dict
3333
) -> torch.Tensor:
34-
"""Method of forwarding loss
34+
"""Method of forwarding main loss
3535
3636
Args:
3737
embeddings_dict (dict): A dictionary of embddings which has following key and values.
@@ -79,6 +79,5 @@ def forward(
7979
neg_loss = self.delta_neg(neg_r_hat).sum()
8080

8181
loss = (pos_loss + neg_loss) / (n_batch * (n_pos + n_neg))
82-
reg = self.regularize(embeddings_dict)
8382

84-
return loss + reg
83+
return loss

PyTorchCML/losses/SumTripletLoss.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ def __init__(self, margin: float = 1, regularizers: list = []):
1212
self.margin = margin
1313
self.ReLU = nn.ReLU()
1414

15-
def forward(
15+
def main(
1616
self, embeddings_dict: dict, batch: torch.Tensor, column_names: dict
1717
) -> torch.Tensor:
18-
"""Method of forwarding loss
18+
"""Method of forwarding main loss
1919
2020
Args:
2121
embeddings_dict (dict): A dictionary of embddings which has following key and values
@@ -39,6 +39,5 @@ def forward(
3939

4040
tripletloss = self.ReLU(self.margin + pos_dist ** 2 - neg_dist ** 2)
4141
loss = torch.mean(tripletloss)
42-
reg = self.regularize(embeddings_dict)
4342

44-
return loss + reg
43+
return loss

PyTorchCML/models/BaseEmbeddingModel.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import torch
44
from torch import nn
55

6+
from ..adaptors import BaseAdaptor
7+
68

79
class BaseEmbeddingModel(nn.Module):
810
"""Class of abstract embeddings model getting embedding or predict relevance from indices."""
@@ -15,6 +17,8 @@ def __init__(
1517
max_norm: Optional[float] = 1,
1618
user_embedding_init: Optional[torch.Tensor] = None,
1719
item_embedding_init: Optional[torch.Tensor] = None,
20+
user_adaptor: Optional[BaseAdaptor] = None,
21+
item_adaptor: Optional[BaseAdaptor] = None,
1822
):
1923
"""Set embeddings.
2024
@@ -31,6 +35,8 @@ def __init__(
3135
self.n_item = n_item
3236
self.n_dim = n_dim
3337
self.max_norm = max_norm
38+
self.user_adaptor = user_adaptor
39+
self.item_adaptor = item_adaptor
3440

3541
if user_embedding_init is None:
3642
self.user_embedding = nn.Embedding(
@@ -71,7 +77,7 @@ def predict(self, pairs: torch.Tensor) -> torch.Tensor:
7177
"""Method of predicting relevance for each pair of user and item.
7278
7379
Args:
74-
pairs (torch.Tensor): Tensor which columns are [user_id, item_id]
80+
pairs (torch.Tensor): Tensor whose columns are [user_id, item_id]
7581
7682
Raises:
7783
NotImplementedError: [description]

0 commit comments

Comments
 (0)