|
| 1 | +import torch |
| 2 | + |
| 3 | +from torch import nn |
| 4 | + |
| 5 | +from .BaseAdaptor import BaseAdaptor |
| 6 | + |
| 7 | + |
| 8 | +class MLPAdaptor(BaseAdaptor): |
| 9 | + """Class of module for domain adaptation with MLP.""" |
| 10 | + |
| 11 | + def __init__( |
| 12 | + self, |
| 13 | + features: torch.Tensor, |
| 14 | + n_dim: int = 20, |
| 15 | + n_hidden: list = [100], |
| 16 | + weight: float = 1e-3, |
| 17 | + ): |
| 18 | + """Set MLP model for domain adaptation. |
| 19 | +
|
| 20 | + Args: |
| 21 | + features (torch.Tensor): A feature of users or items. size = (n_user, n_feature) |
| 22 | + n_dim (int, optional): A number of dimention of embeddings. Defaults to 20. |
| 23 | + n_hidden (list, optional): A list of numbers of neuron for each hidden layers. Defaults to [100]. |
| 24 | + weight (float, optional): Loss weights for domain adaptation. Defaults to 1e-3. |
| 25 | + """ |
| 26 | + super().__init__(weight) |
| 27 | + self.features_embeddings = nn.Embedding.from_pretrained(features) |
| 28 | + self.features_embeddings.weight.requires_grad = False |
| 29 | + |
| 30 | + self.n_input = features.shape[1] |
| 31 | + self.n_hidden = n_hidden |
| 32 | + self.n_output = n_dim |
| 33 | + |
| 34 | + projection_layers = [nn.Linear(self.n_input, self.n_hidden[0]), nn.ReLU()] |
| 35 | + for i in range(len(self.n_hidden) - 1): |
| 36 | + layer = [nn.Linear(self.n_hidden[i], self.n_hidden[i + 1]), nn.ReLU()] |
| 37 | + projection_layers += layer |
| 38 | + projection_layers += [nn.Linear(self.n_hidden[-1], self.n_output)] |
| 39 | + |
| 40 | + self.projector = nn.Sequential(*projection_layers) |
| 41 | + |
| 42 | + def forward(self, indices: torch.Tensor, embeddings: torch.Tensor): |
| 43 | + """Method to calculate loss for domain adaptation. |
| 44 | +
|
| 45 | + Args: |
| 46 | + indices (torch.Tensor): Indices of users or items. size = (n_user, n_sample) |
| 47 | + embeddings (torch.Tensor): The embeddings corresponding to indices. size = (n_user, n_sample, n_dim) |
| 48 | +
|
| 49 | + Returns: |
| 50 | + [torch.Tensor]: loss for domain adaptation. dim = 0. |
| 51 | + """ |
| 52 | + features = self.features_embeddings(indices) |
| 53 | + projection = self.projector(features) |
| 54 | + dist = torch.sqrt(torch.pow(projection - embeddings, 2).sum(axis=2)) |
| 55 | + return self.weight * dist.sum() |
0 commit comments