diff --git a/pypots/imputation/csdi/data.py b/pypots/imputation/csdi/data.py
index a983c5a4..b2798ca0 100644
--- a/pypots/imputation/csdi/data.py
+++ b/pypots/imputation/csdi/data.py
@@ -69,14 +69,14 @@ def _fetch_data_from_array(self, idx: int) -> Iterable:
         X = self.X[idx].to(torch.float32)
         X_intact, X, missing_mask, indicating_mask = mcar(X, p=self.rate)
 
-        observed_data = X_intact
-        observed_mask = missing_mask + indicating_mask
+        observed_data = X_intact  # i.e. originally observed data
+        observed_mask = missing_mask + indicating_mask  # i.e. originally missing masks
         observed_tp = (
             torch.arange(0, self.n_steps, dtype=torch.float32)
             if self.time_points is None
             else self.time_points[idx].to(torch.float32)
         )
-        gt_mask = indicating_mask
+        gt_mask = missing_mask  # missing mask with ground truth masked for validation
         for_pattern_mask = (
             gt_mask if self.for_pattern_mask is None else self.for_pattern_mask[idx]
         )
diff --git a/pypots/imputation/csdi/model.py b/pypots/imputation/csdi/model.py
index 67764c55..70bf7734 100644
--- a/pypots/imputation/csdi/model.py
+++ b/pypots/imputation/csdi/model.py
@@ -128,9 +128,9 @@ def __init__(
         d_time_embedding: int,
         d_feature_embedding: int,
         d_diffusion_embedding: int,
-        is_unconditional: bool = False,
-        target_strategy: str = "random",
         n_diffusion_steps: int = 50,
+        target_strategy: str = "random",
+        is_unconditional: bool = False,
         schedule: str = "quad",
         beta_start: float = 0.0001,
         beta_end: float = 0.5,
diff --git a/pypots/imputation/csdi/modules/core.py b/pypots/imputation/csdi/modules/core.py
index b25fd190..958fb65d 100644
--- a/pypots/imputation/csdi/modules/core.py
+++ b/pypots/imputation/csdi/modules/core.py
@@ -127,7 +127,9 @@ def get_side_info(self, observed_tp, cond_mask):
         )  # (K,emb)
         feature_embed = feature_embed.unsqueeze(0).unsqueeze(0).expand(B, L, -1, -1)
 
-        side_info = torch.cat([time_embed, feature_embed], dim=-1)  # (B,L,K,*)
+        side_info = torch.cat(
+            [time_embed, feature_embed], dim=-1
+        )  # (B,L,K,emb+d_feature_embedding)
         side_info = side_info.permute(0, 3, 2, 1)  # (B,*,K,L)
 
         if not self.is_unconditional:
diff --git a/pypots/imputation/csdi/modules/submodules.py b/pypots/imputation/csdi/modules/submodules.py
index 31e71fff..68248d7c 100644
--- a/pypots/imputation/csdi/modules/submodules.py
+++ b/pypots/imputation/csdi/modules/submodules.py
@@ -38,14 +38,6 @@ def __init__(self, n_diffusion_steps, d_embedding=128, d_projection=None):
         self.projection1 = nn.Linear(d_embedding, d_projection)
         self.projection2 = nn.Linear(d_projection, d_projection)
 
-    def forward(self, diffusion_step):
-        x = self.embedding[diffusion_step]
-        x = self.projection1(x)
-        x = F.silu(x)
-        x = self.projection2(x)
-        x = F.silu(x)
-        return x
-
     @staticmethod
     def _build_embedding(n_steps, d_embedding=64):
         steps = torch.arange(n_steps).unsqueeze(1)  # (T,1)
@@ -58,6 +50,14 @@ def _build_embedding(n_steps, d_embedding=64):
         table = torch.cat([torch.sin(table), torch.cos(table)], dim=1)  # (T,dim*2)
         return table
 
+    def forward(self, diffusion_step: int):
+        x = self.embedding[diffusion_step]
+        x = self.projection1(x)
+        x = F.silu(x)
+        x = self.projection2(x)
+        x = F.silu(x)
+        return x
+
 
 class ResidualBlock(nn.Module):
     def __init__(self, d_side, n_channels, diffusion_embedding_dim, nheads):
@@ -73,7 +73,7 @@ def __init__(self, d_side, n_channels, diffusion_embedding_dim, nheads):
         )
 
     def forward_time(self, y, base_shape):
-        B, channel, K, L = base_shape
+        B, channel, K, L = base_shape  # bz, 2, n_features, n_steps
         if L == 1:
             return y
         y = y.reshape(B, channel, K, L).permute(0, 2, 1, 3).reshape(B * K, channel, L)
@@ -82,7 +82,7 @@ def forward_time(self, y, base_shape):
         return y
 
     def forward_feature(self, y, base_shape):
-        B, channel, K, L = base_shape
+        B, channel, K, L = base_shape  # bz, 2, n_features, n_steps
         if K == 1:
             return y
         y = y.reshape(B, channel, K, L).permute(0, 3, 1, 2).reshape(B * L, channel, K)
@@ -98,8 +98,8 @@ def forward(self, x, cond_info, diffusion_emb):
         diffusion_emb = self.diffusion_projection(diffusion_emb).unsqueeze(
             -1
         )  # (B,channel,1)
-        y = x + diffusion_emb
 
+        y = x + diffusion_emb
         y = self.forward_time(y, base_shape)
         y = self.forward_feature(y, base_shape)  # (B,channel,K*L)
         y = self.mid_projection(y)  # (B,2*channel,K*L)
@@ -155,12 +155,12 @@ def __init__(
         self.n_channels = n_channels
 
     def forward(self, x, cond_info, diffusion_step):
-        B, input_dim, K, L = x.shape
+        B, input_dim, K, L = x.shape  # bz, 2, n_features, n_steps
 
         x = x.reshape(B, input_dim, K * L)
-        x = self.input_projection(x)
+        x = self.input_projection(x)  # bz, n_channels, n_features*n_steps
         x = F.relu(x)
-        x = x.reshape(B, self.n_channels, K, L)
+        x = x.reshape(B, self.n_channels, K, L)  # bz, n_channels, n_features, n_steps
 
         diffusion_emb = self.diffusion_embedding(diffusion_step)
 
diff --git a/tests/global_test_config.py b/tests/global_test_config.py
index c0d8afdb..3258ab6d 100644
--- a/tests/global_test_config.py
+++ b/tests/global_test_config.py
@@ -11,6 +11,9 @@
 
 from pypots.data.generating import gene_random_walk
 from pypots.utils.logging import logger
+from pypots.utils.random import set_random_seed
+
+set_random_seed(2023)
 
 # Generate the unified data for testing and cache it first, DATA here is a singleton
 # Otherwise, file lock will cause bug if running test parallely with pytest-xdist.
diff --git a/tests/imputation/csdi.py b/tests/imputation/csdi.py
index ff37ec55..3bfcd888 100644
--- a/tests/imputation/csdi.py
+++ b/tests/imputation/csdi.py
@@ -48,6 +48,7 @@ class TestCSDI(unittest.TestCase):
         d_time_embedding=32,
         d_feature_embedding=3,
         d_diffusion_embedding=32,
+        n_diffusion_steps=10,
         n_heads=1,
         epochs=EPOCHS,
         saving_path=saving_path,