diff --git a/src/rapids_singlecell/squidpy_gpu/__init__.py b/src/rapids_singlecell/squidpy_gpu/__init__.py index 168e44d5..a4c91694 100644 --- a/src/rapids_singlecell/squidpy_gpu/__init__.py +++ b/src/rapids_singlecell/squidpy_gpu/__init__.py @@ -3,3 +3,4 @@ from ._autocorr import spatial_autocorr from ._co_oc import co_occurrence from ._ligrec import ligrec +from ._sepal import sepal diff --git a/src/rapids_singlecell/squidpy_gpu/_sepal.py b/src/rapids_singlecell/squidpy_gpu/_sepal.py new file mode 100644 index 00000000..d7a393dc --- /dev/null +++ b/src/rapids_singlecell/squidpy_gpu/_sepal.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +import cupy as cp +import pandas as pd +from anndata import AnnData +from cupyx.scipy.sparse import csr_matrix as cp_csr_matrix +from cupyx.scipy.sparse import isspmatrix_csc as cp_isspmatrix_csc +from cupyx.scipy.sparse import isspmatrix_csr as cp_isspmatrix_csr +from scanpy import logging as logg + +from .kernels._sepal import ( + _get_get_nhood_idx_with_distance, + _get_sepal_simulation, +) + +if TYPE_CHECKING: + from collections.abc import Sequence + + +def sepal( + adata: AnnData, + max_neighs: Literal[4, 6], + genes: str | Sequence[str] | None = None, + n_iter: int = 30000, + dt: float = 0.001, + thresh: float = 1e-8, + connectivity_key: str = "spatial_connectivities", + spatial_key: str = "spatial", + layer: str | None = None, + copy: bool = False, +) -> pd.DataFrame | None: + """ + GPU-accelerated sepal implementation with unlimited scalability. + Handles datasets from thousands to millions of cells. + + Grid/block configuration follows established patterns: + - threads_per_block = 256 (as in src/rapids_singlecell/preprocessing/_harmony/_helper.py) + - 1D grid sizing with ceil division (as in src/rapids_singlecell/preprocessing/_harmony/_helper.py) + - Shared memory allocation for entropy computation (similar to co-occurrence kernels) + """ + # won't support SpatialData to avoid dependencies on spatialdata + assert isinstance(adata, AnnData), "adata must be an AnnData object" + + # _assert_connectivity_key(adata, connectivity_key) + assert connectivity_key in adata.obsp, ( + f"Connectivity key {connectivity_key} not found in adata.obsp" + ) + # _assert_spatial_basis(adata, key=spatial_key) replacement + assert spatial_key in adata.obsm, ( + f"Spatial key {spatial_key} not found in adata.obsm" + ) + + if max_neighs not in (4, 6): + raise ValueError( + f"Expected `max_neighs` to be either `4` or `6`, found `{max_neighs}`." + ) + + # Setup spatial coordinates as float32 (standard for spatial data) + spatial = cp.asarray(adata.obsm[spatial_key], dtype=cp.float32) + + # replacement for _assert_non_empty_sequence + if genes is None: + genes = adata.var_names.values + if "highly_variable" in adata.var.columns: + genes = genes[adata.var["highly_variable"].values] + if len(genes) == 0: + raise ValueError("No genes found") + + # Graph and index computation + g = adata.obsp[connectivity_key] + if not cp_isspmatrix_csr(g): + g = cp_csr_matrix(g) + g.eliminate_zeros() + + degrees = cp.diff(g.indptr) + max_n = degrees.max() + if max_n != max_neighs: + raise ValueError( + f"Expected `max_neighs={max_neighs}`, found node with `{max_n}` neighbors." + ) + + sat, sat_idx, unsat, unsat_to_nearest_sat = _compute_idxs( + g=g, + degrees=degrees, + spatial=spatial, + sat_thresh=max_neighs, + ) + + # replacement for _extract_expression + if layer is None: + vals = adata[:, genes].X + elif layer not in adata.layers: + raise KeyError(f"Layer `{layer}` not found in `adata.layers`.") + else: + vals = adata[:, genes].layers[layer] + if isinstance(vals, AnnData): + vals = vals.X + start = logg.info( + f"Calculating sepal score for `{len(genes)}` genes using scalable GPU kernel" + ) + + if cp_isspmatrix_csr(vals) or cp_isspmatrix_csc(vals): + vals = vals.toarray() + + # Use double precision for numerical stability in simulation + vals = cp.ascontiguousarray(cp.asarray(vals, dtype=cp.float64)) + + # Run scalable simulation - handles ANY dataset size! + scores = _cuda_kernel_diffusion_gpu( + vals=vals, + sat=sat, + sat_idx=sat_idx, + unsat=unsat, + unsat_to_nearest_sat=unsat_to_nearest_sat, + max_neighs=max_neighs, + n_iter=n_iter, + dt=dt, + thresh=thresh, + ) + + # Results processing + score = cp.asnumpy(scores) + + key_added = "sepal_score" + sepal_score = pd.DataFrame(score, index=genes, columns=[key_added]) + + if sepal_score[key_added].isna().any(): + logg.warning( + "Found `NaN` in sepal scores, consider increasing `n_iter` to a higher value" + ) + sepal_score = sepal_score.sort_values(by=key_added, ascending=False) + + if copy: + logg.info("Finish", time=start) + return sepal_score + + # replacement for _save_data + if not copy: + adata.uns[key_added] = sepal_score + return sepal_score + + +def _cuda_kernel_diffusion_gpu( + vals: cp.ndarray, # (n_cells, n_genes) - all gene expressions + sat: cp.ndarray, # (n_sat,) - saturated node indices + sat_idx: cp.ndarray, # (n_sat, max_neighs) - neighborhood indices for sat nodes + unsat: cp.ndarray, # (n_unsat,) - unsaturated node indices + unsat_to_nearest_sat: cp.ndarray, # (n_unsat,) - nearest sat for each unsat + max_neighs: int, + n_iter: int, + dt: float, + thresh: float, +) -> cp.ndarray: + n_cells, n_genes = vals.shape + n_sat = len(sat) + n_unsat = len(unsat) + + # Grid/block configuration following established patterns: + # threads_per_block = 256 (as in src/rapids_singlecell/preprocessing/_harmony/_helper.py) + threads_per_block = 256 + blocks_per_grid = n_genes # Process ALL genes in parallel! + + # Allocate arrays for ALL genes at once + concentration_all = cp.ascontiguousarray( + vals.T, dtype=cp.float64 + ) # (n_genes, n_cells) + derivatives_all = cp.zeros((n_genes, n_cells), dtype=cp.float64) + results_all = cp.full(n_genes, -999999.0, dtype=cp.float64) # Results for ALL genes + + # Calculate shared memory (fixed size per block, independent of n_cells) + min_blocks = 256 # Hardware-specific minimum + blocks_per_grid = max(n_genes, min_blocks) + shared_mem_size = threads_per_block * 2 * 8 # 2 double arrays per thread + + # Get specialized kernel using cuda_kernel_factory pattern + sepal_simulation_kernel = _get_sepal_simulation(derivatives_all.dtype) + + # **SINGLE KERNEL LAUNCH FOR ALL GENES** + sepal_simulation_kernel( + (blocks_per_grid,), # Grid: one block per gene + (threads_per_block,), # Block: 256 threads + ( + concentration_all, # (n_genes, n_cells) - all genes + derivatives_all, # (n_genes, n_cells) - all derivatives + sat, + sat_idx, + unsat, + unsat_to_nearest_sat, + results_all, # (n_genes,) - results for all genes + n_cells, # n_cells (can be 1M+) + n_genes, # Number of genes to process + n_sat, + n_unsat, + max_neighs, + n_iter, + cp.float64(dt), + cp.float64(thresh), + ), + shared_mem=shared_mem_size, + ) + + # Convert results + final_scores = cp.where(results_all < 0.0, cp.nan, dt * results_all) + + return final_scores # Shape: (n_genes,) + + +def _compute_idxs( + g: cp_csr_matrix, + degrees: cp.ndarray, + spatial: cp.ndarray, + sat_thresh: int, +) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray, cp.ndarray]: + """Compute saturated/unsaturated indices on GPU with unified distance computation. + + Grid/block configuration follows established patterns: + - threads_per_block = 256 (as in src/rapids_singlecell/preprocessing/_harmony/_helper.py) + - 1D grid sizing with ceil division (as in src/rapids_singlecell/preprocessing/_harmony/_helper.py) + """ + + # Get saturated and unsaturated nodes + unsat_mask = degrees < sat_thresh + sat_mask = degrees == sat_thresh + + unsat = cp.asarray(cp.where(unsat_mask)[0], dtype=cp.int32) + sat = cp.asarray(cp.where(sat_mask)[0], dtype=cp.int32) + + # Extract saturated neighborhoods with vectorized CuPy + nearest_sat = cp.full(len(unsat), -1, dtype=cp.int32) + sat_idx = g.indices[g.indptr[sat][:, None] + cp.arange(sat_thresh)] + + # Single kernel handles both graph neighbors and distance fallback + if len(unsat) > 0: + # Grid/block configuration following established patterns: + # threads_per_block = 256 (as in src/rapids_singlecell/preprocessing/_harmony/_helper.py) + threads_per_block = 256 + blocks = (len(unsat) + threads_per_block - 1) // threads_per_block + + # Get specialized kernel using cuda_kernel_factory pattern + get_nhood_kernel = _get_get_nhood_idx_with_distance(spatial.dtype) + + get_nhood_kernel( + (blocks,), + (threads_per_block,), + ( + unsat, # unsaturated nodes (read only int32) + spatial, # spatial coordinates [n_nodes, 2] (read only float64) + sat, # saturated node list (read only int32) + g.indptr, # CSR indptr (read only int32) + g.indices, # CSR indices (read only int32) + sat_mask, # boolean mask for saturated nodes (read only bool) + nearest_sat, # output int32 + len(unsat), # number of unsaturated nodes read only int32 + len(sat), # number of saturated nodes read only int32 + ), + ) + + return sat, sat_idx, unsat, nearest_sat diff --git a/src/rapids_singlecell/squidpy_gpu/kernels/_sepal.py b/src/rapids_singlecell/squidpy_gpu/kernels/_sepal.py new file mode 100644 index 00000000..52dc6cca --- /dev/null +++ b/src/rapids_singlecell/squidpy_gpu/kernels/_sepal.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +from cuml.common.kernel_utils import cuda_kernel_factory + +# Kernel for finding nearest saturated node for each unsaturated node +get_nhood_idx_with_distance_kernel = r""" +( + const int* __restrict__ unsat_nodes, + const {0}* __restrict__ spatial, + const int* __restrict__ sat_nodes, + const int* __restrict__ g_indptr, + const int* __restrict__ g_indices, + const bool* __restrict__ sat_mask, + int* __restrict__ nearest_sat, + int n_unsat, + int n_sat +) +{{ + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n_unsat) return; + + int node = unsat_nodes[tid]; + {0} node_x = spatial[node * 2]; + {0} node_y = spatial[node * 2 + 1]; + + {0} min_dist = -1.0; // -1.0 means no closest sat found yet + int closest = -1; + + // Phase 1: Check graph neighbors for saturated nodes + int start = g_indptr[node]; + int end = g_indptr[node + 1]; + + for (int i = start; i < end; i++) {{ + int neighbor = g_indices[i]; + if (sat_mask[neighbor]) {{ + closest = neighbor; // Take first + break; // Stop immediately + }} + }} + + // Phase 2: If no saturated graph neighbors, search ALL saturated nodes + if (closest == -1) {{ + for (int i = 0; i < n_sat; i++) {{ + int sat_node = sat_nodes[i]; + {0} sat_x = spatial[sat_node * 2]; + {0} sat_y = spatial[sat_node * 2 + 1]; + {0} dist = fabs(node_x - sat_x) + fabs(node_y - sat_y); + + if (min_dist < 0.0 || dist < min_dist) {{ + min_dist = dist; + closest = sat_node; + }} + }} + }} + + nearest_sat[tid] = closest; +}} +""" + +# SEPAL simulation kernel with device function for entropy computation +# This kernel must be compiled with CuPy's RawKernel due to device functions +sepal_simulation_kernel = r""" +extern "C" { + // Device function: Computes entropy using cooperative thread reduction + __device__ double compute_entropy_cooperative( + const double* __restrict__ conc, + int n_sat, + const int* __restrict__ sat_nodes, + int tid, + int blockSize + ) { + __shared__ double total_sum_shared[256]; + __shared__ double entropy_shared[256]; + // np.finfo(np.float64).eps # ~2.22e-16 + const double eps = 2.220446049250313e-16; + + // Each thread accumulates its portion of nodes + double local_sum = 0.0; + for (int i = tid; i < n_sat; i += blockSize) { + double val = conc[sat_nodes[i]]; + if (val > eps) local_sum += val; + } + + total_sum_shared[tid] = local_sum; + __syncthreads(); + + // Parallel reduction to sum all values + for (int s = blockSize / 2; s > 0; s >>= 1) { + if (tid < s) { + total_sum_shared[tid] += total_sum_shared[tid + s]; + } + __syncthreads(); + } + + double total_sum = total_sum_shared[0]; + if (total_sum < eps) return 0.0; + // see here why + // https://stats.stackexchange.com/questions/57069/alternative-to-shannons-entropy-when-probability-equal-to-zero/433096 + + // Each thread computes entropy for its portion + double local_entropy = 0.0; + for (int i = tid; i < n_sat; i += blockSize) { + double val = conc[sat_nodes[i]]; + if (val > eps) { + double normalized = val / total_sum; + local_entropy += -normalized * log(fmax(normalized, eps)); + } + } + + entropy_shared[tid] = local_entropy; + __syncthreads(); + + // Parallel reduction for entropy + for (int s = blockSize / 2; s > 0; s >>= 1) { + if (tid < s) { + entropy_shared[tid] += entropy_shared[tid + s]; + } + __syncthreads(); + } + + return entropy_shared[0] / (double)(n_sat); + } + + // Main SEPAL simulation kernel - processes one gene per block + __global__ void sepal_simulation( + double* __restrict__ concentration_all, + double* __restrict__ derivatives_all, + const int* __restrict__ sat_nodes, + const int* __restrict__ sat_idx, + const int* __restrict__ unsat_nodes, + const int* __restrict__ unsat_idx, + double* __restrict__ results, + int n_cells, + int n_genes, + int n_sat, + int n_unsat, + int max_neighs, + int n_iter, + double dt, + double thresh + ) { + int gene_idx = blockIdx.x; // Each block handles one gene + int tid = threadIdx.x; // Thread ID (0-255) + int blockSize = blockDim.x; // 256 threads per block + + if (gene_idx >= n_genes) return; + + // Per-gene pointers into global arrays + double* concentration = &concentration_all[gene_idx * n_cells]; + double* derivatives = &derivatives_all[gene_idx * n_cells]; + + // Convergence tracking + __shared__ double prev_entropy; + __shared__ int convergence_iter; + __shared__ bool converged_flag; + + if (tid == 0) { + prev_entropy = 1.0; + convergence_iter = -1; + converged_flag = false; + } + __syncthreads(); + + // Main iteration loop - all threads process their portion of nodes + for (int iter = 0; iter < n_iter; iter++) { + // Phase 1: Update derivatives for saturated nodes + for (int i = tid; i < n_sat; i += blockSize) { + double neighbor_sum = 0.0; + for (int j = 0; j < max_neighs; j++) { + neighbor_sum += concentration[sat_idx[i * max_neighs + j]]; + } + int sat_global_idx = sat_nodes[i]; + double center = concentration[sat_global_idx]; + double d2 = 0.0; + + if (max_neighs == 4) { + d2 = (neighbor_sum - 4.0 * center); + } else if (max_neighs == 6) { + d2 = (2.0 * neighbor_sum - 12.0 * center) / 3.0; + } + derivatives[sat_global_idx] = d2; + } + __syncthreads(); + + // Phase 2: Update saturated node concentrations + for (int i = tid; i < n_sat; i += blockSize) { + int sat_global_idx = sat_nodes[i]; + concentration[sat_global_idx] += derivatives[sat_global_idx] * dt; + concentration[sat_global_idx] = fmax(0.0, concentration[sat_global_idx]); + } + __syncthreads(); + + // Phase 3: Update unsaturated nodes based on nearest saturated + for (int i = tid; i < n_unsat; i += blockSize) { + int unsat_global_idx = unsat_nodes[i]; + concentration[unsat_global_idx] += derivatives[unsat_idx[i]] * dt; + concentration[unsat_global_idx] = fmax(0.0, concentration[unsat_global_idx]); + } + __syncthreads(); + + // Check convergence using entropy + if (!converged_flag) { + double current_entropy = compute_entropy_cooperative( + concentration, n_sat, sat_nodes, tid, blockSize + ); + + if (tid == 0) { + double entropy_diff = fabs(current_entropy - prev_entropy); + if (entropy_diff <= thresh && convergence_iter == -1) { + convergence_iter = iter; + converged_flag = true; + } + prev_entropy = current_entropy; + } + __syncthreads(); + } + + if (converged_flag) break; + } + + // Store result for this gene + if (tid == 0) { + results[gene_idx] = convergence_iter >= 0 ? + (double)convergence_iter : -1.0; + } + } +} +""" + + +def _get_get_nhood_idx_with_distance(dtype): + """Get neighborhood index with distance kernel specialized for the given dtype. + + This kernel finds the nearest saturated node for each unsaturated node. + First checks graph neighbors, then falls back to spatial distance search. + """ + return cuda_kernel_factory( + get_nhood_idx_with_distance_kernel, (dtype,), "get_nhood_idx_with_distance" + ) + + +def _get_sepal_simulation(dtype=None): + """Get SEPAL simulation kernel. + + This kernel simulates diffusion for multiple genes in parallel. + Each block processes one gene with 256 threads cooperating. + Uses double precision for concentration values. + + Parameters + ---------- + dtype : dtype, optional + Ignored. The kernel always uses double precision for numerical stability. + This parameter exists for API compatibility. + + Note + ---- + This kernel uses device functions and is compiled via CuPy's RawKernel + rather than cuda_kernel_factory, as it requires special handling for + the device function (compute_entropy_cooperative). + """ + import cupy as cp + + return cp.RawKernel(sepal_simulation_kernel, "sepal_simulation") diff --git a/tests/test_sepal_tmp.py b/tests/test_sepal_tmp.py new file mode 100644 index 00000000..64ab3c67 --- /dev/null +++ b/tests/test_sepal_tmp.py @@ -0,0 +1,215 @@ +"""Tests for sepal GPU implementation.""" + +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest +from anndata import AnnData +from scipy import sparse + +import rapids_singlecell as rsc + + +@pytest.fixture +def synthetic_spatial_data(): + """Create synthetic spatial data for testing sepal.""" + # Create a small 3x3 grid (9 cells) + n_cells = 9 + n_genes = 5 + + # Spatial coordinates (3x3 grid) + spatial_coords = np.array( + [[0, 0], [1, 0], [2, 0], [0, 1], [1, 1], [2, 1], [0, 2], [1, 2], [2, 2]], + dtype=np.float32, + ) + + # Create connectivity matrix (4-neighbors for rectangular grid) + connectivity = np.zeros((n_cells, n_cells), dtype=np.float32) + for i in range(n_cells): + row, col = i // 3, i % 3 + # Add connections to neighbors + for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]: + nr, nc = row + dr, col + dc + if 0 <= nr < 3 and 0 <= nc < 3: + j = nr * 3 + nc + connectivity[i, j] = 1.0 + + # Create expression data with some spatial pattern + expression = np.random.poisson(5, (n_cells, n_genes)).astype(np.float32) + # Add spatial gradient to first gene + expression[:, 0] += np.arange(n_cells) * 2 + + # Gene names + gene_names = [f"Gene_{i}" for i in range(n_genes)] + + adata = AnnData(X=sparse.csr_matrix(expression), obsm={"spatial": spatial_coords}) + adata.var_names = gene_names + adata.obsp["spatial_connectivities"] = sparse.csr_matrix(connectivity) + + return adata + + +@pytest.fixture +def synthetic_hex_data(): + """Create synthetic hexagonal grid data for testing sepal.""" + # Create a small hexagonal grid (7 cells in center + 6 neighbors) + n_cells = 7 + n_genes = 3 + + # Spatial coordinates (hexagonal pattern) + spatial_coords = np.array( + [ + [0, 0], # center + [1, 0], # right + [0.5, 0.866], + [0.5, -0.866], # top, bottom + [-0.5, 0.866], + [-0.5, -0.866], # top-left, bottom-left + [-1, 0], # left + ], + dtype=np.float32, + ) + + # Create connectivity matrix (6-neighbors for hexagonal grid) + connectivity = np.zeros((n_cells, n_cells), dtype=np.float32) + # Center connects to all others + for i in range(1, n_cells): + connectivity[0, i] = 1.0 + connectivity[i, 0] = 1.0 + + # Create expression data + expression = np.random.poisson(3, (n_cells, n_genes)).astype(np.float32) + gene_names = [f"HexGene_{i}" for i in range(n_genes)] + + adata = AnnData(X=sparse.csr_matrix(expression), obsm={"spatial": spatial_coords}) + adata.var_names = gene_names + adata.obsp["spatial_connectivities"] = sparse.csr_matrix(connectivity) + + return adata + + +def test_sepal_rectangular_grid(synthetic_spatial_data): + """Test sepal on rectangular grid (4-neighbors).""" + adata = synthetic_spatial_data.copy() + + # Run sepal with small number of iterations for testing + result = rsc.gr.sepal( + adata, + max_neighs=4, + n_iter=100, # Small number for testing + copy=True, + ) + + # Check result type and shape + assert isinstance(result, pd.DataFrame) + assert result.shape == (5, 1) # 5 genes, 1 score column + assert "sepal_score" in result.columns + + # Check no NaN values + assert not result["sepal_score"].isna().any() + + # Check scores are sorted descending + assert result["sepal_score"].is_monotonic_decreasing + + # Check gene names match + assert list(result.index) == [f"Gene_{i}" for i in range(5)] + + +def test_sepal_hexagonal_grid(synthetic_hex_data): + """Test sepal on hexagonal grid (6-neighbors).""" + adata = synthetic_hex_data.copy() + + # Run sepal with small number of iterations for testing + result = rsc.gr.sepal( + adata, + max_neighs=6, + n_iter=50, # Small number for testing + copy=True, + ) + + # Check result type and shape + assert isinstance(result, pd.DataFrame) + assert result.shape == (3, 1) # 3 genes, 1 score column + assert "sepal_score" in result.columns + + # Check no NaN values + assert not result["sepal_score"].isna().any() + + # Check scores are sorted descending + assert result["sepal_score"].is_monotonic_decreasing + + +def test_sepal_inplace_storage(synthetic_spatial_data): + """Test sepal stores results in adata.uns when copy=False.""" + adata = synthetic_spatial_data.copy() + + # Run sepal in-place + result = rsc.gr.sepal(adata, max_neighs=4, n_iter=50, copy=False) + + # Should return None + assert result is None + + # Check results stored in adata.uns + assert "sepal_score" in adata.uns + stored_result = adata.uns["sepal_score"] + + # Check stored result + assert isinstance(stored_result, pd.DataFrame) + assert stored_result.shape == (5, 1) + assert "sepal_score" in stored_result.columns + + +def test_sepal_gene_selection(synthetic_spatial_data): + """Test sepal with specific gene selection.""" + adata = synthetic_spatial_data.copy() + + # Select only first 2 genes + selected_genes = ["Gene_0", "Gene_1"] + + result = rsc.gr.sepal( + adata, max_neighs=4, genes=selected_genes, n_iter=50, copy=True + ) + + # Check only selected genes are in result + assert result.shape == (2, 1) + assert list(result.index) == selected_genes + + +def test_sepal_validation_errors(synthetic_spatial_data): + """Test sepal input validation.""" + adata = synthetic_spatial_data.copy() + + # Test invalid max_neighs + with pytest.raises( + ValueError, match="Expected `max_neighs` to be either `4` or `6`" + ): + rsc.gr.sepal(adata, max_neighs=5, copy=True) + + # Test missing connectivity + adata.obsp.pop("spatial_connectivities") + with pytest.raises(KeyError, match="Connectivity matrix"): + rsc.gr.sepal(adata, max_neighs=4, copy=True) + + # Test missing spatial coordinates + adata = synthetic_spatial_data.copy() + adata.obsm.pop("spatial") + with pytest.raises(KeyError, match="Spatial coordinates"): + rsc.gr.sepal(adata, max_neighs=4, copy=True) + + +def test_sepal_connectivity_mismatch(synthetic_spatial_data): + """Test sepal with connectivity that doesn't match max_neighs.""" + adata = synthetic_spatial_data.copy() + + # Modify connectivity to have 6 neighbors for some cells + connectivity = adata.obsp["spatial_connectivities"].toarray() + connectivity[0, 5] = 1.0 # Add extra connection + connectivity[5, 0] = 1.0 + adata.obsp["spatial_connectivities"] = sparse.csr_matrix(connectivity) + + # Should raise error when max_neighs=4 but some cells have 5 neighbors + with pytest.raises( + ValueError, match="Expected `max_neighs=4`, found node with `5` neighbors" + ): + rsc.gr.sepal(adata, max_neighs=4, copy=True) diff --git a/tmp_scripts/compare_cpu_gpu.py b/tmp_scripts/compare_cpu_gpu.py new file mode 100644 index 00000000..ddcc035b --- /dev/null +++ b/tmp_scripts/compare_cpu_gpu.py @@ -0,0 +1,221 @@ +from __future__ import annotations + +import time +import warnings +from pathlib import Path + +import anndata as ad +import pandas as pd +from scipy.stats import spearmanr +from squidpy.gr import sepal as sepal_cpu + +import rapids_singlecell as rsc +from rapids_singlecell.squidpy_gpu import sepal as sepal_gpu + +warnings.filterwarnings("ignore") + +HOME = Path.home() + + +def main(): + # Load data + print("Loading data...") + adata = ad.read_h5ad(HOME / "data/visium_hne_adata.h5ad") + + # Test on first 10 genes + n_genes = 15 + genes = adata.var_names.values[:n_genes].tolist() + + print(f"\nTesting on first {n_genes} genes:") + print(", ".join(genes)) + + # Run CPU version + print("\n" + "=" * 80) + print("Running CPU version...") + print("=" * 80) + adata_cpu = adata.copy() + start_time = time.time() + result_cpu = sepal_cpu( + adata_cpu, max_neighs=6, genes=genes, n_iter=30000, copy=True + ) + cpu_time = time.time() - start_time + print(f"CPU Time: {cpu_time:.2f} seconds") + + # Run GPU version + print("\n" + "=" * 80) + print("Running GPU version...") + print("=" * 80) + adata_gpu = adata.copy() + rsc.get.anndata_to_GPU(adata_gpu, convert_all=True) + adata_gpu.obsp["spatial_connectivities"] = rsc.get.X_to_GPU( + adata_gpu.obsp["spatial_connectivities"] + ) + adata_gpu.obsm["spatial"] = rsc.get.X_to_GPU(adata_gpu.obsm["spatial"]) + + start_time = time.time() + result_gpu = sepal_gpu( + adata_gpu, max_neighs=6, genes=genes, n_iter=30000, copy=True + ) + gpu_time = time.time() - start_time + print(f"GPU Time: {gpu_time:.2f} seconds") + + # Prepare comparison + print("\n" + "=" * 80) + print("RESULTS COMPARISON") + print("=" * 80) + + # Merge results and calculate ranks + comparison = pd.DataFrame( + { + "Gene": result_cpu.index, + "CPU_Score": result_cpu["sepal_score"].values, + "GPU_Score": result_gpu["sepal_score"].values, + } + ) + + # Calculate ranks (1 = highest score) + comparison["CPU_Rank"] = ( + comparison["CPU_Score"].rank(ascending=False, method="min").astype(int) + ) + comparison["GPU_Rank"] = ( + comparison["GPU_Score"].rank(ascending=False, method="min").astype(int) + ) + comparison["Rank_Diff"] = abs(comparison["CPU_Rank"] - comparison["GPU_Rank"]) + + # Calculate absolute and relative differences in scores + comparison["Score_Diff"] = abs(comparison["CPU_Score"] - comparison["GPU_Score"]) + comparison["Rel_Diff_%"] = ( + comparison["Score_Diff"] / comparison["CPU_Score"].abs() + ) * 100 + + # Calculate correlations + # Spearman correlation (rank-based) + spearman_corr, spearman_pval = spearmanr( + comparison["CPU_Score"], comparison["GPU_Score"] + ) + + # Also calculate Spearman on explicit ranks for clarity + spearman_rank_corr, spearman_rank_pval = spearmanr( + comparison["CPU_Rank"], comparison["GPU_Rank"] + ) + + # Pearson correlation (on scores) + from scipy.stats import pearsonr + + pearson_corr, pearson_pval = pearsonr( + comparison["CPU_Score"], comparison["GPU_Score"] + ) + + # Display overall metrics + print(f"\n{'CORRELATION METRICS':^80}") + print("=" * 80) + print( + f"Spearman Rank Correlation: {spearman_corr:.6f} (p-value: {spearman_pval:.2e})" + ) + print( + f" (on explicit ranks): {spearman_rank_corr:.6f} (p-value: {spearman_rank_pval:.2e})" + ) + print( + f"Pearson Correlation: {pearson_corr:.6f} (p-value: {pearson_pval:.2e})" + ) + print(f"\nSpeedup: {cpu_time / gpu_time:.2f}x") + + # Side-by-side comparison with ranks + print("\n" + "=" * 80) + print(f"{'SIDE-BY-SIDE COMPARISON (Sorted by CPU Rank)':^80}") + print("=" * 80) + + # Format the table for display + pd.set_option("display.max_columns", None) + pd.set_option("display.width", None) + pd.set_option("display.max_rows", None) + + # Create display dataframe + display_df = comparison[ + [ + "Gene", + "CPU_Score", + "CPU_Rank", + "GPU_Score", + "GPU_Rank", + "Rank_Diff", + "Score_Diff", + "Rel_Diff_%", + ] + ].copy() + display_df = display_df.sort_values("CPU_Rank") + + # Format for better readability + print( + display_df.to_string( + index=False, + formatters={ + "CPU_Score": "{:.3f}".format, + "GPU_Score": "{:.3f}".format, + "Score_Diff": "{:.3f}".format, + "Rel_Diff_%": "{:.2f}".format, + }, + ) + ) + + # Summary statistics + print("\n" + "=" * 80) + print(f"{'SUMMARY STATISTICS':^80}") + print("=" * 80) + + print("\nScore Differences:") + print(f" Mean Absolute Difference: {comparison['Score_Diff'].mean():.6f}") + print(f" Max Absolute Difference: {comparison['Score_Diff'].max():.6f}") + print(f" Mean Relative Difference: {comparison['Rel_Diff_%'].mean():.2f}%") + print(f" Max Relative Difference: {comparison['Rel_Diff_%'].max():.2f}%") + + print("\nRank Differences:") + print(f" Mean Rank Difference: {comparison['Rank_Diff'].mean():.2f}") + print(f" Max Rank Difference: {comparison['Rank_Diff'].max()}") + print( + f" Perfect Rank Matches: {(comparison['Rank_Diff'] == 0).sum()}/{len(comparison)}" + ) + print( + f" Within 1 Rank: {(comparison['Rank_Diff'] <= 1).sum()}/{len(comparison)}" + ) + print( + f" Within 2 Ranks: {(comparison['Rank_Diff'] <= 2).sum()}/{len(comparison)}" + ) + + # Top genes comparison + print("\n" + "=" * 80) + print(f"{'TOP 5 GENES COMPARISON':^80}") + print("=" * 80) + + cpu_top5 = comparison.nsmallest(5, "CPU_Rank")[ + ["Gene", "CPU_Score", "CPU_Rank"] + ].reset_index(drop=True) + gpu_top5 = comparison.nsmallest(5, "GPU_Rank")[ + ["Gene", "GPU_Score", "GPU_Rank"] + ].reset_index(drop=True) + + print("\nTop 5 by CPU:") + print(cpu_top5.to_string(index=False, formatters={"CPU_Score": "{:.3f}".format})) + + print("\nTop 5 by GPU:") + print(gpu_top5.to_string(index=False, formatters={"GPU_Score": "{:.3f}".format})) + + # Check overlap in top genes + cpu_top5_genes = set(cpu_top5["Gene"]) + gpu_top5_genes = set(gpu_top5["Gene"]) + overlap = cpu_top5_genes & gpu_top5_genes + print(f"\nTop 5 Overlap: {len(overlap)}/5 genes") + if overlap: + print(f"Common genes: {', '.join(sorted(overlap))}") + + # Save results + output_file = HOME / "rapids_singlecell/tmp_scripts/comparison_results.csv" + comparison_sorted = comparison.sort_values("CPU_Rank") + comparison_sorted.to_csv(output_file, index=False) + print(f"\n{'=' * 80}") + print(f"Results saved to: {output_file}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/tmp_scripts/compare_helpers.py b/tmp_scripts/compare_helpers.py new file mode 100644 index 00000000..74b5dc7a --- /dev/null +++ b/tmp_scripts/compare_helpers.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import warnings + +warnings.filterwarnings("ignore") +import os +import time +from pathlib import Path + +import anndata as ad +import cupy as cp +import numpy as np + +import rapids_singlecell as rsc + +# Add utils to path for GPU version + +HOME = Path(os.path.expanduser("~")) + + +def compare_indices(adata_cpu): + """ + Compare the saturated/unsaturated indices between CPU and GPU versions. + """ + print("🔍 Comparing CPU vs GPU index computation...") + + # Import the helper functions + from utils._sepal import _compute_idxs as _compute_idxs_cpu + + from rapids_singlecell.squidpy_gpu._sepal import _compute_idxs as _compute_idxs_gpu + + # Get connectivity and spatial data + adata_gpu = ad.read_h5ad(HOME / "data/visium_hne_adata.h5ad") + rsc.get.anndata_to_GPU(adata_gpu, convert_all=True) + g = adata_cpu.obsp["spatial_connectivities"] + g_gpu = adata_gpu.obsp["spatial_connectivities"] + degrees = cp.diff(g_gpu.indptr) + spatial_cpu = adata_cpu.obsm["spatial"].astype(np.float64) + spatial_gpu = adata_gpu.obsm["spatial"].astype(cp.float32) + + g_gpu = rsc.get.X_to_GPU(g_gpu) + spatial_gpu = rsc.get.X_to_GPU(spatial_gpu) + + # Compute indices with both methods + start = time.time() + degrees = cp.diff(g_gpu.indptr) + sat_gpu, sat_idx_gpu, unsat_gpu, unsat_idx_gpu = _compute_idxs_gpu( + g_gpu, degrees, spatial_gpu, 6 + ) + end = time.time() + print("GPU indices computed in ", end - start, "seconds") + + start = time.time() + sat_cpu, sat_idx_cpu, unsat_cpu, unsat_idx_cpu = _compute_idxs_cpu( + g, spatial_cpu, 6, "l1" + ) + end = time.time() + print("CPU indices computed in ", end - start, "seconds") + # Convert GPU results to CPU for comparison + sat_gpu_cpu = sat_gpu.get() + sat_idx_gpu_cpu = sat_idx_gpu.get() + unsat_gpu_cpu = unsat_gpu.get() + unsat_idx_gpu_cpu = unsat_idx_gpu.get() + + print(f"Saturated nodes - CPU: {len(sat_cpu)}, GPU: {len(sat_gpu_cpu)}") + print(f"Saturated nodes identical: {np.array_equal(sat_cpu, sat_gpu_cpu)}") + + print(f"Unsaturated nodes - CPU: {len(unsat_cpu)}, GPU: {len(unsat_gpu_cpu)}") + print(f"Unsaturated nodes identical: {np.array_equal(unsat_cpu, unsat_gpu_cpu)}") + + print( + f"Saturated indices identical: {np.array_equal(sat_idx_cpu, sat_idx_gpu_cpu)}" + ) + + # Check unsat_idx differences (these might differ due to tie-breaking) + unsat_idx_diff = np.sum(unsat_idx_cpu != unsat_idx_gpu_cpu) + print( + f"Unsaturated index differences: {unsat_idx_diff}/{len(unsat_idx_cpu)} ({100 * unsat_idx_diff / len(unsat_idx_cpu):.1f}%)" + ) + + return { + "sat_identical": np.array_equal(sat_cpu, sat_gpu_cpu), + "unsat_identical": np.array_equal(unsat_cpu, unsat_gpu_cpu), + "sat_idx_identical": np.array_equal(sat_idx_cpu, sat_idx_gpu_cpu), + "unsat_idx_diff_count": unsat_idx_diff, + "unsat_idx_diff_percent": 100 * unsat_idx_diff / len(unsat_idx_cpu), + } + + +if __name__ == "__main__": + # Run comparison + adata_cpu = ad.read_h5ad(HOME / "data/visium_hne_adata.h5ad") + res = compare_indices(adata_cpu) + print(res) diff --git a/tmp_scripts/comparison_results.csv b/tmp_scripts/comparison_results.csv new file mode 100644 index 00000000..ec4b6cf7 --- /dev/null +++ b/tmp_scripts/comparison_results.csv @@ -0,0 +1,16 @@ +Gene,CPU_Score,GPU_Score,CPU_Rank,GPU_Rank,Rank_Diff,Score_Diff,Rel_Diff_% +4732440D04Rik,2.783,3.238,1,1,0,0.45500000000000007,16.34926338483651 +Npbwr1,2.578,2.8890000000000002,2,2,0,0.3110000000000004,12.06361520558574 +St18,2.503,1.6,3,3,0,0.903,36.07670795045944 +Gm26901,2.251,1.479,4,4,0,0.7719999999999998,34.2958685028876 +Oprk1,1.537,1.431,5,5,0,0.10599999999999987,6.896551724137923 +Xkr4,1.427,1.234,6,6,0,0.19300000000000006,13.524877365101615 +Sox17,1.368,1.037,7,7,0,0.3310000000000002,24.195906432748547 +Sntg1,0.992,0.971,8,8,0,0.02100000000000002,2.1169354838709697 +Lypla1,0.835,0.752,9,9,0,0.08299999999999996,9.940119760479039 +Rgs20,0.786,0.6890000000000001,10,10,0,0.09699999999999998,12.34096692111959 +Mrpl15,0.687,0.613,11,11,0,0.07400000000000007,10.771470160116458 +Pcmtd1,0.684,0.591,12,12,0,0.09300000000000008,13.596491228070187 +Tcea1,0.601,0.5690000000000001,13,13,0,0.03199999999999992,5.324459234608971 +Rb1cc1,0.598,0.554,14,14,0,0.04399999999999993,7.357859531772563 +Atp6v1h,0.47800000000000004,0.47300000000000003,15,15,0,0.0050000000000000044,1.0460251046025113 diff --git a/tmp_scripts/debug_comparison.py b/tmp_scripts/debug_comparison.py new file mode 100644 index 00000000..546c65e1 --- /dev/null +++ b/tmp_scripts/debug_comparison.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import warnings + +warnings.filterwarnings("ignore") +import os +from pathlib import Path + +import anndata as ad +from utils.sepal_cpu import sepal +from utils.sepal_gpu import sepal_gpu + +import rapids_singlecell as rsc + +HOME = Path(os.path.expanduser("~")) + + +def debug_sepal_comparison(): + # Load data + adata = ad.read_h5ad(HOME / "data/visium_hne_adata.h5ad") + + # Test with single gene first + test_genes = ["Npbwr1"] # Just one gene + + print("=== CPU vs GPU Comparison ===") + + # CPU version + cpu_result = sepal( + adata, max_neighs=6, genes=test_genes, n_iter=4000, copy=True, debug=True + ) + print(f"CPU result: {cpu_result.iloc[0, 0]}") + + # GPU version + adata = ad.read_h5ad(HOME / "data/visium_hne_adata.h5ad") + rsc.get.anndata_to_GPU(adata, convert_all=True) + adata.obsp["spatial_connectivities"] = rsc.get.X_to_GPU( + adata.obsp["spatial_connectivities"] + ) + adata.obsm["spatial"] = rsc.get.X_to_GPU(adata.obsm["spatial"]) + + gpu_result = sepal_gpu( + adata, max_neighs=6, genes=test_genes, n_iter=4000, copy=True, debug=True + ) + print(f"GPU result: {gpu_result.iloc[0, 0]}") + + print(f"Difference: {abs(cpu_result.iloc[0, 0] - gpu_result.iloc[0, 0])}") + + +if __name__ == "__main__": + debug_sepal_comparison() diff --git a/tmp_scripts/prepare_data.py b/tmp_scripts/prepare_data.py new file mode 100644 index 00000000..c25c3fb9 --- /dev/null +++ b/tmp_scripts/prepare_data.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import os +from pathlib import Path + +import squidpy as sq + +HOME = Path(os.path.expanduser("~")) +if __name__ == "__main__": + (HOME / "data").mkdir(parents=True, exist_ok=True) + adata = sq.datasets.visium_hne_adata() + sq.gr.spatial_neighbors(adata) + adata.write_h5ad(HOME / "data/visium_hne_adata.h5ad") diff --git a/tmp_scripts/run.py b/tmp_scripts/run.py new file mode 100644 index 00000000..c7ee97d6 --- /dev/null +++ b/tmp_scripts/run.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import warnings + +warnings.filterwarnings("ignore") + +import os +import time +from argparse import ArgumentParser +from pathlib import Path + +import anndata as ad +from utils._sepal import sepal as sepal_cpu + +import rapids_singlecell as rsc +from rapids_singlecell.squidpy_gpu import sepal as sepal_gpu + +HOME = Path(os.path.expanduser("~")) +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--debug", action="store_true") + parser.add_argument("--cpu", action="store_true") + args = parser.parse_args() + adata = ad.read_h5ad(HOME / "data/visium_hne_adata.h5ad") + # sc.pp.normalize_total(adata) + if not args.cpu: + rsc.get.anndata_to_GPU(adata, convert_all=True) + adata.obsp["spatial_connectivities"] = rsc.get.X_to_GPU( + adata.obsp["spatial_connectivities"] + ) + adata.obsm["spatial"] = rsc.get.X_to_GPU(adata.obsm["spatial"]) + start_time = time.time() + genes = adata.var_names.values[:10] + # genes = ["Gm29570"] + if args.cpu: + result = sepal_cpu(adata, max_neighs=6, genes=genes, n_iter=30000, copy=True) + else: + result = sepal_gpu(adata, max_neighs=6, genes=genes, n_iter=30000, copy=True) + end_time = time.time() + print(f"Time taken: {end_time - start_time} seconds") + + result.sort_values(by="sepal_score", ascending=False, inplace=True) + print(result.head(10)) diff --git a/tmp_scripts/run_cpu.py b/tmp_scripts/run_cpu.py new file mode 100644 index 00000000..571b582d --- /dev/null +++ b/tmp_scripts/run_cpu.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import warnings + +warnings.filterwarnings("ignore") +import os +import time +from argparse import ArgumentParser +from pathlib import Path + +import anndata as ad +from utils._sepal import sepal + +HOME = Path(os.path.expanduser("~")) + +if __name__ == "__main__": + parser = ArgumentParser() + args = parser.parse_args() + adata = ad.read_h5ad(HOME / "data/visium_hne_adata.h5ad") + start_time = time.time() + genes = adata.var_names.values[:100] + genes = ["Gm29570"] + # sc.pp.normalize_total(adata) + result = sepal( + adata, + max_neighs=6, + genes=genes, + n_iter=30000, + copy=True, + ) + end_time = time.time() + print(f"Time taken: {end_time - start_time} seconds") + result.sort_values(by="sepal_score", ascending=False, inplace=True) + print(result.head(10)) diff --git a/tmp_scripts/run_gpu.py b/tmp_scripts/run_gpu.py new file mode 100644 index 00000000..dd2d53d3 --- /dev/null +++ b/tmp_scripts/run_gpu.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import warnings + +warnings.filterwarnings("ignore") + +import os +import time +from argparse import ArgumentParser +from pathlib import Path + +import anndata as ad + +import rapids_singlecell as rsc +from rapids_singlecell.squidpy_gpu import sepal + +HOME = Path(os.path.expanduser("~")) +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--debug", action="store_true") + args = parser.parse_args() + adata = ad.read_h5ad(HOME / "data/visium_hne_adata.h5ad") + # sc.pp.normalize_total(adata) + rsc.get.anndata_to_GPU(adata, convert_all=True) + adata.obsp["spatial_connectivities"] = rsc.get.X_to_GPU( + adata.obsp["spatial_connectivities"] + ) + adata.obsm["spatial"] = rsc.get.X_to_GPU(adata.obsm["spatial"]) + start_time = time.time() + genes = adata.var_names.values[:100] + genes = ["Gm29570"] + result = sepal(adata, max_neighs=6, genes=genes, n_iter=30000, copy=True) + end_time = time.time() + print(f"Time taken: {end_time - start_time} seconds") + + result.sort_values(by="sepal_score", ascending=False, inplace=True) + print(result.head(10)) diff --git a/tmp_scripts/test.py b/tmp_scripts/test.py new file mode 100644 index 00000000..342543be --- /dev/null +++ b/tmp_scripts/test.py @@ -0,0 +1,208 @@ +from __future__ import annotations + +import numpy as np +from numba import cuda + + +@cuda.jit +def compute_all_neighbor_sums_kernel( + concentration, sat_idx, neighbor_sums, n_sat, sat_thresh +): + """Kernel to compute neighbor sums for ALL saturated nodes in parallel""" + idx = cuda.grid(1) + + if idx < n_sat: + neighbor_sum = 0.0 + # Sum all neighbors for this saturated node + for j in range(sat_thresh): + neighbor_idx = sat_idx[idx * sat_thresh + j] + neighbor_sum += concentration[neighbor_idx] + + neighbor_sums[idx] = neighbor_sum + + +@cuda.jit +def compute_neighbor_sums_multiple_per_thread( + concentration, sat_idx, neighbor_sums, n_sat, sat_thresh +): + """ + Optimized kernel where each thread can process multiple saturated nodes + Better for small n_sat values + """ + tid = cuda.threadIdx.x + block_size = cuda.blockDim.x + + # Each thread processes multiple nodes (stride loop) + for idx in range(tid, n_sat, block_size): + neighbor_sum = 0.0 + for j in range(sat_thresh): + neighbor_idx = sat_idx[idx * sat_thresh + j] + neighbor_sum += concentration[neighbor_idx] + neighbor_sums[idx] = neighbor_sum + + +def compute_neighbor_sums_adaptive(concentration_host, sat_idx_host, n_sat, sat_thresh): + """ + Adaptive function that chooses the best kernel launch configuration + based on dataset size to avoid GPU under-utilization + """ + + # Transfer to GPU + concentration = cuda.to_device(concentration_host) + sat_idx = cuda.to_device(sat_idx_host) + neighbor_sums = cuda.device_array(n_sat, dtype=np.float64) + + # Adaptive kernel configuration + if n_sat < 256: + # Small dataset: Use single block with stride loop + threads_per_block = min(256, max(32, n_sat)) # At least 32 (one warp) + blocks_per_grid = 1 + + print( + f"Small dataset mode: {blocks_per_grid} block, {threads_per_block} threads" + ) + compute_neighbor_sums_multiple_per_thread[blocks_per_grid, threads_per_block]( + concentration, sat_idx, neighbor_sums, n_sat, sat_thresh + ) + else: + # Large dataset: Use multiple blocks + threads_per_block = 256 + blocks_per_grid = (n_sat + threads_per_block - 1) // threads_per_block + + print( + f"Large dataset mode: {blocks_per_grid} blocks, {threads_per_block} threads" + ) + compute_all_neighbor_sums_kernel[blocks_per_grid, threads_per_block]( + concentration, sat_idx, neighbor_sums, n_sat, sat_thresh + ) + + return neighbor_sums.copy_to_host() + + +# Test with different dataset sizes +def test_different_sizes(): + """Test the adaptive approach with different dataset sizes""" + + test_cases = [ + (100, 20, 4), # Small: under-utilization case + (1000, 200, 4), # Medium + (10000, 2000, 4), # Large + (50000, 8000, 6), # Very large with hexagonal grid + ] + + for n_nodes, n_sat, sat_thresh in test_cases: + print( + f"\n=== Testing: {n_nodes} nodes, {n_sat} saturated, {sat_thresh} neighbors ===" + ) + + # Create test data + concentration = np.random.rand(n_nodes).astype(np.float64) + sat_idx = np.random.randint(0, n_nodes, size=(n_sat * sat_thresh,)).astype( + np.int32 + ) + + # GPU computation + gpu_sums = compute_neighbor_sums_adaptive( + concentration, sat_idx, n_sat, sat_thresh + ) + + # CPU verification (only for smaller datasets) + if n_sat <= 1000: # Don't verify very large datasets (too slow) + neighbor_sums_cpu = np.zeros(n_sat) + for i in range(n_sat): + for j in range(sat_thresh): + neighbor_idx = sat_idx[i * sat_thresh + j] + neighbor_sums_cpu[i] += concentration[neighbor_idx] + + max_diff = np.max(np.abs(gpu_sums - neighbor_sums_cpu)) + print(f"Max difference: {max_diff}") + assert max_diff < 1e-10, "GPU and CPU results don't match!" + print("✓ Verification passed!") + else: + print("✓ Large dataset - skipping CPU verification") + + +# Alternative: Work-efficient approach for very small datasets +@cuda.jit +def compute_neighbor_sums_warp_efficient( + concentration, sat_idx, neighbor_sums, n_sat, sat_thresh +): + """ + Work-efficient kernel that uses full warps even for small datasets + by processing multiple operations per thread + """ + tid = cuda.threadIdx.x + warp_id = tid // 32 + lane_id = tid % 32 + + # Process multiple saturated nodes per warp + nodes_per_warp = max(1, 32 // sat_thresh) # How many nodes can one warp handle + + for warp_batch in range(0, n_sat, nodes_per_warp): + node_idx = warp_batch + (tid // sat_thresh) + neighbor_idx_in_node = tid % sat_thresh + + if node_idx < n_sat and neighbor_idx_in_node < sat_thresh: + # Each thread loads one neighbor value + sat_idx_flat = node_idx * sat_thresh + neighbor_idx_in_node + neighbor_val = concentration[sat_idx[sat_idx_flat]] + + # Use warp shuffle to sum within the warp + # (This is more complex but very efficient for small sat_thresh) + for offset in range(1, sat_thresh): + neighbor_val += cuda.shfl_down_sync(0xFFFFFFFF, neighbor_val, offset) + + # First thread in each group writes the result + if neighbor_idx_in_node == 0: + neighbor_sums[node_idx] = neighbor_val + + +def benchmark_approaches(): + """Benchmark different approaches""" + import time + + # Test parameters + n_nodes = 10000 + n_sat = 2000 + sat_thresh = 4 + + concentration = np.random.rand(n_nodes).astype(np.float64) + sat_idx = np.random.randint(0, n_nodes, size=(n_sat * sat_thresh,)).astype(np.int32) + + # Warm up GPU + _ = compute_neighbor_sums_adaptive(concentration, sat_idx, n_sat, sat_thresh) + + # Benchmark + n_runs = 10 + + start = time.time() + for _ in range(n_runs): + result = compute_neighbor_sums_adaptive( + concentration, sat_idx, n_sat, sat_thresh + ) + cuda.synchronize() # Ensure completion + gpu_time = (time.time() - start) / n_runs + + # CPU baseline + start = time.time() + neighbor_sums_cpu = np.zeros(n_sat) + for i in range(n_sat): + for j in range(sat_thresh): + neighbor_idx = sat_idx[i * sat_thresh + j] + neighbor_sums_cpu[i] += concentration[neighbor_idx] + cpu_time = time.time() - start + + print("\nBenchmark Results:") + print(f"CPU time: {cpu_time:.4f} seconds") + print(f"GPU time: {gpu_time:.4f} seconds") + print(f"Speedup: {cpu_time / gpu_time:.2f}x") + print(f"Max difference: {np.max(np.abs(result - neighbor_sums_cpu))}") + + +if __name__ == "__main__": + print("Testing adaptive neighbor sum computation...") + test_different_sizes() + + print("\n" + "=" * 50) + print("Running benchmark...") + benchmark_approaches() diff --git a/tmp_scripts/utils/_sepal.py b/tmp_scripts/utils/_sepal.py new file mode 100644 index 00000000..8cde471e --- /dev/null +++ b/tmp_scripts/utils/_sepal.py @@ -0,0 +1,366 @@ +from __future__ import annotations + +from collections.abc import Callable, Sequence +from typing import Literal + +import numpy as np +import pandas as pd +from anndata import AnnData +from numba import njit +from scanpy import logging as logg +from scipy.sparse import csr_matrix, isspmatrix_csr, spmatrix +from sklearn.metrics import pairwise_distances +from spatialdata import SpatialData +from squidpy._constants._pkg_constants import Key +from squidpy._docs import d, inject_docs +from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize +from squidpy.gr._utils import ( + _assert_connectivity_key, + _assert_non_empty_sequence, + _assert_spatial_basis, + _extract_expression, + _save_data, +) + +__all__ = ["sepal"] + + +@d.dedent +@inject_docs(key=Key.obsp.spatial_conn()) +def sepal( + adata: AnnData | SpatialData, + max_neighs: Literal[4, 6], + genes: str | Sequence[str] | None = None, + n_iter: int | None = 30000, + dt: float = 0.001, + thresh: float = 1e-8, + connectivity_key: str = Key.obsp.spatial_conn(), + spatial_key: str = Key.obsm.spatial, + layer: str | None = None, + use_raw: bool = False, + copy: bool = False, + n_jobs: int | None = None, + backend: str = "loky", + show_progress_bar: bool = True, +) -> pd.DataFrame | None: + """ + Identify spatially variable genes with *Sepal*. + + *Sepal* is a method that simulates a diffusion process to quantify spatial structure in tissue. + See :cite:`andersson2021` for reference. + + Parameters + ---------- + %(adata)s + max_neighs + Maximum number of neighbors of a node in the graph. Valid options are: + + - `4` - for a square-grid (ST, Dbit-seq). + - `6` - for a hexagonal-grid (Visium). + genes + List of gene names, as stored in :attr:`anndata.AnnData.var_names`, used to compute sepal score. + + If `None`, it's computed :attr:`anndata.AnnData.var` ``['highly_variable']``, if present. + Otherwise, it's computed for all genes. + n_iter + Maximum number of iterations for the diffusion simulation. + If ``n_iter`` iterations are reached, the simulation will terminate + even though convergence has not been achieved. + dt + Time step in diffusion simulation. + thresh + Entropy threshold for convergence of diffusion simulation. + %(conn_key)s + %(spatial_key)s + layer + Layer in :attr:`anndata.AnnData.layers` to use. If `None`, use :attr:`anndata.AnnData.X`. + use_raw + Whether to access :attr:`anndata.AnnData.raw`. + %(copy)s + %(parallelize)s + + Returns + ------- + If ``copy = True``, returns a :class:`pandas.DataFrame` with the sepal scores. + + Otherwise, modifies the ``adata`` with the following key: + + - :attr:`anndata.AnnData.uns` ``['sepal_score']`` - the sepal scores. + + Notes + ----- + If some genes in :attr:`anndata.AnnData.uns` ``['sepal_score']`` are `NaN`, + consider re-running the function with increased ``n_iter``. + """ + if isinstance(adata, SpatialData): + adata = adata.table + _assert_connectivity_key(adata, connectivity_key) + _assert_spatial_basis(adata, key=spatial_key) + if max_neighs not in (4, 6): + raise ValueError( + f"Expected `max_neighs` to be either `4` or `6`, found `{max_neighs}`." + ) + + spatial = adata.obsm[spatial_key].astype(np.float64) + + if genes is None: + genes = adata.var_names.values + if "highly_variable" in adata.var.columns: + genes = genes[adata.var["highly_variable"].values] + genes = _assert_non_empty_sequence(genes, name="genes") + + n_jobs = _get_n_cores(n_jobs) + + g = adata.obsp[connectivity_key] + if not isspmatrix_csr(g): + g = csr_matrix(g) + g.eliminate_zeros() + + max_n = np.diff(g.indptr).max() + if max_n != max_neighs: + raise ValueError( + f"Expected `max_neighs={max_neighs}`, found node with `{max_n}` neighbors." + ) + + # get saturated/unsaturated nodes + sat, sat_idx, unsat, unsat_idx = _compute_idxs(g, spatial, max_neighs, "l1") + + # get counts + vals, genes = _extract_expression(adata, genes=genes, use_raw=use_raw, layer=layer) + start = logg.info( + f"Calculating sepal score for `{len(genes)}` genes using `{n_jobs}` core(s)" + ) + + score = parallelize( + _score_helper, + collection=np.arange(len(genes)).tolist(), + extractor=np.hstack, + use_ixs=False, + n_jobs=n_jobs, + backend=backend, + show_progress_bar=show_progress_bar, + )( + vals=vals, + max_neighs=max_neighs, + n_iter=n_iter, + sat=sat, + sat_idx=sat_idx, + unsat=unsat, + unsat_idx=unsat_idx, + dt=dt, + thresh=thresh, + ) + + key_added = "sepal_score" + sepal_score = pd.DataFrame(score, index=genes, columns=[key_added]) + + if sepal_score[key_added].isna().any(): + logg.warning( + "Found `NaN` in sepal scores, consider increasing `n_iter` to a higher value" + ) + sepal_score = sepal_score.sort_values(by=key_added, ascending=False) + + if copy: + logg.info("Finish", time=start) + return sepal_score + + _save_data(adata, attr="uns", key=key_added, data=sepal_score, time=start) + + +def _score_helper( + ixs: Sequence[int], + vals: spmatrix | NDArrayA, + max_neighs: int, + n_iter: int, + sat: NDArrayA, + sat_idx: NDArrayA, + unsat: NDArrayA, + unsat_idx: NDArrayA, + dt: float, + thresh: float, + queue: SigQueue | None = None, +) -> NDArrayA: + if max_neighs == 4: + fun = _laplacian_rect + elif max_neighs == 6: + fun = _laplacian_hex + else: + raise NotImplementedError( + f"Laplacian for `{max_neighs}` neighbors is not yet implemented." + ) + + score = [] + for i in ixs: + if isinstance(vals, spmatrix): + conc = vals[:, i].toarray().flatten() # Safe to call toarray() + else: + conc = vals[:, i].copy() # vals is assumed to be a NumPy array here + + time_iter = _diffusion( + conc, fun, n_iter, sat, sat_idx, unsat, unsat_idx, dt=dt, thresh=thresh + ) + score.append(dt * time_iter) + + if queue is not None: + queue.put(Signal.UPDATE) + + if queue is not None: + queue.put(Signal.FINISH) + + return np.array(score) + + +@njit(fastmath=True) +def _diffusion( + conc: NDArrayA, + laplacian: Callable[[NDArrayA, NDArrayA], float], + n_iter: int, + sat: NDArrayA, + sat_idx: NDArrayA, + unsat: NDArrayA, + unsat_idx: NDArrayA, + dt: float = 0.001, + thresh: float = 1e-8, +) -> float: + """Simulate diffusion process on a regular graph.""" + sat_shape, conc_shape = sat.shape[0], conc.shape[0] + entropy_arr = np.zeros(n_iter) + prev_ent = 1.0 + nhood = np.zeros(sat_shape) + + for i in range(n_iter): + for j in range(sat_shape): + nhood[j] = np.sum(conc[sat_idx[j]]) + d2 = laplacian(conc[sat], nhood) + + dcdt = np.zeros(conc_shape) + dcdt[sat] = d2 + conc[sat] += dcdt[sat] * dt + conc[unsat] += dcdt[unsat_idx] * dt + # set values below zero to 0 + conc[conc < 0] = 0 + # compute entropy + ent = _entropy(conc[sat]) / sat_shape + entropy_arr[i] = np.abs(ent - prev_ent) # estimate entropy difference + prev_ent = ent + if entropy_arr[i] <= thresh: + break + + tmp = np.nonzero(entropy_arr <= thresh)[0] + return float(tmp[0] if len(tmp) else np.nan) + + +# taken from https://github.com/almaan/sepal/blob/master/sepal/models.py +@njit(parallel=False, fastmath=True) +def _laplacian_rect( + centers: NDArrayA, + nbrs: NDArrayA, +) -> NDArrayA: + """ + Five point stencil approximation on rectilinear grid. + + See `Wikipedia `_ for more information. + """ + d2f: NDArrayA = nbrs - 4 * centers + return d2f + + +# taken from https://github.com/almaan/sepal/blob/master/sepal/models.py +@njit(fastmath=True) +def _laplacian_hex( + centers: NDArrayA, + nbrs: NDArrayA, +) -> NDArrayA: + """ + Seven point stencil approximation on hexagonal grid. + + References + ---------- + Approximate Methods of Higher Analysis, + Curtis D. Benster, L.V. Kantorovich, V.I. Krylov, + ISBN-13: 978-0486821603. + """ + d2f: NDArrayA = (2.0 * nbrs - 12.0 * centers) / 3.0 + return d2f + + +# taken from https://github.com/almaan/sepal/blob/master/sepal/models.py +@njit(fastmath=True) +def _entropy( + xx: NDArrayA, +) -> float: + """Compute Shannon entropy of an array of probability values (in nats).""" + xnz = xx[xx > 0] + xs: np.float64 = np.sum(xnz) + eps = np.finfo(np.float64).eps # ~2.22e-16 + if xs < eps: + # 0 because + # xn represents probabilities + # and p(x)=0 is taken as 0 entropy + # see https://stats.stackexchange.com/a/433096 + return 0.0 + xn = xnz / xs + xl = np.log(np.maximum(xn, eps)) + return float((-xl * xn).sum()) + + +def _compute_idxs( + g: spmatrix, spatial: NDArrayA, sat_thresh: int, metric: str = "l1" +) -> tuple[NDArrayA, NDArrayA, NDArrayA, NDArrayA]: + """Get saturated and unsaturated nodes and neighborhood indices.""" + sat, unsat = _get_sat_unsat_idx(g.indptr, g.shape[0], sat_thresh) + + sat_idx, nearest_sat, un_unsat = _get_nhood_idx( + sat, unsat, g.indptr, g.indices, sat_thresh + ) + + # compute dist btwn remaining unsat and all sat + dist = pairwise_distances(spatial[un_unsat], spatial[sat], metric=metric) + # assign closest sat to remaining nearest_sat + nearest_sat[np.isnan(nearest_sat)] = sat[np.argmin(dist, axis=1)] + + return sat, sat_idx, unsat, nearest_sat.astype(np.int32) + + +@njit +def _get_sat_unsat_idx( + g_indptr: NDArrayA, g_shape: int, sat_thresh: int +) -> tuple[NDArrayA, NDArrayA]: + """Get saturated and unsaturated nodes based on thresh.""" + n_indices = np.diff(g_indptr) + unsat = np.arange(g_shape)[n_indices < sat_thresh] + sat = np.arange(g_shape)[n_indices == sat_thresh] + + return sat, unsat + + +@njit +def _get_nhood_idx( + sat: NDArrayA, + unsat: NDArrayA, + g_indptr: NDArrayA, + g_indices: NDArrayA, + sat_thresh: int, +) -> tuple[NDArrayA, NDArrayA, NDArrayA]: + """Get saturated and unsaturated neighborhood indices.""" + # get saturated nhood indices + sat_idx = np.zeros((sat.shape[0], sat_thresh)) + for idx in range(sat.shape[0]): + i = sat[idx] + sat_idx[idx] = g_indices[g_indptr[i] : g_indptr[i + 1]] + + # get closest saturated of unsaturated + nearest_sat = np.full_like(unsat, fill_value=np.nan, dtype=np.float64) + for idx in range(unsat.shape[0]): + i = unsat[idx] + unsat_neigh = g_indices[g_indptr[i] : g_indptr[i + 1]] + for u in unsat_neigh: + if u in sat: # take the first saturated nhood + nearest_sat[idx] = u + break + + # some unsat still don't have a sat nhood + # return them and compute distances in outer func + un_unsat = unsat[np.isnan(nearest_sat)] + + return sat_idx.astype(np.int32), nearest_sat, un_unsat diff --git a/tmp_scripts/verify.py b/tmp_scripts/verify.py new file mode 100644 index 00000000..2dce39d4 --- /dev/null +++ b/tmp_scripts/verify.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import cupy as cp +import numpy as np +import scipy.sparse as sp + +import rapids_singlecell as rsc +from rapids_singlecell.squidpy_gpu._sepal import _compute_idxs + + +def build_rect_grid(n=3): + n_cells = n * n + coords = ( + np.stack(np.meshgrid(np.arange(n), np.arange(n), indexing="xy"), -1) + .reshape(-1, 2) + .astype(np.float32) + ) + rows, cols, data = [], [], [] + + def idx(r, c): + return r * n + c + + for r in range(n): + for c in range(n): + i = idx(r, c) + for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]: + rr, cc = r + dr, c + dc + if 0 <= rr < n and 0 <= cc < n: + j = idx(rr, cc) + rows.append(i) + cols.append(j) + data.append(1.0) + g = sp.csr_matrix((data, (rows, cols)), shape=(n_cells, n_cells), dtype=np.float32) + return g, coords + + +def verify_invariants_and_remaps(max_neighs=4): + g_np, coords_np = build_rect_grid(3) + g = rsc.get.X_to_GPU(g_np) + spatial = rsc.get.X_to_GPU(coords_np.astype(np.float32)) + degrees = cp.diff(g.indptr) + + sat, sat_idx, unsat, nearest_sat = _compute_idxs(g, degrees, spatial, max_neighs) + + n_cells = g.shape[0] + n_sat = len(sat) + n_unsat = len(unsat) + + reorder_indices = cp.concatenate([sat, unsat]) + old_to_new = cp.empty(n_cells, dtype=cp.int32) + old_to_new[reorder_indices] = cp.arange(n_cells, dtype=cp.int32) + + # Invariants + assert cp.all(old_to_new[sat] == cp.arange(n_sat)), ( + "sat positions not contiguous at front" + ) + assert cp.all(old_to_new[unsat] == cp.arange(n_sat, n_cells)), ( + "unsat positions not after sat" + ) + + # Remaps for kernel inputs + sat_idx_mapped = old_to_new[sat_idx] + nearest_sat_mapped = old_to_new[nearest_sat] + + assert sat_idx_mapped.shape == sat_idx.shape + assert cp.all(nearest_sat_mapped >= 0) and cp.all(nearest_sat_mapped < n_sat), ( + "nearest_sat must map into [0, n_sat)" + ) + + # Neighbor-sum preservation + rng = cp.random.RandomState(0) + conc_orig = rng.rand(n_cells) + conc_reordered = conc_orig[reorder_indices] + for k in range(n_sat): + s1 = conc_orig[sat_idx[k]].sum() + s2 = conc_reordered[sat_idx_mapped[k]].sum() + cp.testing.assert_allclose( + s1, s2, rtol=0, atol=0, err_msg=f"neighbor sum mismatch at sat row {k}" + ) + + return { + "n_cells": int(n_cells), + "n_sat": int(n_sat), + "n_unsat": int(n_unsat), + "ok": True, + } + + +def verify_single_step_update(max_neighs=4, dt=1e-2): + g_np, coords_np = build_rect_grid(3) + g = rsc.get.X_to_GPU(g_np) + spatial = rsc.get.X_to_GPU(coords_np.astype(np.float32)) + degrees = cp.diff(g.indptr) + + sat, sat_idx, unsat, nearest_sat = _compute_idxs(g, degrees, spatial, max_neighs) + + n_cells = g.shape[0] + n_sat = len(sat) + + reorder_indices = cp.concatenate([sat, unsat]) + old_to_new = cp.empty(n_cells, dtype=cp.int32) + old_to_new[reorder_indices] = cp.arange(n_cells, dtype=cp.int32) + + sat_idx_mapped = old_to_new[sat_idx] + nearest_sat_mapped = old_to_new[nearest_sat] + + rng = cp.random.RandomState(1) + conc_orig = rng.rand(n_cells) + conc_reo = conc_orig[reorder_indices].copy() + + # Rectangular Laplacian + neighbor_sum_orig = cp.array([conc_orig[sat_idx[i]].sum() for i in range(n_sat)]) + neighbor_sum_reo = cp.array( + [conc_reo[sat_idx_mapped[i]].sum() for i in range(n_sat)] + ) + centers_orig = conc_orig[sat] + centers_reo = conc_reo[:n_sat] + d2_sat_orig = neighbor_sum_orig - 4.0 * centers_orig + d2_sat_reo = neighbor_sum_reo - 4.0 * centers_reo + + conc_orig2 = conc_orig.copy() + conc_reo2 = conc_reo.copy() + + conc_orig2[sat] = cp.maximum(0.0, conc_orig2[sat] + d2_sat_orig * dt) + conc_reo2[:n_sat] = cp.maximum(0.0, conc_reo2[:n_sat] + d2_sat_reo * dt) + + # Unsat update using nearest sat derivative (per-unsat order) + for i in range(len(unsat)): + u_global = unsat[i] + ns = nearest_sat[i] # nearest sat node id + + # Find position of ns in sat array + ns_pos_orig = cp.where(sat == ns)[0][0] + conc_orig2[u_global] = cp.maximum( + 0.0, conc_orig2[u_global] + d2_sat_orig[ns_pos_orig] * dt + ) + + # In reordered space, nearest_sat_mapped[i] gives us the position in reordered array + # But we need the position in the saturated block (0..n_sat-1) + ns_pos_reo = nearest_sat_mapped[i] # This should be < n_sat + u_reo = n_sat + i + conc_reo2[u_reo] = cp.maximum( + 0.0, conc_reo2[u_reo] + d2_sat_reo[ns_pos_reo] * dt + ) + + # Map back reordered → original + inv = cp.empty_like(old_to_new) + inv[old_to_new] = cp.arange(n_cells, dtype=cp.int32) + conc_reconstructed = conc_reo2[inv] + + cp.testing.assert_allclose( + conc_orig2, + conc_reconstructed, + rtol=0, + atol=1e-12, + err_msg="single-step update mismatch", + ) + + return {"ok": True} + + +def main(): + res1 = verify_invariants_and_remaps() + print( + "Invariant+remap checks:", + {k: int(v) if isinstance(v, (np.integer,)) else v for k, v in res1.items()}, + ) + res2 = verify_single_step_update() + print("Single-step update equivalence:", res2) + print("All mapping tests passed.") + + +if __name__ == "__main__": + main()