Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Factor cost matrix in dense case #50

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/fugw/mappings/barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def compute_all_ot_plans(
mapping.fit(
source_features=features,
target_features=barycenter_features,
source_geometry=G,
target_geometry=barycenter_geometry,
source_geometry_embedding=G,
target_geometry_embedding=barycenter_geometry,
source_weights=weights,
target_weights=barycenter_weights,
init_plan=plans[i] if plans is not None else None,
Expand Down Expand Up @@ -196,9 +196,9 @@ def fit(
can have weights with different sizes.
features_list (list of np.array): List of features. Individuals should
have the same number of features n_features.
geometry_list (list of np.array or np.array): List of kernel matrices
or just one kernel matrix if it's shared across individuals
barycenter_size (int, optional): Size of computed
geometry_list (list of np.array or np.array): List of geometry
embeddings or just one embedding if it's shared across individuals.
barycenter_size (int, optional): Size of computed
barycentric features and geometry. Defaults to None.
init_barycenter_weights (np.array, optional): Distribution weights
of barycentric points. If None, points will have uniform
Expand Down
102 changes: 67 additions & 35 deletions src/fugw/mappings/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from fugw.solvers.dense import FUGWSolver
from fugw.mappings.utils import BaseMapping, console
from fugw.utils import _make_tensor, init_plan_dense
from fugw.utils import _make_tensor, init_plan_dense, _low_rank_squared_l2


class FUGW(BaseMapping):
Expand All @@ -13,12 +13,12 @@ def fit(
self,
source_features=None,
target_features=None,
source_geometry=None,
target_geometry=None,
source_geometry_embedding=None,
target_geometry_embedding=None,
source_features_val=None,
target_features_val=None,
source_geometry_val=None,
target_geometry_val=None,
source_geometry_embedding_val=None,
target_geometry_embedding_val=None,
source_weights=None,
target_weights=None,
init_plan=None,
Expand Down Expand Up @@ -52,14 +52,16 @@ def fit(
Feature maps for target subject.
**This array should be normalized**, otherwise you will
run into computational errors.
source_geometry: ndarray(n, n)
Kernel matrix of anatomical distances
between nodes of source mesh
source_geometry_embedding: ndarray(n, k), optional
Embedding X such that norm(X_i - X_j) approximates
the anatomical distance between vertices i and j
of the source mesh
**This array should be normalized**, otherwise you will
run into computational errors.
target_geometry: ndarray(m, m)
Kernel matrix of anatomical distances
between nodes of target mesh
target_geometry_embedding: ndarray(m, k), optional
Embedding X such that norm(X_i - X_j) approximates
the anatomical distance between vertices i and j
of the target mesh
**This array should be normalized**, otherwise you will
run into computational errors.
source_features_val: ndarray(n_features, n) or None
Expand All @@ -68,11 +70,11 @@ def fit(
target_features_val: ndarray(n_features, m) or None
Feature maps for target subject used for validation.
If None, target_features will be used instead.
source_geometry_val: ndarray(n, n) or None
source_geometry_embedding_val: ndarray(n, n) or None
Kernel matrix of anatomical distances
between nodes of source mesh used for validation.
If None, source_geometry will be used instead.
target_geometry_val: ndarray(m, m) or None
target_geometry_embedding_val: ndarray(m, m) or None
Kernel matrix of anatomical distances
between nodes of target mesh used for validation.
If None, target_geometry will be used instead.
Expand Down Expand Up @@ -170,17 +172,30 @@ def fit(
# Compute distance matrix between features
Fs = _make_tensor(source_features.T, device=device)
Ft = _make_tensor(target_features.T, device=device)
F = torch.cdist(Fs, Ft, p=2) ** 2
F1, F2 = _low_rank_squared_l2(Fs, Ft)
F1 = _make_tensor(F1, device=device)
F2 = _make_tensor(F2, device=device)

# Load anatomical kernels to GPU
Ds = _make_tensor(source_geometry, device=device)
Dt = _make_tensor(target_geometry, device=device)
Ds1, Ds2 = _low_rank_squared_l2(
source_geometry_embedding, source_geometry_embedding
)
Ds1 = _make_tensor(Ds1, device=device)
Ds2 = _make_tensor(Ds2, device=device)
Dt1, Dt2 = _low_rank_squared_l2(
target_geometry_embedding, target_geometry_embedding
)
Dt1 = _make_tensor(Dt1, device=device)
Dt2 = _make_tensor(Dt2, device=device)

# Do the same for validation data if it was provided
if source_features_val is not None and target_features_val is not None:
Fs_val = _make_tensor(source_features_val.T, device=device)
Ft_val = _make_tensor(target_features_val.T, device=device)
F_val = torch.cdist(Fs_val, Ft_val, p=2) ** 2
Fs_val = _make_tensor(source_features.T, device=device)
Ft_val = _make_tensor(target_features.T, device=device)
F1_val, F2_val = _low_rank_squared_l2(Fs_val, Ft_val)
F1_val = _make_tensor(F1_val, device=device)
F2_val = _make_tensor(F2_val, device=device)

elif source_features_val is not None and target_features_val is None:
raise ValueError(
Expand All @@ -195,7 +210,7 @@ def fit(
)

else:
F_val = None
F1_val, F2_val = None, None

# Raise warning if validation feature maps are not provided
if verbose:
Expand All @@ -204,27 +219,44 @@ def fit(
" Using training data instead."
)

if source_geometry_val is not None and target_geometry_val is not None:
Ds_val = _make_tensor(source_geometry_val, device=device)
Dt_val = _make_tensor(target_geometry_val, device=device)
if (
source_geometry_embedding_val is not None
and target_geometry_embedding_val is not None
):
Ds1_val, Ds2_val = _low_rank_squared_l2(
source_geometry_embedding_val, source_geometry_embedding_val
)
Ds1_val = _make_tensor(Ds1, device=device)
Ds2_val = _make_tensor(Ds2, device=device)
Dt1_val, Dt2_val = _low_rank_squared_l2(
target_geometry_embedding, target_geometry_embedding
)
Dt1_val = _make_tensor(Dt1_val, device=device)
Dt2_val = _make_tensor(Dt2_val, device=device)

elif source_geometry_val is not None and target_geometry_val is None:
elif (
source_geometry_embedding_val is not None
and target_geometry_embedding_val is None
):
raise ValueError(
"Source geometry validation data provided but not target"
" geometry validation data."
)

elif source_geometry_val is None and target_geometry_val is not None:
elif (
source_geometry_embedding_val is None
and target_geometry_embedding_val is not None
):
raise ValueError(
"Target geometry validation data provided but not source"
" geometry validation data."
)

else:
Ds_val = None
Dt_val = None
Ds1_val, Ds2_val = Ds1, Ds2
Dt1_val, Dt2_val = Dt1, Dt2

# Raise warning if validation anatomical kernelsare not provided
# Raise warning if validation anatomical kernels are not provided
if verbose:
console.log(
"Validation data for anatomical kernels is not provided."
Expand All @@ -242,12 +274,12 @@ def fit(
eps=self.eps,
reg_mode=self.reg_mode,
divergence=self.divergence,
F=F,
Ds=Ds,
Dt=Dt,
F_val=F_val,
Ds_val=Ds_val,
Dt_val=Dt_val,
F=(F1, F2),
Ds=(Ds1, Ds2),
Dt=(Dt1, Dt2),
F_val=(F1_val, F2_val),
Ds_val=(Ds1_val, Ds2_val),
Dt_val=(Dt1_val, Dt2_val),
ws=ws,
wt=wt,
init_plan=pi_init,
Expand All @@ -265,9 +297,9 @@ def fit(
self.loss_val = res["loss_val"]

# Free allocated GPU memory
del Fs, Ft, F, Ds, Dt
del F1, F2, Ds1, Ds2, Dt1, Dt2
if source_features_val is not None:
del Fs_val, Ft_val, F_val, Ds_val, Dt_val
del Ds1_val, Ds2_val, Dt1_val, Dt2_val
if torch.cuda.is_available():
torch.cuda.empty_cache()

Expand Down
30 changes: 17 additions & 13 deletions src/fugw/scripts/coarse_to_fine.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,19 +373,23 @@ def fit(
source_geometry_embeddings = _make_tensor(source_geometry_embeddings)
target_geometry_embeddings = _make_tensor(target_geometry_embeddings)

# Compute anatomical kernels
source_geometry_kernel = torch.cdist(
source_geometry_embeddings[source_sample],
source_geometry_embeddings[source_sample],
p=2,
# Compute coarse geometry embeddings
source_coarse_embedding = source_geometry_embeddings[source_sample]
target_coarse_embedding = target_geometry_embeddings[target_sample]

# Normalize embeddings
source_coarse_embedding = (
source_coarse_embedding
/ (source_coarse_embedding @ source_coarse_embedding.T)
.norm(dim=1)
.max()
)
source_geometry_kernel /= source_geometry_kernel.max()
target_geometry_kernel = torch.cdist(
target_geometry_embeddings[target_sample],
target_geometry_embeddings[target_sample],
p=2,
target_coarse_embedding = (
target_coarse_embedding
/ (target_coarse_embedding @ target_coarse_embedding.T)
.norm(dim=1)
.max()
)
target_geometry_kernel /= target_geometry_kernel.max()

# Sampled weights
if source_weights is None:
Expand All @@ -408,8 +412,8 @@ def fit(
coarse_mapping.fit(
source_features[:, source_sample],
target_features[:, target_sample],
source_geometry=source_geometry_kernel,
target_geometry=target_geometry_kernel,
source_geometry_embedding=source_coarse_embedding,
target_geometry_embedding=target_coarse_embedding,
source_weights=source_weights_sampled,
target_weights=target_weights_sampled,
solver=coarse_mapping_solver,
Expand Down
Loading
Loading