diff --git a/src/fugw/mappings/barycenter.py b/src/fugw/mappings/barycenter.py index 7b48affe..e431dc4c 100644 --- a/src/fugw/mappings/barycenter.py +++ b/src/fugw/mappings/barycenter.py @@ -146,8 +146,8 @@ def compute_all_ot_plans( mapping.fit( source_features=features, target_features=barycenter_features, - source_geometry=G, - target_geometry=barycenter_geometry, + source_geometry_embedding=G, + target_geometry_embedding=barycenter_geometry, source_weights=weights, target_weights=barycenter_weights, init_plan=plans[i] if plans is not None else None, @@ -196,9 +196,9 @@ def fit( can have weights with different sizes. features_list (list of np.array): List of features. Individuals should have the same number of features n_features. - geometry_list (list of np.array or np.array): List of kernel matrices - or just one kernel matrix if it's shared across individuals - barycenter_size (int, optional): Size of computed + geometry_list (list of np.array or np.array): List of geometry + embeddings or just one embedding if it's shared across individuals. + barycenter_size (int, optional): Size of computed barycentric features and geometry. Defaults to None. init_barycenter_weights (np.array, optional): Distribution weights of barycentric points. If None, points will have uniform diff --git a/src/fugw/mappings/dense.py b/src/fugw/mappings/dense.py index 0115c1e3..a90cf0d3 100644 --- a/src/fugw/mappings/dense.py +++ b/src/fugw/mappings/dense.py @@ -3,7 +3,7 @@ from fugw.solvers.dense import FUGWSolver from fugw.mappings.utils import BaseMapping, console -from fugw.utils import _make_tensor, init_plan_dense +from fugw.utils import _make_tensor, init_plan_dense, _low_rank_squared_l2 class FUGW(BaseMapping): @@ -13,12 +13,12 @@ def fit( self, source_features=None, target_features=None, - source_geometry=None, - target_geometry=None, + source_geometry_embedding=None, + target_geometry_embedding=None, source_features_val=None, target_features_val=None, - source_geometry_val=None, - target_geometry_val=None, + source_geometry_embedding_val=None, + target_geometry_embedding_val=None, source_weights=None, target_weights=None, init_plan=None, @@ -52,14 +52,16 @@ def fit( Feature maps for target subject. **This array should be normalized**, otherwise you will run into computational errors. - source_geometry: ndarray(n, n) - Kernel matrix of anatomical distances - between nodes of source mesh + source_geometry_embedding: ndarray(n, k), optional + Embedding X such that norm(X_i - X_j) approximates + the anatomical distance between vertices i and j + of the source mesh **This array should be normalized**, otherwise you will run into computational errors. - target_geometry: ndarray(m, m) - Kernel matrix of anatomical distances - between nodes of target mesh + target_geometry_embedding: ndarray(m, k), optional + Embedding X such that norm(X_i - X_j) approximates + the anatomical distance between vertices i and j + of the target mesh **This array should be normalized**, otherwise you will run into computational errors. source_features_val: ndarray(n_features, n) or None @@ -68,11 +70,11 @@ def fit( target_features_val: ndarray(n_features, m) or None Feature maps for target subject used for validation. If None, target_features will be used instead. - source_geometry_val: ndarray(n, n) or None + source_geometry_embedding_val: ndarray(n, n) or None Kernel matrix of anatomical distances between nodes of source mesh used for validation. If None, source_geometry will be used instead. - target_geometry_val: ndarray(m, m) or None + target_geometry_embedding_val: ndarray(m, m) or None Kernel matrix of anatomical distances between nodes of target mesh used for validation. If None, target_geometry will be used instead. @@ -170,17 +172,30 @@ def fit( # Compute distance matrix between features Fs = _make_tensor(source_features.T, device=device) Ft = _make_tensor(target_features.T, device=device) - F = torch.cdist(Fs, Ft, p=2) ** 2 + F1, F2 = _low_rank_squared_l2(Fs, Ft) + F1 = _make_tensor(F1, device=device) + F2 = _make_tensor(F2, device=device) # Load anatomical kernels to GPU - Ds = _make_tensor(source_geometry, device=device) - Dt = _make_tensor(target_geometry, device=device) + Ds1, Ds2 = _low_rank_squared_l2( + source_geometry_embedding, source_geometry_embedding + ) + Ds1 = _make_tensor(Ds1, device=device) + Ds2 = _make_tensor(Ds2, device=device) + Dt1, Dt2 = _low_rank_squared_l2( + target_geometry_embedding, target_geometry_embedding + ) + Dt1 = _make_tensor(Dt1, device=device) + Dt2 = _make_tensor(Dt2, device=device) # Do the same for validation data if it was provided if source_features_val is not None and target_features_val is not None: Fs_val = _make_tensor(source_features_val.T, device=device) - Ft_val = _make_tensor(target_features_val.T, device=device) - F_val = torch.cdist(Fs_val, Ft_val, p=2) ** 2 + Fs_val = _make_tensor(source_features.T, device=device) + Ft_val = _make_tensor(target_features.T, device=device) + F1_val, F2_val = _low_rank_squared_l2(Fs_val, Ft_val) + F1_val = _make_tensor(F1_val, device=device) + F2_val = _make_tensor(F2_val, device=device) elif source_features_val is not None and target_features_val is None: raise ValueError( @@ -195,7 +210,7 @@ def fit( ) else: - F_val = None + F1_val, F2_val = None, None # Raise warning if validation feature maps are not provided if verbose: @@ -204,27 +219,44 @@ def fit( " Using training data instead." ) - if source_geometry_val is not None and target_geometry_val is not None: - Ds_val = _make_tensor(source_geometry_val, device=device) - Dt_val = _make_tensor(target_geometry_val, device=device) + if ( + source_geometry_embedding_val is not None + and target_geometry_embedding_val is not None + ): + Ds1_val, Ds2_val = _low_rank_squared_l2( + source_geometry_embedding_val, source_geometry_embedding_val + ) + Ds1_val = _make_tensor(Ds1, device=device) + Ds2_val = _make_tensor(Ds2, device=device) + Dt1_val, Dt2_val = _low_rank_squared_l2( + target_geometry_embedding, target_geometry_embedding + ) + Dt1_val = _make_tensor(Dt1_val, device=device) + Dt2_val = _make_tensor(Dt2_val, device=device) - elif source_geometry_val is not None and target_geometry_val is None: + elif ( + source_geometry_embedding_val is not None + and target_geometry_embedding_val is None + ): raise ValueError( "Source geometry validation data provided but not target" " geometry validation data." ) - elif source_geometry_val is None and target_geometry_val is not None: + elif ( + source_geometry_embedding_val is None + and target_geometry_embedding_val is not None + ): raise ValueError( "Target geometry validation data provided but not source" " geometry validation data." ) else: - Ds_val = None - Dt_val = None + Ds1_val, Ds2_val = Ds1, Ds2 + Dt1_val, Dt2_val = Dt1, Dt2 - # Raise warning if validation anatomical kernelsare not provided + # Raise warning if validation anatomical kernels are not provided if verbose: console.log( "Validation data for anatomical kernels is not provided." @@ -242,12 +274,12 @@ def fit( eps=self.eps, reg_mode=self.reg_mode, divergence=self.divergence, - F=F, - Ds=Ds, - Dt=Dt, - F_val=F_val, - Ds_val=Ds_val, - Dt_val=Dt_val, + F=(F1, F2), + Ds=(Ds1, Ds2), + Dt=(Dt1, Dt2), + F_val=(F1_val, F2_val), + Ds_val=(Ds1_val, Ds2_val), + Dt_val=(Dt1_val, Dt2_val), ws=ws, wt=wt, init_plan=pi_init, @@ -265,9 +297,9 @@ def fit( self.loss_val = res["loss_val"] # Free allocated GPU memory - del Fs, Ft, F, Ds, Dt + del F1, F2, Ds1, Ds2, Dt1, Dt2 if source_features_val is not None: - del Fs_val, Ft_val, F_val, Ds_val, Dt_val + del Ds1_val, Ds2_val, Dt1_val, Dt2_val if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/src/fugw/scripts/coarse_to_fine.py b/src/fugw/scripts/coarse_to_fine.py index 8f1379e0..49424ffd 100644 --- a/src/fugw/scripts/coarse_to_fine.py +++ b/src/fugw/scripts/coarse_to_fine.py @@ -373,19 +373,23 @@ def fit( source_geometry_embeddings = _make_tensor(source_geometry_embeddings) target_geometry_embeddings = _make_tensor(target_geometry_embeddings) - # Compute anatomical kernels - source_geometry_kernel = torch.cdist( - source_geometry_embeddings[source_sample], - source_geometry_embeddings[source_sample], - p=2, + # Compute coarse geometry embeddings + source_coarse_embedding = source_geometry_embeddings[source_sample] + target_coarse_embedding = target_geometry_embeddings[target_sample] + + # Normalize embeddings + source_coarse_embedding = ( + source_coarse_embedding + / (source_coarse_embedding @ source_coarse_embedding.T) + .norm(dim=1) + .max() ) - source_geometry_kernel /= source_geometry_kernel.max() - target_geometry_kernel = torch.cdist( - target_geometry_embeddings[target_sample], - target_geometry_embeddings[target_sample], - p=2, + target_coarse_embedding = ( + target_coarse_embedding + / (target_coarse_embedding @ target_coarse_embedding.T) + .norm(dim=1) + .max() ) - target_geometry_kernel /= target_geometry_kernel.max() # Sampled weights if source_weights is None: @@ -408,8 +412,8 @@ def fit( coarse_mapping.fit( source_features[:, source_sample], target_features[:, target_sample], - source_geometry=source_geometry_kernel, - target_geometry=target_geometry_kernel, + source_geometry_embedding=source_coarse_embedding, + target_geometry_embedding=target_coarse_embedding, source_weights=source_weights_sampled, target_weights=target_weights_sampled, solver=coarse_mapping_solver, diff --git a/src/fugw/solvers/dense.py b/src/fugw/solvers/dense.py index e5bef02a..d5acc581 100644 --- a/src/fugw/solvers/dense.py +++ b/src/fugw/solvers/dense.py @@ -1,4 +1,5 @@ from functools import partial +from copy import deepcopy import time @@ -40,26 +41,35 @@ def local_biconvex_cost( rho_s, rho_t, eps, alpha, reg_mode, divergence = hyperparams ws, wt, ws_dot_wt = tuple_weights - X_sqr, Y_sqr, X, Y, D = data_const + + ( + (Ds_sqr_1, Ds_sqr_2), + (Dt_sqr_1, Dt_sqr_2), + (Ds1, Ds2), + (Dt1, Dt2), + (K1, K2), + ) = data_const if transpose: - X_sqr, Y_sqr, X, Y = X_sqr.T, Y_sqr.T, X.T, Y.T + Ds_sqr_1, Ds_sqr_2 = Ds_sqr_2, Ds_sqr_1 + Dt_sqr_1, Dt_sqr_2 = Dt_sqr_2, Dt_sqr_1 + Ds1, Ds2 = Ds2, Ds1 + Dt1, Dt2 = Dt2, Dt1 pi1, pi2 = pi.sum(1), pi.sum(0) cost = torch.zeros_like(pi) # Avoid unnecessary calculation of UGW when alpha = 0 - if alpha != 1 and D is not None: - wasserstein_cost = D / 2 + if alpha != 1 and K1 is not None and K2 is not None: + wasserstein_cost = (K1 @ K2.T) * pi / 2 cost += (1 - alpha) * wasserstein_cost # or UOT when alpha = 1 if alpha != 0: - A = X_sqr @ pi1 - B = Y_sqr @ pi2 - gromov_wasserstein_cost = ( - A[:, None] + B[None, :] - 2 * X @ pi @ Y.T - ) + A = Ds_sqr_1 @ (Ds_sqr_2.T @ pi1) + B = Dt_sqr_1 @ (Dt_sqr_2.T @ pi2) + C1, C2 = Ds1, ((Ds2.T @ torch.sparse.mm(pi, Dt2)) @ Dt1.T).T + gromov_wasserstein_cost = A[:, None] + B[None, :] - 2 * C1 @ C2.T cost += alpha * gromov_wasserstein_cost @@ -116,7 +126,13 @@ def fugw_loss( """ rho_s, rho_t, eps, alpha, reg_mode, divergence = hyperparams ws, wt, ws_dot_wt = tuple_weights - X_sqr, Y_sqr, X, Y, D = data_const + ( + (Ds_sqr_1, Ds_sqr_2), + (Dt_sqr_1, Dt_sqr_2), + (Ds1, Ds2), + (Dt1, Dt2), + (K1, K2), + ) = data_const pi1, pi2 = pi.sum(1), pi.sum(0) gamma1, gamma2 = gamma.sum(1), gamma.sum(0) @@ -128,14 +144,14 @@ def fugw_loss( loss_regularization = torch.zeros(1) loss = 0 - if alpha != 1 and D is not None: - loss_wasserstein = ((D * pi).sum() + (D * gamma).sum()) / 2 + if alpha != 1 and K1 is not None and K2 is not None: + loss_wasserstein = ((K1 @ K2.T) * (pi + gamma)).sum() / 2 loss += (1 - alpha) * loss_wasserstein if alpha != 0: - A = (X_sqr @ gamma1).dot(pi1) - B = (Y_sqr @ gamma2).dot(pi2) - C = (X @ gamma @ Y.T) * pi + A = (Ds_sqr_1 @ (Ds_sqr_2.T @ gamma1)).dot(pi1) + B = (Dt_sqr_1 @ (Dt_sqr_2.T @ gamma2)).dot(pi2) + C = (Ds1 @ ((Ds2.T @ gamma @ Dt2) @ Dt1.T)) * pi loss_gromov_wasserstein = A + B - 2 * C.sum() loss += alpha * loss_gromov_wasserstein @@ -206,11 +222,11 @@ def solve( reg_mode="joint", divergence="kl", F=None, - Ds=None, - Dt=None, - F_val=None, - Ds_val=None, - Dt_val=None, + Ds=(None, None), + Dt=(None, None), + F_val=(None, None), + Ds_val=(None, None), + Dt_val=(None, None), ws=None, wt=None, init_plan=None, @@ -229,14 +245,13 @@ def solve( eps: float, optional reg_mode: string, optional divergence: string, optional - F: matrix of size n x m. - Kernel matrix between the source and target training features. - Ds: matrix of size n x n - Dt: matrix of size m x m + F: (ndarray(n, d+2), ndarray(m, d+2)) or (None, None) + Ds: (ndarray(n, k+2), ndarray(n, k+2)), or (None, None) + Dt: (ndarray(m, k+2), ndarray(m, k+2)), or (None, None) F_val: matrix of size n x m, None Kernel matrix between the source and target validation features. - Ds_val: matrix of size n x n, None - Dt_val: matrix of size m x m, None + Ds_val: (ndarray(n, k+2), ndarray(n, k+2)), or (None, None) + Dt_val: (ndarray(m, k+2), ndarray(m, k+2)), or (None, None) ws: ndarray(n), None Measures assigned to source points. wt: ndarray(m), None @@ -310,31 +325,66 @@ def solve( if solver == "sinkhorn" and eps == 0: solver = "ibpp" - device, dtype = Ds.device, Ds.dtype + n, m = Ds[0].shape[0], Dt[0].shape[0] + device, dtype = Ds[0].device, Ds[0].dtype # constant data variables - Ds_sqr = Ds**2 - Dt_sqr = Dt**2 + Ds1, Ds2 = Ds + Ds_sqr = ( + torch.einsum("ij,il->ijl", Ds1, Ds1).reshape( + Ds1.shape[0], Ds1.shape[1] ** 2 + ), + torch.einsum("ij,il->ijl", Ds2, Ds2).reshape( + Ds2.shape[0], Ds2.shape[1] ** 2 + ), + ) + + Dt1, Dt2 = Dt + Dt_sqr = ( + torch.einsum("ij,il->ijl", Dt1, Dt1).reshape( + Dt1.shape[0], Dt1.shape[1] ** 2 + ), + torch.einsum("ij,il->ijl", Dt2, Dt2).reshape( + Dt2.shape[0], Dt2.shape[1] ** 2 + ), + ) # Same for validation data if provided - if Ds_val is not None and Dt_val is not None: - Ds_sqr_val = Ds_val**2 - Dt_sqr_val = Dt_val**2 + if Ds_val != (None, None) and Dt_val != (None, None): + Ds1_val, Ds2_val = Ds_val + Ds_sqr_val = ( + torch.einsum("ij,il->ijl", Ds1_val, Ds1_val).reshape( + Ds1_val.shape[0], Ds1_val.shape[1] ** 2 + ), + torch.einsum("ij,il->ijl", Ds2, Ds2).reshape( + Ds2_val.shape[0], Ds2_val.shape[1] ** 2 + ), + ) + + Dt1_val, Dt2_val = Dt_val + Dt_sqr_val = ( + torch.einsum("ij,il->ijl", Dt1_val, Dt1_val).reshape( + Dt1_val.shape[0], Dt1_val.shape[1] ** 2 + ), + torch.einsum("ij,il->ijl", Dt2_val, Dt2_val).reshape( + Dt2_val.shape[0], Dt2_val.shape[1] ** 2 + ), + ) else: Ds_val, Dt_val = Ds, Dt Ds_sqr_val, Dt_sqr_val = Ds_sqr, Dt_sqr - if alpha == 1 or F is None: + if alpha == 1 or F[0] is None or F[1] is None: alpha = 1 - F = None + F = (None, None) # measures on rows and columns if ws is None: - n = Ds.shape[0] + n = Ds1.shape[0] ws = torch.ones(n).to(device).to(dtype) / n if wt is None: - m = Dt.shape[0] + m = Dt1.shape[0] wt = torch.ones(m).to(device).to(dtype) / m ws_dot_wt = ws[:, None] * wt[None, :] @@ -387,6 +437,7 @@ def solve( self_solver_mm_l2 = partial( solver_mm_l2, train_params=(self.nits_uot, self.tol_uot, self.eval_uot), + verbose=verbose, ) self_get_params_uot_l2 = partial( @@ -400,12 +451,14 @@ def solve( solver_sinkhorn, tuple_weights=(ws, wt, ws_dot_wt), train_params=(self.nits_uot, self.tol_uot, self.eval_uot), + verbose=verbose, ) self_solver_mm_kl = partial( solver_mm, tuple_weights=(ws, wt), train_params=(self.nits_uot, self.tol_uot, self.eval_uot), + verbose=verbose, ) self_solver_ibpp = partial( @@ -424,7 +477,7 @@ def solve( # Initialize loss current_loss = compute_fugw_loss(pi, gamma) - if F_val is not None: + if F_val != (None, None): current_loss_validation = compute_fugw_loss_validation(pi, gamma) else: current_loss_validation = current_loss @@ -505,12 +558,12 @@ def solve( if idx % self.eval_bcd == 0: current_loss = compute_fugw_loss(pi, gamma) - if F_val is not None: + if F_val != (None, None): current_loss_validation = compute_fugw_loss_validation( pi, gamma ) else: - current_loss_validation = current_loss + current_loss_validation = deepcopy(current_loss) loss_steps.append(idx + 1) loss = _add_dict(loss, current_loss) diff --git a/tests/mappings/test_barycenter.py b/tests/mappings/test_barycenter.py index bc9305e6..b87b063a 100644 --- a/tests/mappings/test_barycenter.py +++ b/tests/mappings/test_barycenter.py @@ -23,12 +23,12 @@ def test_fugw_barycenter(device): weights_list = [] for _ in range(n_subjects): - weights, features, geometry, _ = _init_mock_distribution( + weights, features, _, geometry_embedding = _init_mock_distribution( n_features, n_voxels ) weights_list.append(weights) features_list.append(features) - geometry_list.append(geometry) + geometry_list.append(geometry_embedding) fugw_barycenter = FUGWBarycenter() fugw_barycenter.fit( diff --git a/tests/mappings/test_dense_mapping.py b/tests/mappings/test_dense_mapping.py index 78b71c86..3de73438 100644 --- a/tests/mappings/test_dense_mapping.py +++ b/tests/mappings/test_dense_mapping.py @@ -32,10 +32,10 @@ ) def test_dense_mapping(device, return_numpy, solver, callback): # Generate random training data for source and target - _, source_features_train, source_geometry, _ = _init_mock_distribution( + _, source_features_train, _, source_embeddings = _init_mock_distribution( n_features_train, n_voxels_source, return_numpy=return_numpy ) - _, target_features_train, target_geometry, _ = _init_mock_distribution( + _, target_features_train, _, target_embeddings = _init_mock_distribution( n_features_train, n_voxels_target, return_numpy=return_numpy ) @@ -43,8 +43,8 @@ def test_dense_mapping(device, return_numpy, solver, callback): fugw.fit( source_features=source_features_train, target_features=target_features_train, - source_geometry=source_geometry, - target_geometry=target_geometry, + source_geometry_embedding=source_embeddings, + target_geometry_embedding=target_embeddings, solver=solver, solver_params={ "nits_bcd": 3, @@ -100,24 +100,30 @@ def test_dense_mapping(device, return_numpy, solver, callback): def test_validation_mapping(validation): # Generate random training data for source and target # and random validation data for source and target - _, source_features_train, source_geometry, _ = _init_mock_distribution( + _, source_features_train, _, source_embeddings = _init_mock_distribution( n_features_train, n_voxels_source, return_numpy=False ) - _, target_features_train, target_geometry, _ = _init_mock_distribution( + _, target_features_train, _, target_embeddings = _init_mock_distribution( n_features_train, n_voxels_target, return_numpy=False ) - _, source_features_train_val, source_geometry_val, _ = ( - _init_mock_distribution( - n_features_train, n_voxels_source, return_numpy=False - ) + ( + _, + source_features_train_val, + _, + source_embeddings_val, + ) = _init_mock_distribution( + n_features_train, n_voxels_source, return_numpy=False ) - _, target_features_train_val, target_geometry_val, _ = ( - _init_mock_distribution( - n_features_train, n_voxels_target, return_numpy=False - ) + ( + _, + target_features_train_val, + _, + target_embeddings_val, + ) = _init_mock_distribution( + n_features_train, n_voxels_target, return_numpy=False ) fugw = FUGW() @@ -126,8 +132,8 @@ def test_validation_mapping(validation): fugw.fit( source_features=source_features_train, target_features=target_features_train, - source_geometry=source_geometry, - target_geometry=target_geometry, + source_geometry_embedding=source_embeddings, + target_geometry_embedding=target_embeddings, solver="sinkhorn", solver_params={ "nits_bcd": 3, @@ -141,8 +147,8 @@ def test_validation_mapping(validation): fugw.fit( source_features=source_features_train, target_features=target_features_train, - source_geometry=source_geometry, - target_geometry=target_geometry, + source_geometry_embedding=source_embeddings, + target_geometry_embedding=target_embeddings, source_features_val=source_features_train_val, target_features_val=target_features_train_val, solver="sinkhorn", @@ -158,10 +164,10 @@ def test_validation_mapping(validation): fugw.fit( source_features=source_features_train, target_features=target_features_train, - source_geometry=source_geometry, - target_geometry=target_geometry, - source_geometry_val=source_geometry_val, - target_geometry_val=target_geometry_val, + source_geometry_embedding=source_embeddings, + target_geometry_embedding=target_embeddings, + source_geometry_embedding_val=source_embeddings_val, + target_geometry_embedding_val=target_embeddings_val, solver="sinkhorn", solver_params={ "nits_bcd": 3, @@ -175,12 +181,12 @@ def test_validation_mapping(validation): fugw.fit( source_features=source_features_train, target_features=target_features_train, - source_geometry=source_geometry, - target_geometry=target_geometry, + source_geometry_embedding=source_embeddings, + target_geometry_embedding=target_embeddings, source_features_val=source_features_train_val, target_features_val=target_features_train_val, - source_geometry_val=source_geometry_val, - target_geometry_val=target_geometry_val, + source_geometry_embedding_val=source_embeddings_val, + target_geometry_embedding_val=target_embeddings_val, solver="sinkhorn", solver_params={ "nits_bcd": 3, @@ -193,11 +199,11 @@ def test_validation_mapping(validation): @pytest.mark.parametrize("solver", ["sinkhorn", "ibpp"]) def test_available_l2_solver(solver): - _, source_features_train, source_geometry, _ = _init_mock_distribution( + _, source_features_train, _, source_embeddings = _init_mock_distribution( n_features_train, n_voxels_source, return_numpy=False ) - _, target_features_train, target_geometry, _ = _init_mock_distribution( + _, target_features_train, _, target_embeddings = _init_mock_distribution( n_features_train, n_voxels_target, return_numpy=False ) @@ -209,8 +215,8 @@ def test_available_l2_solver(solver): mapping.fit( source_features=source_features_train, target_features=target_features_train, - source_geometry=source_geometry, - target_geometry=target_geometry, + source_geometry_embedding=source_embeddings, + target_geometry_embedding=target_embeddings, solver=solver, ) diff --git a/tests/solvers/test_dense_solver.py b/tests/solvers/test_dense_solver.py index 7355400f..82c3c54b 100644 --- a/tests/solvers/test_dense_solver.py +++ b/tests/solvers/test_dense_solver.py @@ -5,7 +5,11 @@ import torch from fugw.solvers import FUGWSolver +from fugw.utils import _low_rank_squared_l2 +devices = [torch.device("cpu")] +if torch.cuda.is_available(): + devices.append(torch.device("cuda:0")) callbacks = [None, lambda x: x["gamma"]] @@ -13,19 +17,16 @@ @pytest.mark.parametrize( - "solver,callback,alpha", - product(["sinkhorn", "mm", "ibpp"], callbacks, alphas), + "solver,device,callback,alpha", + product(["sinkhorn", "mm", "ibpp"], devices, callbacks, alphas), ) -def test_dense_solvers(solver, callback, alpha): +def test_dense_solvers(solver, device, callback, alpha): torch.manual_seed(0) - - use_cuda = torch.cuda.is_available() - device = torch.device("cuda:0" if use_cuda else "cpu") torch.backends.cudnn.benchmark = True - ns = 104 + ns = 150 ds = 3 - nt = 151 + nt = 200 dt = 7 nf = 10 @@ -34,13 +35,17 @@ def test_dense_solvers(solver, callback, alpha): source_embeddings = torch.rand(ns, ds).to(device) target_embeddings = torch.rand(nt, dt).to(device) - F = torch.cdist(source_features, target_features) - Ds = torch.cdist(source_embeddings, source_embeddings) - Dt = torch.cdist(target_embeddings, target_embeddings) + F = _low_rank_squared_l2(source_features, target_features) + Ds = _low_rank_squared_l2(source_embeddings, source_embeddings) + Dt = _low_rank_squared_l2(target_embeddings, target_embeddings) - Ds_normalized = Ds / Ds.max() - Dt_normalized = Dt / Dt.max() - F_normalized = F / F.max() + F_norm = (F[0] @ F[1].T).max() + Ds_norm = (Ds[0] @ Ds[1].T).max() + Dt_norm = (Dt[0] @ Dt[1].T).max() + + F_normalized = (F[0] / F_norm, F[1] / F_norm) + Ds_normalized = (Ds[0] / Ds_norm, Ds[1] / Ds_norm) + Dt_normalized = (Dt[0] / Dt_norm, Dt[1] / Dt_norm) nits_bcd = 100 eval_bcd = 2 @@ -57,7 +62,7 @@ def test_dense_solvers(solver, callback, alpha): rho_s = 2 rho_t = 3 - eps = 0.02 + eps = 0.5 res = fugw.solve( alpha=alpha, @@ -68,7 +73,6 @@ def test_dense_solvers(solver, callback, alpha): F=F_normalized, Ds=Ds_normalized, Dt=Dt_normalized, - init_plan=None, solver=solver, callback_bcd=callback, verbose=True, @@ -82,19 +86,16 @@ def test_dense_solvers(solver, callback, alpha): loss_steps = res["loss_steps"] loss_times = res["loss_times"] - assert pi.shape == (ns, nt) - assert gamma.shape == (ns, nt) + assert pi.size() == (ns, nt) + assert gamma.size() == (ns, nt) if solver == "mm": assert duals_pi is None assert duals_gamma is None - else: + elif solver == "ibpp": assert len(duals_pi) == 2 assert duals_pi[0].shape == (ns,) assert duals_pi[1].shape == (nt,) - assert len(duals_gamma) == 2 - assert duals_gamma[0].shape == (ns,) - assert duals_gamma[1].shape == (nt,) assert len(loss_steps) - 1 <= nits_bcd // eval_bcd + 1 assert len(loss_times) == len(loss_steps) @@ -126,17 +127,18 @@ def test_dense_solvers(solver, callback, alpha): # Loss should decrease assert np.all( - np.sign(np.array(loss["total"][1:]) - np.array(loss["total"][:-1])) + np.sign( + np.array(loss["total"][1:]) - np.array(loss["total"][:-1]) - 1e-6 + ) # numerical tolerance == -1 ) -@pytest.mark.parametrize("reg_mode", ["independent", "joint"]) -def test_dense_solvers_l2(reg_mode): +@pytest.mark.parametrize( + "reg_mode, device", product(["independent", "joint"], devices) +) +def test_dense_solvers_l2(reg_mode, device): torch.manual_seed(0) - - use_cuda = torch.cuda.is_available() - device = torch.device("cuda:0" if use_cuda else "cpu") torch.backends.cudnn.benchmark = True ns = 204 @@ -148,13 +150,19 @@ def test_dense_solvers_l2(reg_mode): source_embeddings = torch.rand(ns, ds).to(device) target_embeddings = source_embeddings - F = torch.cdist(source_features, target_features) - Ds = torch.cdist(source_embeddings, source_embeddings) - Dt = torch.cdist(target_embeddings, target_embeddings) + F = _low_rank_squared_l2(source_features, target_features) + Ds = _low_rank_squared_l2(source_embeddings, source_embeddings) + Dt = _low_rank_squared_l2(target_embeddings, target_embeddings) + + F_norm = (F[0] @ F[1].T).max() + Ds_norm = (Ds[0] @ Ds[1].T).max() + Dt_norm = (Dt[0] @ Dt[1].T).max() + + F_normalized = (F[0] / F_norm, F[1] / F_norm) + Ds_normalized = (Ds[0] / Ds_norm, Ds[1] / Ds_norm) + Dt_normalized = (Dt[0] / Dt_norm, Dt[1] / Dt_norm) - Ds_normalized = Ds / Ds.max() - Dt_normalized = Dt / Dt.max() - F_normalized = F / F.max() + init_plan = (torch.ones(ns, ns) / ns).to(device) nits_bcd = 100 eval_bcd = 2 @@ -178,8 +186,9 @@ def test_dense_solvers_l2(reg_mode): F=F_normalized, Ds=Ds_normalized, Dt=Dt_normalized, + init_plan=init_plan, solver="mm", - verbose=True, + verbose=False, ) pi = res["pi"] @@ -208,24 +217,27 @@ def test_dense_solvers_l2(reg_mode): ]: assert len(loss[key]) == len(loss_steps) # Loss should decrease - # assert np.all(np.sign(np.array(loss[1:]) - np.array(loss[:-1])) == -1) + assert np.all( + np.sign( + np.array(loss["total"][1:]) - np.array(loss["total"][:-1]) - 1e-6 + ) # numerical tolerance + == -1 + ) # Check if we can recover ground truth optimal plan (identity matrix) pi_true = np.eye(ns, ns) / ns pi_np = pi.cpu().detach().numpy() gamma_np = gamma.cpu().detach().numpy() - np.testing.assert_allclose(pi_true, pi_np, atol=1e-04) - np.testing.assert_allclose(pi_true, gamma_np, atol=1e-04) + np.testing.assert_allclose(pi_true, pi_np, atol=1e-02) + np.testing.assert_allclose(pi_true, gamma_np, atol=1e-02) @pytest.mark.parametrize( - "validation", ["None", "features", "geometries", "Both"] + "validation,device", + product(["None", "features", "geometries", "Both"], devices), ) -def test_validation_solver(validation): +def test_validation_solver(validation, device): torch.manual_seed(0) - - use_cuda = torch.cuda.is_available() - device = torch.device("cuda:0" if use_cuda else "cpu") torch.backends.cudnn.benchmark = True ns = 104 @@ -238,42 +250,62 @@ def test_validation_solver(validation): target_features = torch.rand(nt, nf).to(device) source_embeddings = torch.rand(ns, ds).to(device) target_embeddings = torch.rand(nt, dt).to(device) + source_features_val = torch.rand(ns, nf).to(device) + target_features_val = torch.rand(nt, nf).to(device) + source_embeddings_val = torch.rand(ns, ds).to(device) + target_embeddings_val = torch.rand(nt, dt).to(device) - F = torch.cdist(source_features, target_features) - Ds = torch.cdist(source_embeddings, source_embeddings) - Dt = torch.cdist(target_embeddings, target_embeddings) + F = _low_rank_squared_l2(source_features, target_features) + Ds = _low_rank_squared_l2(source_embeddings, source_embeddings) + Dt = _low_rank_squared_l2(target_embeddings, target_embeddings) - Ds_normalized = Ds / Ds.max() - Dt_normalized = Dt / Dt.max() - F_normalized = F / F.max() + F_norm = (F[0] @ F[1].T).max() + Ds_norm = (Ds[0] @ Ds[1].T).max() + Dt_norm = (Dt[0] @ Dt[1].T).max() + + F_normalized = (F[0] / F_norm, F[1] / F_norm) + Ds_normalized = (Ds[0] / Ds_norm, Ds[1] / Ds_norm) + Dt_normalized = (Dt[0] / Dt_norm, Dt[1] / Dt_norm) if validation == "None": - F_val = None - Ds_val = None - Dt_val = None + F_val_normalized = None, None + Ds_val_normalized = None, None + Dt_val_normalized = None, None elif validation == "features": - source_features_val = torch.rand(ns, nf).to(device) - target_features_val = torch.rand(nt, nf).to(device) - F_val = torch.cdist(source_features_val, target_features_val) - Ds_val = None - Dt_val = None + F_val = _low_rank_squared_l2(source_features_val, target_features_val) + F_norm_val = (F_val[0] @ F_val[1].T).max() + F_val_normalized = (F_val[0] / F_norm, F_val[1] / F_norm_val) + Ds_val_normalized = None, None + Dt_val_normalized = None, None elif validation == "geometries": - source_embeddings_val = torch.rand(ns, ds).to(device) - target_embeddings_val = torch.rand(nt, dt).to(device) - F_val = None - Ds_val = torch.cdist(source_embeddings_val, source_embeddings_val) - Dt_val = torch.cdist(target_embeddings_val, target_embeddings_val) + F_val_normalized = None, None + Ds_val = _low_rank_squared_l2( + source_embeddings_val, source_embeddings_val + ) + Dt_val = _low_rank_squared_l2( + target_embeddings_val, target_embeddings_val + ) + Ds_norm_val = (Ds_val[0] @ Ds_val[1].T).max() + Dt_norm_val = (Dt_val[0] @ Dt_val[1].T).max() + Ds_val_normalized = (Ds_val[0] / Ds_norm_val, Ds_val[1] / Ds_norm_val) + Dt_val_normalized = (Dt_val[0] / Dt_norm_val, Dt_val[1] / Dt_norm_val) elif validation == "Both": - source_features_val = torch.rand(ns, nf).to(device) - target_features_val = torch.rand(nt, nf).to(device) - F_val = torch.cdist(source_features_val, target_features_val) - source_embeddings_val = torch.rand(ns, ds).to(device) - target_embeddings_val = torch.rand(nt, dt).to(device) - Ds_val = torch.cdist(source_embeddings_val, source_embeddings_val) - Dt_val = torch.cdist(target_embeddings_val, target_embeddings_val) + F_val = _low_rank_squared_l2(source_features_val, target_features_val) + F_norm_val = (F_val[0] @ F_val[1].T).max() + F_val_normalized = (F_val[0] / F_norm, F_val[1] / F_norm_val) + Ds_val = _low_rank_squared_l2( + source_embeddings_val, source_embeddings_val + ) + Dt_val = _low_rank_squared_l2( + target_embeddings_val, target_embeddings_val + ) + Ds_norm_val = (Ds_val[0] @ Ds_val[1].T).max() + Dt_norm_val = (Dt_val[0] @ Dt_val[1].T).max() + Ds_val_normalized = (Ds_val[0] / Ds_norm_val, Ds_val[1] / Ds_norm_val) + Dt_val_normalized = (Dt_val[0] / Dt_norm_val, Dt_val[1] / Dt_norm_val) nits_bcd = 100 eval_bcd = 2 @@ -285,11 +317,13 @@ def test_validation_solver(validation): tol_loss=1e-5, eval_bcd=eval_bcd, eval_uot=10, + # Set a high value of ibpp, otherwise nans appear in coupling. + # This will generally increase the computed fugw loss. ibpp_eps_base=1e2, ) res = fugw.solve( - alpha=0.8, + alpha=0.2, rho_s=2, rho_t=3, eps=0.02, @@ -297,11 +331,10 @@ def test_validation_solver(validation): F=F_normalized, Ds=Ds_normalized, Dt=Dt_normalized, - F_val=F_val, - Ds_val=Ds_val, - Dt_val=Dt_val, - init_plan=None, - solver="sinkhorn", + F_val=F_val_normalized, + Ds_val=Ds_val_normalized, + Dt_val=Dt_val_normalized, + solver="mm", verbose=True, ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 9020d238..fbed8c5b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -93,10 +93,10 @@ def test__make_tensor_preserve_type(): "device,return_numpy,solver", product(devices, return_numpys, solvers) ) def test_saving_and_loading(device, return_numpy, solver): - _, source_features_train, source_geometry, _ = _init_mock_distribution( + _, source_features_train, _, source_embeddings = _init_mock_distribution( n_features_train, n_voxels_source, return_numpy=return_numpy ) - _, target_features_train, target_geometry, _ = _init_mock_distribution( + _, target_features_train, _, target_embeddings = _init_mock_distribution( n_features_train, n_voxels_target, return_numpy=return_numpy ) @@ -104,8 +104,8 @@ def test_saving_and_loading(device, return_numpy, solver): fugw.fit( source_features=source_features_train, target_features=target_features_train, - source_geometry=source_geometry, - target_geometry=target_geometry, + source_geometry_embedding=source_embeddings, + target_geometry_embedding=target_embeddings, solver=solver, solver_params={ "nits_bcd": 3,