Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 58 additions & 17 deletions src/move/tasks/identify_associations.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,17 @@ def _bayes_approach(
# Last appended dataloader is the baseline
baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)

#############################
#MY EDITS
# At the start of _bayes_approach, after getting baseline_dataset
if hasattr(baseline_dataset, "con_list"):
# Concatenate all continuous arrays
orig_con = np.concatenate([con.numpy() for con in baseline_dataset.con_list], axis=1)
else:
# Fallback to con_all if already concatenated
orig_con = baseline_dataset.con_all.numpy()


for j in range(task_config.num_refits):
# Initialize model
model: VAE = hydra.utils.instantiate(
Expand Down Expand Up @@ -262,43 +273,73 @@ def _bayes_approach(
[min_baseline, min_perturb], axis=0
), np.max([max_baseline, max_perturb], axis=0)



# Calculate Bayes factors
logger.info("Identifying significant features")
logger.info("Identifying significant features / Modified Algo")
logger.info("Setting manual sig-threshold")
task_config.sig_threshold = 0.3
bayes_k = np.empty((num_perturbed, num_continuous))
bayes_mask = np.zeros(np.shape(bayes_k))
for i in range(num_perturbed):
#a mask which contains the the NA values is excluded for bayes calculation
#feature mask excludes the values for pertubed features
mask = feature_mask[:, [i]] | nan_mask # 2D: N x C

##############
#Extra Logging
##############

#logger.debug(f"Mask head (first 5 rows): {mask[:5]}")


########
diff = np.ma.masked_array(mean_diff[i, :, :], mask=mask) # 2D: N x C
prob = np.ma.compressed(np.mean(diff > 1e-8, axis=0)) # 1D: C
#prob = np.ma.compressed(np.mean(diff > 1e-8, axis=0)) # 1D: C
prob = np.mean(diff > 1e-8, axis=0).filled(np.nan) # 1D: C, masked positions are nan
bayes_k[i, :] = np.log(prob + 1e-8) - np.log(1 - prob + 1e-8)
if task_config.target_value in CONTINUOUS_TARGET_VALUE:
bayes_mask[i, :] = (
baseline_dataloader.dataset.con_all[0, :]
orig_con[0, :]
- dataloaders[i].dataset.con_all[0, :]
)

bayes_mask[bayes_mask != 0] = 1
bayes_mask = np.array(bayes_mask, dtype=bool)

# Calculate Bayes probabilities
bayes_abs = np.abs(bayes_k)
bayes_p = np.exp(bayes_abs) / (1 + np.exp(bayes_abs)) # 2D: N x C
bayes_abs[bayes_mask] = np.min(
bayes_abs
) # Bring feature_i feature_i associations to minimum
sort_ids = np.argsort(bayes_abs, axis=None)[::-1] # 1D: N x C
prob = np.take(bayes_p, sort_ids) # 1D: N x C
logger.debug(f"Bayes proba range: [{prob[-1]:.3f} {prob[0]:.3f}]")
# Filter out features with zero unmasked values before sorting
# Use the last mask from the loop (for the last i)
unmasked_counts = np.sum(~mask, axis=0)
valid_features = unmasked_counts > 0
bayes_k_valid = bayes_k[:, valid_features]
bayes_mask_valid = bayes_mask[:, valid_features]
bayes_abs_valid = np.abs(bayes_k_valid)
bayes_p_valid = np.exp(bayes_abs_valid) / (1 + np.exp(bayes_abs_valid))
bayes_abs_valid[bayes_mask_valid] = np.min(bayes_abs_valid)

# Filter out features where probability is nan
if bayes_p_valid.ndim == 2:
not_nan_mask = ~np.isnan(bayes_p_valid).any(axis=0)
else:
not_nan_mask = ~np.isnan(bayes_p_valid)
bayes_k_valid = bayes_k_valid[:, not_nan_mask]
bayes_mask_valid = bayes_mask_valid[:, not_nan_mask]
bayes_abs_valid = bayes_abs_valid[:, not_nan_mask]
bayes_p_valid = bayes_p_valid[:, not_nan_mask]

sort_ids_valid = np.argsort(bayes_abs_valid, axis=None)[::-1]
prob_valid = np.take(bayes_p_valid, sort_ids_valid)
logger.debug(f"Bayes proba range: [{prob_valid[-1]:.3f} {prob_valid[0]:.3f}]")

# Sort Bayes
bayes_k = np.take(bayes_k, sort_ids) # 1D: N x C
bayes_k_valid_sorted = np.take(bayes_k_valid, sort_ids_valid) # 1D: valid features only

# Calculate FDR
fdr = np.cumsum(1 - prob) / np.arange(1, prob.size + 1) # 1D
idx = np.argmin(np.abs(fdr - task_config.sig_threshold))
logger.debug(f"FDR range: [{fdr[0]:.3f} {fdr[-1]:.3f}]")
fdr_valid = np.cumsum(1 - prob_valid) / np.arange(1, prob_valid.size + 1) # 1D
idx_valid = np.argmin(np.abs(fdr_valid - task_config.sig_threshold))
logger.debug(f"FDR range: [{fdr_valid[0]:.3f} {fdr_valid[-1]:.3f}]")

return sort_ids[:idx], prob[:idx], fdr[:idx], bayes_k[:idx]
return sort_ids_valid[:idx_valid], prob_valid[:idx_valid], fdr_valid[:idx_valid], bayes_k_valid_sorted[:idx_valid]


def _ttest_approach(
Expand Down