diff --git a/graphgp/covariance.py b/graphgp/covariance.py index 60985ff..9169a35 100644 --- a/graphgp/covariance.py +++ b/graphgp/covariance.py @@ -88,17 +88,4 @@ def cov_lookup(r, cov_bins, cov_vals): If `r` is below the first bin, the first value is returned. But really the first bin should always be 0.0. If `r` is above the last bin, the last value is returned. Maybe the last value should be zero. """ - # interpolate between bins - idx = jnp.searchsorted(cov_bins, r) - # return cov_vals[idx] - r0 = cov_bins[idx - 1] - r1 = cov_bins[idx] - c0 = cov_vals[idx - 1] - c1 = cov_vals[idx] - c = c0 + (c1 - c0) * (r - r0) / (r1 - r0) - - # handle edge cases - c = jnp.where(idx == 0, c1, c) - c = jnp.where(idx == len(cov_bins), c0, c) - c = jnp.where(r0 == r1, c0, c) - return c + return jax.numpy.interp(r, cov_bins, cov_vals) diff --git a/graphgp/refine.py b/graphgp/refine.py index 7a6b1fb..cd62bcc 100644 --- a/graphgp/refine.py +++ b/graphgp/refine.py @@ -4,6 +4,9 @@ import jax.numpy as jnp from jax.tree_util import Partial from jax import Array +from jax import lax + +import numpy as np from .covariance import compute_cov_matrix, CovarianceType from .graph import Graph @@ -21,6 +24,8 @@ def generate( xi: Array, *, cuda: bool = False, + fast_jit: bool = True, + use_cholesky: bool = True, ) -> Array: """ Generate a GP with dense Cholesky for the first layer followed by conditional refinement. @@ -32,21 +37,44 @@ def generate( xi: Unit normal distributed parameters of shape ``(N,).`` reorder: Whether to reorder parameters and values according to the original order of the points. Default is ``True``. cuda: Whether to use optional CUDA extension, if installed. Will still use CUDA GPU via JAX if available. Default is ``False`` but recommended if possible for performance. - + fast_jit: Whether to use version of refinement that compiles faster, if cuda=False. Default is ``True`` but runtime performance and memory usage will suffer. + Returns: The generated values of shape ``(N,).`` """ n0 = len(graph.points) - len(graph.neighbors) if graph.indices is not None: xi = xi[graph.indices] - initial_values = generate_dense(graph.points[:n0], covariance, xi[:n0]) - values = refine(graph.points, graph.neighbors, graph.offsets, covariance, initial_values, xi[n0:], cuda=cuda) + initial_values = generate_dense(graph.points[:n0], covariance, xi[:n0], use_cholesky=use_cholesky) + values = refine(graph.points, graph.neighbors, graph.offsets, covariance, initial_values, xi[n0:], cuda=cuda, fast_jit=fast_jit) if graph.indices is not None: values = jnp.empty_like(values).at[graph.indices].set(values) return values - -def generate_dense(points: Array, covariance: CovarianceType, xi: Array) -> Array: +def my_compute_cov_matrix(covariance, points_a, points_b): + # FIXME: Temporary hack to handle non-stationary covariance functions + if isinstance(covariance, tuple): + if len(covariance) == 2: + return compute_cov_matrix(covariance, points_a, points_b) + elif len(covariance) == 3: + ndims, cov_bins, cov_vals = covariance + nn = ndims[0] + res = compute_cov_matrix((cov_bins[0], cov_vals[0]), + points_a[..., :nn], points_b[..., :nn]) + identity = jnp.eye(res.shape[0]) + res += identity + for i in range(1, len(ndims)): + cv = compute_cov_matrix((cov_bins[i], cov_vals[i]), + points_a[..., nn:nn+ndims[i]], points_b[..., nn:nn+ndims[i]]) + res *= (cv + identity) + nn += ndims[i] + res -= identity + return res + else: + raise ValueError("Invalid covariance specification.") + + +def generate_dense(points: Array, covariance: CovarianceType, xi: Array, *, use_cholesky: bool = True) -> Array: """ Generate a GP with a dense Cholesky decomposition. Note that to compare with the GraphGP values, the points must be provided in tree order. @@ -58,11 +86,23 @@ def generate_dense(points: Array, covariance: CovarianceType, xi: Array) -> Arra Returns: The generated values of shape ``(N,).`` """ - K = compute_cov_matrix(covariance, points, points) - L = jnp.linalg.cholesky(K) + K = my_compute_cov_matrix(covariance, points, points) + if use_cholesky: + L = jnp.linalg.cholesky(K) + else: + from .utils import _sqrtm + L = _sqrtm(K) values = L @ xi return values +def _conditional_mean_std_vec(covariance, coarse_points, fine_point): + k = len(coarse_points) + joint_points = jnp.concatenate([coarse_points, fine_point[jnp.newaxis]], axis=0) + K = my_compute_cov_matrix(covariance, joint_points, joint_points) + L = jnp.linalg.cholesky(K) + mean = jnp.linalg.solve(L[:k, :k].T, L[k, :k].T).T + std = L[k, k] + return mean, std def refine( points: Array, @@ -73,6 +113,7 @@ def refine( xi: Array, *, cuda: bool = False, + fast_jit: bool = True, ) -> Array: """ Conditionally generate using initial values according to GraphGP algorithm. Most users can use ``generate``, which @@ -89,7 +130,8 @@ def refine( initial_values: Initial values of shape ``(offsets[0],).`` xi: Unit normal distributed parameters of shape ``(N - offsets[0],).`` cuda: Whether to use optional CUDA extension, if installed. Will still use CUDA GPU via JAX if available. Default is ``False`` but recommended if possible for performance. - + fast_jit: Whether to use version of refinement that compiles faster, if cuda=False. Default is ``True`` but runtime performance and memory usage will suffer. + Returns: The refined values of shape ``(N,).`` @@ -98,20 +140,66 @@ def refine( if cuda: if not has_cuda: raise ImportError("CUDA extension not installed, cannot use cuda=True.") + if jax.config.jax_enable_x64: + # TODO build generic float64 support + points = points.astype(jnp.float32) + neighbors = neighbors.astype(jnp.int32) + offsets = jnp.asarray(offsets, dtype=jnp.int32) + initial_values = initial_values.astype(jnp.float32) + xi = xi.astype(jnp.float32) + covariance = tuple(cc.astype(jnp.float32) for cc in _cuda_process_covariance(covariance)) values = graphgp_cuda.refine( points, neighbors, jnp.asarray(offsets), *_cuda_process_covariance(covariance), initial_values, xi ) + if jax.config.jax_enable_x64: + values = values.astype(jnp.float64) else: - values = initial_values - for i in range(1, len(offsets)): - start = offsets[i - 1] - end = offsets[i] - coarse_points = jnp.take(points, neighbors[start - n0 : end - n0], axis=0) - coarse_values = jnp.take(values, neighbors[start - n0 : end - n0], axis=0) - fine_point = points[start:end] - fine_xi = xi[start - n0 : end - n0] - mean, std = jax.vmap(Partial(_conditional_mean_std, covariance))(coarse_points, coarse_values, fine_point) - values = jnp.concatenate([values, mean + std * fine_xi], axis=0) + if fast_jit: + k = neighbors.shape[1] + max_batch = np.max(np.diff(np.array(offsets))) + values = jnp.zeros(len(points)) + values = values.at[:n0].set(initial_values) + + # Precompute matrix factorizations for all points + coarse_points = points[neighbors] + joint_points = jnp.concatenate([coarse_points, points[n0:, None]], axis=1) + K = jax.vmap(compute_cov_matrix, in_axes=(None, 0, 0))(covariance, joint_points, joint_points) + L = jnp.linalg.cholesky(K) + mean_vec = jnp.linalg.solve(L[:, :k, :k].transpose(0, 2, 1), L[:, k, :k][..., None]).squeeze(-1) + noise = L[:, k, k] * xi + + # For each batch defined by offsets, dot neighbor values with mean_vec and add noise + def step(values, start): + neighbor_values = values[lax.dynamic_slice(neighbors, (start - n0, 0), (max_batch, k))] + mean_slice = jnp.sum( + lax.dynamic_slice(mean_vec, (start - n0, 0), (max_batch, k)) * neighbor_values, axis=1 + ) + noise_slice = lax.dynamic_slice(noise, (start - n0,), (max_batch,)) + values = lax.dynamic_update_slice(values, mean_slice + noise_slice, (start,)) + return values, None + + values, _ = lax.scan(step, values, jnp.array(offsets[:-1])) + + else: + coarse_points = points[neighbors] + mean, std = jax.vmap(Partial(_conditional_mean_std_vec, covariance))(coarse_points, points[n0:]) + mean = jax.block_until_ready(mean) + std = jax.block_until_ready(std) + + @jax.vmap + def single(mean, std, xi, values): + return jnp.vdot(mean, values) + std * xi + + values = initial_values + for i in range(1, len(offsets)): + start = offsets[i - 1] + end = offsets[i] + means = mean[start - n0 : end - n0] + stds = std[start - n0 : end - n0] + coarse_values = values[neighbors[start - n0 : end - n0]] + fine_xi = xi[start - n0 : end - n0] + res = single(means, stds, fine_xi, coarse_values) + values = jnp.concatenate([values, res], axis=0) return values @@ -140,7 +228,7 @@ def generate_dense_inv(points: Array, covariance: CovarianceType, values: Array) """ Inverse of ``generate_dense``. """ - K = compute_cov_matrix(covariance, points, points) + K = my_compute_cov_matrix(covariance, points, points) L = jnp.linalg.cholesky(K) xi = jnp.linalg.solve(L, values) return xi @@ -193,7 +281,7 @@ def generate_dense_logdet(points: Array, covariance: CovarianceType) -> Array: """ Log determinant of ``generate_dense``. """ - K = compute_cov_matrix(covariance, points, points) + K = my_compute_cov_matrix(covariance, points, points) return jnp.linalg.slogdet(K)[1] / 2 @@ -230,7 +318,7 @@ def refine_logdet( def _conditional_mean_std(covariance, coarse_points, coarse_values, fine_point): k = len(coarse_points) joint_points = jnp.concatenate([coarse_points, fine_point[jnp.newaxis]], axis=0) - K = compute_cov_matrix(covariance, joint_points, joint_points) + K = my_compute_cov_matrix(covariance, joint_points, joint_points) L = jnp.linalg.cholesky(K) mean = L[k, :k] @ jnp.linalg.solve(L[:k, :k], coarse_values) std = L[k, k] @@ -241,7 +329,7 @@ def _conditional_mean_std(covariance, coarse_points, coarse_values, fine_point): def _conditional_std(covariance, coarse_points, fine_point): k = len(coarse_points) joint_points = jnp.concatenate([coarse_points, fine_point[jnp.newaxis]], axis=0) - K = compute_cov_matrix(covariance, joint_points, joint_points) + K = my_compute_cov_matrix(covariance, joint_points, joint_points) L = jnp.linalg.cholesky(K) return L[k, k] diff --git a/graphgp/utils.py b/graphgp/utils.py new file mode 100644 index 0000000..dee06f1 --- /dev/null +++ b/graphgp/utils.py @@ -0,0 +1,33 @@ +import jax +import jax.numpy as jnp + + +def _clip_v(v, A): + info = jnp.finfo(A.dtype) + tol = A.shape[0] * info.eps * jnp.linalg.norm(A, ord=2) + return jnp.clip(v, tol, None) + + +def _get_sqrt(v, U): + vsq = jnp.sqrt(v) + return U @ (vsq[:, jnp.newaxis] * U.T) + + +@jax.custom_jvp +def _sqrtm(M): + v, U = jnp.linalg.eigh(M) + v = _clip_v(v, M) + return _get_sqrt(v, U) + + +@_sqrtm.defjvp +def _sqrtm_jvp(M, dM): + # Note: Only stable 1st derivative! + M, dM = M[0], dM[0] + v, U = jnp.linalg.eigh(M) + v = _clip_v(v, M) + + dM = U.T @ dM @ U + vsq = jnp.sqrt(v) + dres = dM / (vsq[:, jnp.newaxis] + vsq[jnp.newaxis, :]) + return _get_sqrt(v, U), U @ dres @ U.T