12
12
13
13
14
14
import os
15
- from typing import Tuple , Union , Optional
15
+ from typing import Union , Optional
16
16
17
17
import numpy as np
18
18
import torch
@@ -443,6 +443,15 @@ def predict(
443
443
inputs = self ._assemble_input_for_testing (data )
444
444
results = self .model .forward (inputs , training = False )
445
445
446
+ mu_tilde = results ["mu_tilde" ].cpu ().numpy ()
447
+ mu_tilde_collector .append (mu_tilde )
448
+ mu = results ["mu" ].cpu ().numpy ()
449
+ mu_collector .append (mu )
450
+ var = results ["var" ].cpu ().numpy ()
451
+ var_collector .append (var )
452
+ phi = results ["phi" ].cpu ().numpy ()
453
+ phi_collector .append (phi )
454
+
446
455
def func_to_apply (
447
456
mu_t_ : np .ndarray ,
448
457
mu_ : np .ndarray ,
@@ -465,16 +474,8 @@ def func_to_apply(
465
474
clustering_results_collector .append (clustering_results )
466
475
467
476
if return_latent_vars :
468
- mu_tilde = results ["mu_tilde" ].cpu ().numpy ()
469
- mu_tilde_collector .append (mu_tilde )
470
477
stddev_tilde = results ["stddev_tilde" ].cpu ().numpy ()
471
478
stddev_tilde_collector .append (stddev_tilde )
472
- mu = results ["mu" ].cpu ().numpy ()
473
- mu_collector .append (mu )
474
- var = results ["var" ].cpu ().numpy ()
475
- var_collector .append (var )
476
- phi = results ["phi" ].cpu ().numpy ()
477
- phi_collector .append (phi )
478
479
z = results ["z" ].cpu ().numpy ()
479
480
z_collector .append (z )
480
481
imputation_latent = results ["imputation_latent" ].cpu ().numpy ()
0 commit comments