Skip to content

Commit 545b5ee

Browse files
committed
fix: return latent vars if not in the training stage;
1 parent f80467d commit 545b5ee

File tree

2 files changed

+19
-21
lines changed

2 files changed

+19
-21
lines changed

pypots/imputation/saits/model.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -347,21 +347,18 @@ def predict(
347347
results = self.model.forward(
348348
inputs, diagonal_attention_mask, training=False
349349
)
350-
imputed_data = results["imputed_data"]
351-
imputation_collector.append(imputed_data)
350+
imputation_collector.append(results["imputed_data"])
352351

353352
if return_latent_vars:
354-
first_DMSA_attn_weights = (
353+
first_DMSA_attn_weights_collector.append(
355354
results["first_DMSA_attn_weights"].cpu().numpy()
356355
)
357-
second_DMSA_attn_weights = (
356+
second_DMSA_attn_weights_collector.append(
358357
results["second_DMSA_attn_weights"].cpu().numpy()
359358
)
360-
combining_weights = results["combining_weights"].cpu().numpy()
361-
362-
first_DMSA_attn_weights_collector.append(first_DMSA_attn_weights)
363-
second_DMSA_attn_weights_collector.append(second_DMSA_attn_weights)
364-
combining_weights_collector.append(combining_weights)
359+
combining_weights_collector.append(
360+
results["combining_weights"].cpu().numpy()
361+
)
365362

366363
# Step 3: output collection and return
367364
imputation = torch.cat(imputation_collector).cpu().detach().numpy()

pypots/imputation/saits/modules/core.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,15 @@ def forward(
174174
[first_DMSA_attn_weights, second_DMSA_attn_weights, combining_weights],
175175
) = self._process(inputs, diagonal_attention_mask)
176176

177+
results = {
178+
"first_DMSA_attn_weights": first_DMSA_attn_weights,
179+
"second_DMSA_attn_weights": second_DMSA_attn_weights,
180+
"combining_weights": combining_weights,
181+
"imputed_data": imputed_data,
182+
}
177183
if not training:
178184
# if not in training mode, return the classification result only
179-
return {
180-
"imputed_data": imputed_data,
181-
}
185+
return results
182186

183187
ORT_loss = 0
184188
ORT_loss += self.customized_loss_func(X_tilde_1, X, masks)
@@ -193,13 +197,10 @@ def forward(
193197
# `loss` is always the item for backward propagating to update the model
194198
loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss
195199

196-
results = {
197-
"first_DMSA_attn_weights": first_DMSA_attn_weights,
198-
"second_DMSA_attn_weights": second_DMSA_attn_weights,
199-
"combining_weights": combining_weights,
200-
"imputed_data": imputed_data,
201-
"ORT_loss": ORT_loss,
202-
"MIT_loss": MIT_loss,
203-
"loss": loss, # will be used for backward propagating to update the model
204-
}
200+
results["ORT_loss"] = ORT_loss
201+
results["MIT_loss"] = MIT_loss
202+
203+
# will be used for backward propagating to update the model
204+
results["loss"] = loss
205+
205206
return results

0 commit comments

Comments
 (0)