diff --git a/pypots/imputation/csdi/data.py b/pypots/imputation/csdi/data.py index e0cfc894..b2798ca0 100644 --- a/pypots/imputation/csdi/data.py +++ b/pypots/imputation/csdi/data.py @@ -69,14 +69,14 @@ def _fetch_data_from_array(self, idx: int) -> Iterable: X = self.X[idx].to(torch.float32) X_intact, X, missing_mask, indicating_mask = mcar(X, p=self.rate) - observed_data = X_intact - observed_mask = missing_mask + indicating_mask + observed_data = X_intact # i.e. originally observed data + observed_mask = missing_mask + indicating_mask # i.e. originally missing masks observed_tp = ( torch.arange(0, self.n_steps, dtype=torch.float32) if self.time_points is None else self.time_points[idx].to(torch.float32) ) - gt_mask = missing_mask + gt_mask = missing_mask # missing mask with ground truth masked for validation for_pattern_mask = ( gt_mask if self.for_pattern_mask is None else self.for_pattern_mask[idx] ) diff --git a/pypots/imputation/csdi/model.py b/pypots/imputation/csdi/model.py index 67764c55..70bf7734 100644 --- a/pypots/imputation/csdi/model.py +++ b/pypots/imputation/csdi/model.py @@ -128,9 +128,9 @@ def __init__( d_time_embedding: int, d_feature_embedding: int, d_diffusion_embedding: int, - is_unconditional: bool = False, - target_strategy: str = "random", n_diffusion_steps: int = 50, + target_strategy: str = "random", + is_unconditional: bool = False, schedule: str = "quad", beta_start: float = 0.0001, beta_end: float = 0.5, diff --git a/pypots/imputation/csdi/modules/core.py b/pypots/imputation/csdi/modules/core.py index b25fd190..958fb65d 100644 --- a/pypots/imputation/csdi/modules/core.py +++ b/pypots/imputation/csdi/modules/core.py @@ -127,7 +127,9 @@ def get_side_info(self, observed_tp, cond_mask): ) # (K,emb) feature_embed = feature_embed.unsqueeze(0).unsqueeze(0).expand(B, L, -1, -1) - side_info = torch.cat([time_embed, feature_embed], dim=-1) # (B,L,K,*) + side_info = torch.cat( + [time_embed, feature_embed], dim=-1 + ) # (B,L,K,emb+d_feature_embedding) side_info = side_info.permute(0, 3, 2, 1) # (B,*,K,L) if not self.is_unconditional: diff --git a/pypots/imputation/csdi/modules/submodules.py b/pypots/imputation/csdi/modules/submodules.py index 31e71fff..68248d7c 100644 --- a/pypots/imputation/csdi/modules/submodules.py +++ b/pypots/imputation/csdi/modules/submodules.py @@ -38,14 +38,6 @@ def __init__(self, n_diffusion_steps, d_embedding=128, d_projection=None): self.projection1 = nn.Linear(d_embedding, d_projection) self.projection2 = nn.Linear(d_projection, d_projection) - def forward(self, diffusion_step): - x = self.embedding[diffusion_step] - x = self.projection1(x) - x = F.silu(x) - x = self.projection2(x) - x = F.silu(x) - return x - @staticmethod def _build_embedding(n_steps, d_embedding=64): steps = torch.arange(n_steps).unsqueeze(1) # (T,1) @@ -58,6 +50,14 @@ def _build_embedding(n_steps, d_embedding=64): table = torch.cat([torch.sin(table), torch.cos(table)], dim=1) # (T,dim*2) return table + def forward(self, diffusion_step: int): + x = self.embedding[diffusion_step] + x = self.projection1(x) + x = F.silu(x) + x = self.projection2(x) + x = F.silu(x) + return x + class ResidualBlock(nn.Module): def __init__(self, d_side, n_channels, diffusion_embedding_dim, nheads): @@ -73,7 +73,7 @@ def __init__(self, d_side, n_channels, diffusion_embedding_dim, nheads): ) def forward_time(self, y, base_shape): - B, channel, K, L = base_shape + B, channel, K, L = base_shape # bz, 2, n_features, n_steps if L == 1: return y y = y.reshape(B, channel, K, L).permute(0, 2, 1, 3).reshape(B * K, channel, L) @@ -82,7 +82,7 @@ def forward_time(self, y, base_shape): return y def forward_feature(self, y, base_shape): - B, channel, K, L = base_shape + B, channel, K, L = base_shape # bz, 2, n_features, n_steps if K == 1: return y y = y.reshape(B, channel, K, L).permute(0, 3, 1, 2).reshape(B * L, channel, K) @@ -98,8 +98,8 @@ def forward(self, x, cond_info, diffusion_emb): diffusion_emb = self.diffusion_projection(diffusion_emb).unsqueeze( -1 ) # (B,channel,1) - y = x + diffusion_emb + y = x + diffusion_emb y = self.forward_time(y, base_shape) y = self.forward_feature(y, base_shape) # (B,channel,K*L) y = self.mid_projection(y) # (B,2*channel,K*L) @@ -155,12 +155,12 @@ def __init__( self.n_channels = n_channels def forward(self, x, cond_info, diffusion_step): - B, input_dim, K, L = x.shape + B, input_dim, K, L = x.shape # bz, 2, n_features, n_steps x = x.reshape(B, input_dim, K * L) - x = self.input_projection(x) + x = self.input_projection(x) # bz, n_channels, n_features*n_steps x = F.relu(x) - x = x.reshape(B, self.n_channels, K, L) + x = x.reshape(B, self.n_channels, K, L) # bz, n_channels, n_features, n_steps diffusion_emb = self.diffusion_embedding(diffusion_step) diff --git a/tests/imputation/csdi.py b/tests/imputation/csdi.py index ff37ec55..3bfcd888 100644 --- a/tests/imputation/csdi.py +++ b/tests/imputation/csdi.py @@ -48,6 +48,7 @@ class TestCSDI(unittest.TestCase): d_time_embedding=32, d_feature_embedding=3, d_diffusion_embedding=32, + n_diffusion_steps=10, n_heads=1, epochs=EPOCHS, saving_path=saving_path,