Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions graphgp/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,4 @@ def cov_lookup(r, cov_bins, cov_vals):
If `r` is below the first bin, the first value is returned. But really the first bin should always be 0.0.
If `r` is above the last bin, the last value is returned. Maybe the last value should be zero.
"""
# interpolate between bins
idx = jnp.searchsorted(cov_bins, r)
# return cov_vals[idx]
r0 = cov_bins[idx - 1]
r1 = cov_bins[idx]
c0 = cov_vals[idx - 1]
c1 = cov_vals[idx]
c = c0 + (c1 - c0) * (r - r0) / (r1 - r0)

# handle edge cases
c = jnp.where(idx == 0, c1, c)
c = jnp.where(idx == len(cov_bins), c0, c)
c = jnp.where(r0 == r1, c0, c)
return c
return jax.numpy.interp(r, cov_bins, cov_vals)
132 changes: 110 additions & 22 deletions graphgp/refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import jax.numpy as jnp
from jax.tree_util import Partial
from jax import Array
from jax import lax

import numpy as np

from .covariance import compute_cov_matrix, CovarianceType
from .graph import Graph
Expand All @@ -21,6 +24,8 @@ def generate(
xi: Array,
*,
cuda: bool = False,
fast_jit: bool = True,
use_cholesky: bool = True,
) -> Array:
"""
Generate a GP with dense Cholesky for the first layer followed by conditional refinement.
Expand All @@ -32,21 +37,44 @@ def generate(
xi: Unit normal distributed parameters of shape ``(N,).``
reorder: Whether to reorder parameters and values according to the original order of the points. Default is ``True``.
cuda: Whether to use optional CUDA extension, if installed. Will still use CUDA GPU via JAX if available. Default is ``False`` but recommended if possible for performance.

fast_jit: Whether to use version of refinement that compiles faster, if cuda=False. Default is ``True`` but runtime performance and memory usage will suffer.

Returns:
The generated values of shape ``(N,).``
"""
n0 = len(graph.points) - len(graph.neighbors)
if graph.indices is not None:
xi = xi[graph.indices]
initial_values = generate_dense(graph.points[:n0], covariance, xi[:n0])
values = refine(graph.points, graph.neighbors, graph.offsets, covariance, initial_values, xi[n0:], cuda=cuda)
initial_values = generate_dense(graph.points[:n0], covariance, xi[:n0], use_cholesky=use_cholesky)
values = refine(graph.points, graph.neighbors, graph.offsets, covariance, initial_values, xi[n0:], cuda=cuda, fast_jit=fast_jit)
if graph.indices is not None:
values = jnp.empty_like(values).at[graph.indices].set(values)
return values


def generate_dense(points: Array, covariance: CovarianceType, xi: Array) -> Array:
def my_compute_cov_matrix(covariance, points_a, points_b):
# FIXME: Temporary hack to handle non-stationary covariance functions
if isinstance(covariance, tuple):
if len(covariance) == 2:
return compute_cov_matrix(covariance, points_a, points_b)
elif len(covariance) == 3:
ndims, cov_bins, cov_vals = covariance
nn = ndims[0]
res = compute_cov_matrix((cov_bins[0], cov_vals[0]),
points_a[..., :nn], points_b[..., :nn])
identity = jnp.eye(res.shape[0])
res += identity
for i in range(1, len(ndims)):
cv = compute_cov_matrix((cov_bins[i], cov_vals[i]),
points_a[..., nn:nn+ndims[i]], points_b[..., nn:nn+ndims[i]])
res *= (cv + identity)
nn += ndims[i]
res -= identity
return res
else:
raise ValueError("Invalid covariance specification.")


def generate_dense(points: Array, covariance: CovarianceType, xi: Array, *, use_cholesky: bool = True) -> Array:
"""
Generate a GP with a dense Cholesky decomposition. Note that to compare with the GraphGP values,
the points must be provided in tree order.
Expand All @@ -58,11 +86,23 @@ def generate_dense(points: Array, covariance: CovarianceType, xi: Array) -> Arra
Returns:
The generated values of shape ``(N,).``
"""
K = compute_cov_matrix(covariance, points, points)
L = jnp.linalg.cholesky(K)
K = my_compute_cov_matrix(covariance, points, points)
if use_cholesky:
L = jnp.linalg.cholesky(K)
else:
from .utils import _sqrtm
L = _sqrtm(K)
values = L @ xi
return values

def _conditional_mean_std_vec(covariance, coarse_points, fine_point):
k = len(coarse_points)
joint_points = jnp.concatenate([coarse_points, fine_point[jnp.newaxis]], axis=0)
K = my_compute_cov_matrix(covariance, joint_points, joint_points)
L = jnp.linalg.cholesky(K)
mean = jnp.linalg.solve(L[:k, :k].T, L[k, :k].T).T
std = L[k, k]
return mean, std

def refine(
points: Array,
Expand All @@ -73,6 +113,7 @@ def refine(
xi: Array,
*,
cuda: bool = False,
fast_jit: bool = True,
) -> Array:
"""
Conditionally generate using initial values according to GraphGP algorithm. Most users can use ``generate``, which
Expand All @@ -89,7 +130,8 @@ def refine(
initial_values: Initial values of shape ``(offsets[0],).``
xi: Unit normal distributed parameters of shape ``(N - offsets[0],).``
cuda: Whether to use optional CUDA extension, if installed. Will still use CUDA GPU via JAX if available. Default is ``False`` but recommended if possible for performance.

fast_jit: Whether to use version of refinement that compiles faster, if cuda=False. Default is ``True`` but runtime performance and memory usage will suffer.

Returns:
The refined values of shape ``(N,).``

Expand All @@ -98,20 +140,66 @@ def refine(
if cuda:
if not has_cuda:
raise ImportError("CUDA extension not installed, cannot use cuda=True.")
if jax.config.jax_enable_x64:
# TODO build generic float64 support
points = points.astype(jnp.float32)
neighbors = neighbors.astype(jnp.int32)
offsets = jnp.asarray(offsets, dtype=jnp.int32)
initial_values = initial_values.astype(jnp.float32)
xi = xi.astype(jnp.float32)
covariance = tuple(cc.astype(jnp.float32) for cc in _cuda_process_covariance(covariance))
values = graphgp_cuda.refine(
points, neighbors, jnp.asarray(offsets), *_cuda_process_covariance(covariance), initial_values, xi
)
if jax.config.jax_enable_x64:
values = values.astype(jnp.float64)
else:
values = initial_values
for i in range(1, len(offsets)):
start = offsets[i - 1]
end = offsets[i]
coarse_points = jnp.take(points, neighbors[start - n0 : end - n0], axis=0)
coarse_values = jnp.take(values, neighbors[start - n0 : end - n0], axis=0)
fine_point = points[start:end]
fine_xi = xi[start - n0 : end - n0]
mean, std = jax.vmap(Partial(_conditional_mean_std, covariance))(coarse_points, coarse_values, fine_point)
values = jnp.concatenate([values, mean + std * fine_xi], axis=0)
if fast_jit:
k = neighbors.shape[1]
max_batch = np.max(np.diff(np.array(offsets)))
values = jnp.zeros(len(points))
values = values.at[:n0].set(initial_values)

# Precompute matrix factorizations for all points
coarse_points = points[neighbors]
joint_points = jnp.concatenate([coarse_points, points[n0:, None]], axis=1)
K = jax.vmap(compute_cov_matrix, in_axes=(None, 0, 0))(covariance, joint_points, joint_points)
L = jnp.linalg.cholesky(K)
mean_vec = jnp.linalg.solve(L[:, :k, :k].transpose(0, 2, 1), L[:, k, :k][..., None]).squeeze(-1)
noise = L[:, k, k] * xi

# For each batch defined by offsets, dot neighbor values with mean_vec and add noise
def step(values, start):
neighbor_values = values[lax.dynamic_slice(neighbors, (start - n0, 0), (max_batch, k))]
mean_slice = jnp.sum(
lax.dynamic_slice(mean_vec, (start - n0, 0), (max_batch, k)) * neighbor_values, axis=1
)
noise_slice = lax.dynamic_slice(noise, (start - n0,), (max_batch,))
values = lax.dynamic_update_slice(values, mean_slice + noise_slice, (start,))
return values, None

values, _ = lax.scan(step, values, jnp.array(offsets[:-1]))

else:
coarse_points = points[neighbors]
mean, std = jax.vmap(Partial(_conditional_mean_std_vec, covariance))(coarse_points, points[n0:])
mean = jax.block_until_ready(mean)
std = jax.block_until_ready(std)

@jax.vmap
def single(mean, std, xi, values):
return jnp.vdot(mean, values) + std * xi

values = initial_values
for i in range(1, len(offsets)):
start = offsets[i - 1]
end = offsets[i]
means = mean[start - n0 : end - n0]
stds = std[start - n0 : end - n0]
coarse_values = values[neighbors[start - n0 : end - n0]]
fine_xi = xi[start - n0 : end - n0]
res = single(means, stds, fine_xi, coarse_values)
values = jnp.concatenate([values, res], axis=0)
return values


Expand Down Expand Up @@ -140,7 +228,7 @@ def generate_dense_inv(points: Array, covariance: CovarianceType, values: Array)
"""
Inverse of ``generate_dense``.
"""
K = compute_cov_matrix(covariance, points, points)
K = my_compute_cov_matrix(covariance, points, points)
L = jnp.linalg.cholesky(K)
xi = jnp.linalg.solve(L, values)
return xi
Expand Down Expand Up @@ -193,7 +281,7 @@ def generate_dense_logdet(points: Array, covariance: CovarianceType) -> Array:
"""
Log determinant of ``generate_dense``.
"""
K = compute_cov_matrix(covariance, points, points)
K = my_compute_cov_matrix(covariance, points, points)
return jnp.linalg.slogdet(K)[1] / 2


Expand Down Expand Up @@ -230,7 +318,7 @@ def refine_logdet(
def _conditional_mean_std(covariance, coarse_points, coarse_values, fine_point):
k = len(coarse_points)
joint_points = jnp.concatenate([coarse_points, fine_point[jnp.newaxis]], axis=0)
K = compute_cov_matrix(covariance, joint_points, joint_points)
K = my_compute_cov_matrix(covariance, joint_points, joint_points)
L = jnp.linalg.cholesky(K)
mean = L[k, :k] @ jnp.linalg.solve(L[:k, :k], coarse_values)
std = L[k, k]
Expand All @@ -241,7 +329,7 @@ def _conditional_mean_std(covariance, coarse_points, coarse_values, fine_point):
def _conditional_std(covariance, coarse_points, fine_point):
k = len(coarse_points)
joint_points = jnp.concatenate([coarse_points, fine_point[jnp.newaxis]], axis=0)
K = compute_cov_matrix(covariance, joint_points, joint_points)
K = my_compute_cov_matrix(covariance, joint_points, joint_points)
L = jnp.linalg.cholesky(K)
return L[k, k]

Expand Down
33 changes: 33 additions & 0 deletions graphgp/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import jax
import jax.numpy as jnp


def _clip_v(v, A):
info = jnp.finfo(A.dtype)
tol = A.shape[0] * info.eps * jnp.linalg.norm(A, ord=2)
return jnp.clip(v, tol, None)


def _get_sqrt(v, U):
vsq = jnp.sqrt(v)
return U @ (vsq[:, jnp.newaxis] * U.T)


@jax.custom_jvp
def _sqrtm(M):
v, U = jnp.linalg.eigh(M)
v = _clip_v(v, M)
return _get_sqrt(v, U)


@_sqrtm.defjvp
def _sqrtm_jvp(M, dM):
# Note: Only stable 1st derivative!
M, dM = M[0], dM[0]
v, U = jnp.linalg.eigh(M)
v = _clip_v(v, M)

dM = U.T @ dM @ U
vsq = jnp.sqrt(v)
dres = dM / (vsq[:, jnp.newaxis] + vsq[jnp.newaxis, :])
return _get_sqrt(v, U), U @ dres @ U.T