Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
885 changes: 885 additions & 0 deletions notebooks/denoising_demo.ipynb

Large diffs are not rendered by default.

199 changes: 199 additions & 0 deletions src/flat_mae/models_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from huggingface_hub import PyTorchModelHubMixin
from jaxtyping import Float, Int
from timm.layers import to_2tuple, to_ntuple
from einops import rearrange

from .modules import (
Block,
Expand Down Expand Up @@ -755,6 +756,204 @@ def forward_embedding(
):
return self.encoder.forward_embedding(x, mask, mask_ratio)

def forward_denoise(
self,
imgs,
mask_ratio=0.75,
n_samples=4,
img_mask=None,
generator=None,
):
"""
Explicit denoising via multiple masked reconstructions.

runs mae n_samples + 1 times with partitioned patch masks, then aggregates
reconstructions. each patch is reconstructed either n_samples or n_samples + 1 times due to remainder handling.
structured signal will be consistent across runs, while unstructured noise will be inconsistent.

args:
imgs: input fmri volume [N, C, T, H, W]
mask_ratio: fraction of patches to mask per run (1-mask_ratio are visible)
n_samples: number of masked reconstructions to run
img_mask: optional binary mask of valid voxels
generator: torch.generator for reproducibility

returns:
denoised_imgs: [N, C, T, H, W] denoised fmri volume
"""
# convert input to model's dtype to avoid dtype mismatch
model_dtype = next(self.parameters()).dtype
imgs = imgs.to(dtype=model_dtype)
if img_mask is not None:
img_mask = img_mask.to(dtype=model_dtype)

N, C, T, H, W = imgs.shape
# if samples is too high, there won't be enough visible patches.
assert 1 / (n_samples + 1) >= (1 - mask_ratio), (
"n_samples too large for mask_ratio"
)

# get patch dimensions (extract 2D spatial parts, handling both 2D and 3D cases)
grid_size = self.encoder.patchify.grid_size
patch_size = self.encoder.patchify.patch_size
# for 3D: grid_size is (t, h, w), patch_size is (p_t, p_h, p_w)
# for 2D: grid_size is (h, w), patch_size is (p_h, p_w)
# we only need spatial dimensions (last 2)
h, w = grid_size[-2:]
ph, pw = patch_size[-2:]

# we assume a shared img_mask for all samples for simplicity.
if img_mask is not None:
# img_mask should be [H, W] or [1, 1, H, W] or similar
img_mask_squeezed = img_mask.squeeze()
assert img_mask_squeezed.shape == (H, W), (
"denoising requires a fixed image mask"
)
# use rearrange to patchify the spatial mask directly (like old version)
# this works with 2D spatial patches regardless of temporal dimension
patch_mask = rearrange(
img_mask_squeezed, "(h p) (w q) -> (h w) (p q)", h=h, p=ph, w=w, q=pw
)
patch_mask = patch_mask.sum(dim=-1).clip(max=1)
img_mask = img_mask.expand_as(imgs)
else:
patch_mask = torch.ones(h * w, device=imgs.device)

# generate random patch permutation (using generator for reproducibility)
valid_ids = patch_mask.flatten().nonzero().flatten()
patch_permutation = valid_ids[
torch.randperm(len(valid_ids), generator=generator, device=valid_ids.device)
]

# run multiple masked reconstructions
reconstructions = []
decoder_masks = []

# run n_samples + 1 times
for run_idx in range(n_samples + 1):
# create visible mask for this run using sliding window approach
run_visible_mask = self._generate_visible_mask(
imgs,
run_idx,
n_samples + 1,
patch_permutation,
)

# prepare targets with normalization
targets_patches, targets_stats = self.prepare_targets(imgs, img_mask)

# encode with visible mask
cls_embeds, reg_embeds, patch_embeds, visible_mask_out, visible_ids = self.encoder(
imgs, mask=run_visible_mask, mask_ratio=None # use provided mask, no random masking
)

# prepare prediction mask (predict non-visible patches)
pred_mask_patches, pred_ids = self.prepare_pred_mask(
visible_mask_out,
pred_mask=None,
pred_mask_ratio=None, # predict all non-visible patches
)

# decode
preds = self.forward_decoder(patch_embeds, reg_embeds, visible_ids, pred_ids)

# unnormalize predictions
if targets_stats is not None:
targets_mean, targets_std = targets_stats
# gather stats for predicted patches
pred_mean = targets_mean.gather(1, pred_ids.unsqueeze(-1).expand(-1, -1, preds.shape[-1]))
pred_std = targets_std.gather(1, pred_ids.unsqueeze(-1).expand(-1, -1, preds.shape[-1]))
pred_unnorm = preds * pred_std + pred_mean
else:
pred_unnorm = preds

# expand predictions to full patch space for aggregation
# pred_unnorm is [N, Q, P], need to expand to [N, N_patches, P]
N_patches = self.pred_patchify.num_patches
P = pred_unnorm.shape[-1]
pred_expanded = torch.zeros((N, N_patches, P), dtype=pred_unnorm.dtype, device=pred_unnorm.device)
pred_expanded.scatter_(1, pred_ids.unsqueeze(-1).expand(-1, -1, P), pred_unnorm)

# create decoder_mask to track which patches were predicted
# use model dtype to ensure compatibility
model_dtype = next(self.parameters()).dtype
decoder_mask = torch.zeros((N, N_patches), device=pred_ids.device, dtype=model_dtype)
decoder_mask.scatter_(1, pred_ids, torch.ones_like(pred_ids, dtype=model_dtype))

reconstructions.append(pred_expanded)
decoder_masks.append(decoder_mask)

# aggregate reconstructions using proper weighted mean
# each patch may be reconstructed a different number of times
stacked_reconstructions = torch.stack(reconstructions) # [n_runs, N, L, D]
stacked_decoder_masks = torch.stack(decoder_masks) # [n_runs, N, L]

# compute weighted mean accounting for how many times each patch was reconstructed
# decoder_mask: 1 = reconstructed, 0 = not reconstructed
weights = stacked_decoder_masks.unsqueeze(-1) # [n_runs, N, L, 1]

# sum of reconstructions and sum of weights for each patch
weighted_sum = (stacked_reconstructions * weights).sum(dim=0) # [N, L, D]
weight_sum = weights.sum(dim=0) # [N, L, 1]

final_patches = weighted_sum / weight_sum.clamp(min=1.0)

# convert back to image space using unpatchify
# note: t_pred_stride > 1 is automatically handled by StridedPatchify3D.unpatchify()
# which repeats temporal frames to fill all time steps
denoised_imgs = self.pred_patchify.unpatchify(final_patches)

return denoised_imgs

def _generate_visible_mask(
self,
imgs,
run_idx,
total_runs,
patch_permutation,
):
"""
Partition-based visible mask with correct temporal handling.
- Randomly partition patch_permutation into total_runs.
- Select run_idx group as visible.
- Expand 2D patch-level mask to voxel space using unpatchify without manual repeats.
"""
N, C, T, H, W = imgs.shape

# Split shuffled permutation into groups
groups = torch.tensor_split(patch_permutation, total_runs)
visible_patch_indices = groups[run_idx]

# Use patch dimensions from the model configuration
# extract 2D spatial parts (handling both 2D and 3D cases)
patch_size = self.encoder.patchify.patch_size
grid_size = self.encoder.patchify.grid_size
# for 3D: patch_size is (p_t, p_h, p_w), grid_size is (t, h, w)
# for 2D: patch_size is (p_h, p_w), grid_size is (h, w)
# we only need spatial dimensions (last 2 elements work for both cases)
ph, pw = patch_size[-2:] # spatial patch size
h, w = grid_size[-2:] # spatial grid size
# for temporal: if 3D, use first dim; if 2D, assume single frame
u = patch_size[0] if len(patch_size) == 3 else 1
t = grid_size[0] if len(grid_size) == 3 else 1

# Create patch_mask in patch space, matching model's parameter dtype
# use the dtype of the first model parameter to ensure compatibility
model_dtype = next(self.parameters()).dtype
patch_mask = torch.zeros((N, h * w), device=imgs.device, dtype=model_dtype)
patch_mask[:, visible_patch_indices] = 1.0
patch_mask = patch_mask.unsqueeze(1).expand(-1, t, -1).flatten(1)

# Expand to include voxel content: (N, t*h*w, u*ph*pw*C)
patch_mask_expanded = patch_mask.unsqueeze(-1).expand(-1, -1, u * ph * pw * C)

# use encoder's unpatchify to convert patch mask back to image space
# patch ordering matches patchify3d: (t h w) flattened to [t*h*w]
visible_mask = self.encoder.patchify.unpatchify(patch_mask_expanded)

# mask is already in model_dtype from patch_mask creation, just return it
return visible_mask


class MaskedViT(MaskedEncoder, PyTorchModelHubMixin):
def __init__(
Expand Down