Skip to content

Commit

Permalink
Merge pull request #256 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Fixing CSDI `gt_mask` issue, and setting a fixed random seed for testing cases
  • Loading branch information
WenjieDu authored Dec 5, 2023
2 parents c597bb5 + 2542a0c commit 5c0ce3e
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 20 deletions.
6 changes: 3 additions & 3 deletions pypots/imputation/csdi/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ def _fetch_data_from_array(self, idx: int) -> Iterable:
X = self.X[idx].to(torch.float32)
X_intact, X, missing_mask, indicating_mask = mcar(X, p=self.rate)

observed_data = X_intact
observed_mask = missing_mask + indicating_mask
observed_data = X_intact # i.e. originally observed data
observed_mask = missing_mask + indicating_mask # i.e. originally missing masks
observed_tp = (
torch.arange(0, self.n_steps, dtype=torch.float32)
if self.time_points is None
else self.time_points[idx].to(torch.float32)
)
gt_mask = indicating_mask
gt_mask = missing_mask # missing mask with ground truth masked for validation
for_pattern_mask = (
gt_mask if self.for_pattern_mask is None else self.for_pattern_mask[idx]
)
Expand Down
4 changes: 2 additions & 2 deletions pypots/imputation/csdi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ def __init__(
d_time_embedding: int,
d_feature_embedding: int,
d_diffusion_embedding: int,
is_unconditional: bool = False,
target_strategy: str = "random",
n_diffusion_steps: int = 50,
target_strategy: str = "random",
is_unconditional: bool = False,
schedule: str = "quad",
beta_start: float = 0.0001,
beta_end: float = 0.5,
Expand Down
4 changes: 3 additions & 1 deletion pypots/imputation/csdi/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ def get_side_info(self, observed_tp, cond_mask):
) # (K,emb)
feature_embed = feature_embed.unsqueeze(0).unsqueeze(0).expand(B, L, -1, -1)

side_info = torch.cat([time_embed, feature_embed], dim=-1) # (B,L,K,*)
side_info = torch.cat(
[time_embed, feature_embed], dim=-1
) # (B,L,K,emb+d_feature_embedding)
side_info = side_info.permute(0, 3, 2, 1) # (B,*,K,L)

if not self.is_unconditional:
Expand Down
28 changes: 14 additions & 14 deletions pypots/imputation/csdi/modules/submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,6 @@ def __init__(self, n_diffusion_steps, d_embedding=128, d_projection=None):
self.projection1 = nn.Linear(d_embedding, d_projection)
self.projection2 = nn.Linear(d_projection, d_projection)

def forward(self, diffusion_step):
x = self.embedding[diffusion_step]
x = self.projection1(x)
x = F.silu(x)
x = self.projection2(x)
x = F.silu(x)
return x

@staticmethod
def _build_embedding(n_steps, d_embedding=64):
steps = torch.arange(n_steps).unsqueeze(1) # (T,1)
Expand All @@ -58,6 +50,14 @@ def _build_embedding(n_steps, d_embedding=64):
table = torch.cat([torch.sin(table), torch.cos(table)], dim=1) # (T,dim*2)
return table

def forward(self, diffusion_step: int):
x = self.embedding[diffusion_step]
x = self.projection1(x)
x = F.silu(x)
x = self.projection2(x)
x = F.silu(x)
return x


class ResidualBlock(nn.Module):
def __init__(self, d_side, n_channels, diffusion_embedding_dim, nheads):
Expand All @@ -73,7 +73,7 @@ def __init__(self, d_side, n_channels, diffusion_embedding_dim, nheads):
)

def forward_time(self, y, base_shape):
B, channel, K, L = base_shape
B, channel, K, L = base_shape # bz, 2, n_features, n_steps
if L == 1:
return y
y = y.reshape(B, channel, K, L).permute(0, 2, 1, 3).reshape(B * K, channel, L)
Expand All @@ -82,7 +82,7 @@ def forward_time(self, y, base_shape):
return y

def forward_feature(self, y, base_shape):
B, channel, K, L = base_shape
B, channel, K, L = base_shape # bz, 2, n_features, n_steps
if K == 1:
return y
y = y.reshape(B, channel, K, L).permute(0, 3, 1, 2).reshape(B * L, channel, K)
Expand All @@ -98,8 +98,8 @@ def forward(self, x, cond_info, diffusion_emb):
diffusion_emb = self.diffusion_projection(diffusion_emb).unsqueeze(
-1
) # (B,channel,1)
y = x + diffusion_emb

y = x + diffusion_emb
y = self.forward_time(y, base_shape)
y = self.forward_feature(y, base_shape) # (B,channel,K*L)
y = self.mid_projection(y) # (B,2*channel,K*L)
Expand Down Expand Up @@ -155,12 +155,12 @@ def __init__(
self.n_channels = n_channels

def forward(self, x, cond_info, diffusion_step):
B, input_dim, K, L = x.shape
B, input_dim, K, L = x.shape # bz, 2, n_features, n_steps

x = x.reshape(B, input_dim, K * L)
x = self.input_projection(x)
x = self.input_projection(x) # bz, n_channels, n_features*n_steps
x = F.relu(x)
x = x.reshape(B, self.n_channels, K, L)
x = x.reshape(B, self.n_channels, K, L) # bz, n_channels, n_features, n_steps

diffusion_emb = self.diffusion_embedding(diffusion_step)

Expand Down
3 changes: 3 additions & 0 deletions tests/global_test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

from pypots.data.generating import gene_random_walk
from pypots.utils.logging import logger
from pypots.utils.random import set_random_seed

set_random_seed(2023)

# Generate the unified data for testing and cache it first, DATA here is a singleton
# Otherwise, file lock will cause bug if running test parallely with pytest-xdist.
Expand Down
1 change: 1 addition & 0 deletions tests/imputation/csdi.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class TestCSDI(unittest.TestCase):
d_time_embedding=32,
d_feature_embedding=3,
d_diffusion_embedding=32,
n_diffusion_steps=10,
n_heads=1,
epochs=EPOCHS,
saving_path=saving_path,
Expand Down

0 comments on commit 5c0ce3e

Please sign in to comment.