diff --git a/tests/test_heat.py b/tests/test_heat.py new file mode 100644 index 0000000..ab59427 --- /dev/null +++ b/tests/test_heat.py @@ -0,0 +1,46 @@ +import pytest +import torch +from torchcfm.diffusion_distance import HeatKernelKNN, torch_knn_from_data + +DEVICES = ["cpu"] +if torch.cuda.is_available(): + DEVICES.append("cuda") + +def gt_heat_kernel_knn( + data, + t, + k, +): + L = torch_knn_from_data(data, k=k, projection=False, proj_dim=10) + # eigendecomposition + eigvals, eigvecs = torch.linalg.eigh(L) + # compute the heat kernel + heat_kernel = eigvecs @ torch.diag(torch.exp(-t * eigvals)) @ eigvecs.T + heat_kernel = (heat_kernel + heat_kernel.T) / 2 + heat_kernel[heat_kernel < 0] = 0.0 + return heat_kernel + + +@pytest.mark.parametrize("t", [0.1, 1.0,]) +@pytest.mark.parametrize("order", [10, 30, 50]) +@pytest.mark.parametrize("k", [10, 20]) +@pytest.mark.parametrize("device", DEVICES) +def test_heat_kernel_knn(t, order, k, device): + tol = 2e-1 if t > 1.0 else 1e-1 + data = torch.randn(100, 5) + data = data.to(device) + heat_op = HeatKernelKNN(k=k, t=t, order=order, graph_type="scanpy") + heat_kernel = heat_op(data) + + # test if symmetric + assert torch.allclose(heat_kernel, heat_kernel.T) + + # test if positive + assert torch.all(heat_kernel >= 0) + + # test if the heat kernel is close to the ground truth + gt_heat_kernel = gt_heat_kernel_knn(data, t=t, k=k) + assert torch.allclose(heat_kernel, gt_heat_kernel, atol=tol, rtol=tol) + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file diff --git a/torchcfm/cheb_approx.py b/torchcfm/cheb_approx.py new file mode 100644 index 0000000..2f5ba3d --- /dev/null +++ b/torchcfm/cheb_approx.py @@ -0,0 +1,38 @@ +import typing as T +import numpy as np +import torch +from scipy.special import ive + + +def expm_multiply( + L: torch.Tensor, + X: torch.Tensor, + coeff: torch.Tensor, + eigval: T.Union[torch.Tensor, np.ndarray], +): + """Matrix exponential with chebyshev polynomial approximation.""" + + def body(carry, c): + T0, T1, Y = carry + T2 = (2.0 / eigval) * torch.matmul(L, T1) - 2.0 * T1 - T0 + Y = Y + c * T2 + return (T1, T2, Y) + + T0 = X + Y = 0.5 * coeff[0] * T0 + T1 = (1.0 / eigval) * torch.matmul(L, X) - T0 + Y = Y + coeff[1] * T1 + + initial_state = (T0, T1, Y) + for c in coeff[2:]: + initial_state = body(initial_state, c) + + _, _, Y = initial_state + + return Y + + +@torch.no_grad() +def compute_chebychev_coeff_all(eigval, t, K): + eigval = eigval.detach().cpu() + return 2.0 * ive(torch.arange(0, K + 1, device=eigval.device), -t * eigval) diff --git a/torchcfm/diffusion_distance.py b/torchcfm/diffusion_distance.py new file mode 100644 index 0000000..127b864 --- /dev/null +++ b/torchcfm/diffusion_distance.py @@ -0,0 +1,119 @@ +import torch +from torchcfm.cheb_approx import compute_chebychev_coeff_all, expm_multiply + +try: + import scanpy as sc +except ImportError: + pass + +EPS_LOG = 1e-6 +EPS_HEAT = 1e-4 + +def norm_sym_laplacian(A: torch.Tensor): + deg = A.sum(dim=1) + deg_sqrt_inv = torch.diag(1.0 / torch.sqrt(deg + EPS_LOG)) + id = torch.eye(A.shape[0], device=A.device, dtype=A.dtype) + return id - deg_sqrt_inv @ A @ deg_sqrt_inv + + +def torch_knn_from_data( + data: torch.Tensor, k: int, projection: bool = False, proj_dim: int = 100 +): + if projection: + _, _, V = torch.pca_lowrank(data, q=proj_dim, center=True) + data = data @ V + dist = torch.cdist(data, data) + _, indices = torch.topk(dist, k, largest=False) + affinity = torch.zeros(data.shape[0], data.shape[0]) + affinity.scatter_(1, indices, 1) + return norm_sym_laplacian(affinity) + + +def scanpy_knn_from_data( + data: torch.Tensor, k: int, projection: bool = False, proj_dim: int = 100 +): + adata = sc.AnnData(data.numpy()) + if projection: + sc.pp.pca(adata, n_comps=proj_dim) + sc.pp.neighbors( + adata, n_neighbors=k, use_rep="X_pca" if projection else None + ) + return norm_sym_laplacian( + torch.tensor(adata.obsp["connectivities"].toarray(), device=data.device) + ) + + +def var_fn(x, t): + outer = torch.outer(torch.diag(x), torch.ones(x.shape[0])) + vol_approx = (outer + outer.T) * 0.5 + return -t * torch.log(x + EPS_LOG) + t * torch.log(vol_approx + EPS_LOG) + + +class BaseHeatKernel: + def __init__(self, t: float = 1.0, order: int = 30): + self.t = t + self.order = order + self.dist_fn = var_fn + self.graph_fn = None + + def __call__(self, data: torch.Tensor): + if self.graph_fn is None: + raise NotImplementedError("graph_fn is not implemented") + L = self.graph_fn(data) + heat_kernel = self.compute_heat_from_laplacian(L) + heat_kernel = self.sym_clip(heat_kernel) + return heat_kernel + + def compute_heat_from_laplacian(self, L: torch.Tensor): + n = L.shape[0] + val = torch.linalg.eigvals(L).real + max_eigval = val.max() + cheb_coeff = compute_chebychev_coeff_all( + 0.5 * max_eigval, self.t, self.order + ) + heat_kernel = expm_multiply( + L, torch.eye(n), cheb_coeff, 0.5 * max_eigval + ) + return heat_kernel + + def sym_clip(self, heat_kernel: torch.Tensor): + heat_kernel = (heat_kernel + heat_kernel.T) / 2 + heat_kernel[heat_kernel < 0] = 0.0 + EPS_HEAT + return heat_kernel + + def fit(self, data: torch.Tensor, dist_type: str = "var"): + assert dist_type in self.dist_fn + heat_kernel = self(data) + return self.dist_fn[dist_type](heat_kernel, self.t) + + +class HeatKernelKNN(BaseHeatKernel): + """Approximation of the heat kernel with a graph from a k-nearest neighbors affinity matrix. + Uses Chebyshev polynomial approximation. + """ + + _is_differentiable = False + _implemented_graph = { + "torch": torch_knn_from_data, + "scanpy": scanpy_knn_from_data, + } + + def __init__( + self, + k: int = 10, + order: int = 30, + t: float = 1.0, + projection: bool = False, + proj_dim: int = 100, + graph_type: str = "scanpy", + ): + super().__init__(t=t, order=order) + assert ( + graph_type in self._implemented_graph + ), f"Type must be in {self._implemented_graph}" + self.k = k + self.projection = projection + self.proj_dim = proj_dim + self.graph_fn = lambda x: self._implemented_graph[graph_type]( + x, self.k, projection=self.projection, proj_dim=self.proj_dim + )