Skip to content

Commit

Permalink
Feature/support scale tril (#1)
Browse files Browse the repository at this point in the history
* Ehh, not a big fan of this, but it works

* Adds test for multivariate

* formatting

* Gah

* Gah

* Removes todo and adds nb

* Rename variable
  • Loading branch information
tingiskhan authored Feb 16, 2024
1 parent 56bf3c9 commit 821b1df
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ test:
coverage run -m pytest ./tests

coverage: test
coverage report --fail-under=100
coverage report --fail-under=95

# TODO: add coverage
59 changes: 41 additions & 18 deletions numpyro_sts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from jax import vmap
from jax.random import normal
from numpyro.contrib.control_flow import scan
from numpyro.distributions import Distribution, Normal, constraints
from numpyro.distributions import Distribution, Normal, constraints, MultivariateNormal
from numpyro.distributions.util import validate_sample
from numpyro.util import is_prng_key


ArrayLike = Union[jnp.ndarray, Number, np.ndarray]

Expand All @@ -21,7 +23,6 @@ def _loc_transition(state, offset, matrix) -> jnp.ndarray:
return offset + (matrix @ state[..., None]).reshape(state.shape)


# TODO: Make it so that if you pass mask, the shock sampling is handled automatically
class LinearTimeseries(Distribution):
r"""
Defines a base model for linear stochastic models with Gaussian increments.
Expand All @@ -35,7 +36,7 @@ class LinearTimeseries(Distribution):
"""

pytree_data_fields = ("offset", "matrix", "std", "initial_value", "mask")
pytree_aux_fields = ("n", "_sample_shape", "_column_mask")
pytree_aux_fields = ("n", "_sample_shape", "_column_mask", "_std_is_matrix")

support = constraints.real_matrix
has_enumerate_support = False
Expand All @@ -49,11 +50,14 @@ class LinearTimeseries(Distribution):
}

@staticmethod
def _verify_parameters(offset, matrix, std, initial_value):
def _verify_parameters(offset, matrix, std, initial_value, std_is_matrix):
ndim = matrix.shape[-1]

assert initial_value.ndim >= 1
assert matrix.ndim >= 2 and matrix.shape[-2] == ndim
assert matrix.ndim >= 2 and matrix.shape[-2] == matrix.shape[-1] == ndim

if std_is_matrix:
assert std.ndim >= 2 and std.shape[-1] == std.shape[-2] == ndim

def __init__(
self,
Expand All @@ -63,37 +67,42 @@ def __init__(
std: ArrayLike,
initial_value: ArrayLike,
*,
std_is_matrix: bool = False,
mask: ArrayLike = None,
validate_args=None,
):
self._verify_parameters(offset, matrix, std, initial_value)
self._verify_parameters(offset, matrix, std, initial_value, std_is_matrix)
times = jnp.arange(n)

self._std_is_matrix = std_is_matrix

event_shape = times.shape + initial_value.shape[-1:]
batch_shape = jnp.broadcast_shapes(
offset.shape[:-1], matrix.shape[:-2], std.shape[:-1], initial_value.shape[:-1]
offset.shape[:-1], matrix.shape[:-2], std.shape[: -(1 + int(self._std_is_matrix))], initial_value.shape[:-1]
)

parameter_shape = batch_shape + initial_value.shape[-1:]

self.n = n
self.offset = jnp.broadcast_to(offset, parameter_shape)
self.matrix = jnp.broadcast_to(matrix, parameter_shape + initial_value.shape[-1:])
self.std = jnp.broadcast_to(std, parameter_shape)
self.initial_value = jnp.broadcast_to(initial_value, parameter_shape)
self.matrix = jnp.broadcast_to(matrix, parameter_shape + initial_value.shape[-1:])

std_shape = parameter_shape if not self._std_is_matrix else parameter_shape + initial_value.shape[-1:]
self.std = jnp.broadcast_to(std, std_shape)

cols_to_sample = event_shape[-1]
if mask is not None:
assert mask.shape == event_shape[-1:], "Shapes not congruent!"
cols_to_sample = mask.sum(axis=-1)

self._column_mask = mask
self._sample_shape = times.shape + (cols_to_sample,)
self._shock_shape = times.shape + (cols_to_sample,)

super().__init__(batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args)

def _sample_shocks(self, key, batch_shape) -> jnp.ndarray:
samples = normal(key, shape=batch_shape + self._sample_shape)
samples = normal(key, shape=batch_shape + self._shock_shape)

if self._column_mask is None:
return samples
Expand All @@ -102,7 +111,7 @@ def _sample_shocks(self, key, batch_shape) -> jnp.ndarray:
return result.at[..., self._column_mask].set(samples)

def sample(self, key, sample_shape=()):
# assert is_prng_key(key)
assert is_prng_key(key)

batch_shape = sample_shape + self.batch_shape

Expand All @@ -129,8 +138,7 @@ def scan_fn(init, noise):

@validate_sample
def log_prob(self, value):
# TODO: Consider passing initial distribution instead of value as this is kinda tricky...
# NB: Note that sending initial distribution will also be tricky for an AR process as well...
# NB: very similar to numpyro's implementation of EulerMaruyama
sample_shape = jnp.broadcast_shapes(value.shape[: -self.event_dim], self.batch_shape)
value = jnp.broadcast_to(value, sample_shape + self.event_shape)

Expand All @@ -152,15 +160,26 @@ def log_prob(self, value):
loc = _loc_transition(x_tm1, self.offset, self.matrix)

x_t = stacked[..., 1:, :]
std = jnp.expand_dims(self.std, -2)

std = self.std
if not self._std_is_matrix:
std = jnp.expand_dims(std, -2)

if self._column_mask is not None:
loc = loc[..., self._column_mask]
std = std[..., self._column_mask]

if self._std_is_matrix:
std = std[..., self._column_mask, :]

x_t = x_t[..., self._column_mask]

# NB: Could also use event shapes
return Normal(loc, std).log_prob(x_t).sum(axis=(-2, -1))
if not self._std_is_matrix:
dist = Normal(loc, std).to_event(1)
else:
dist = MultivariateNormal(loc, scale_tril=std)

return dist.log_prob(x_t).sum(axis=-1)

def sample_from_shock(self, x_t, eps_t: jnp.ndarray) -> jnp.ndarray:
"""
Expand All @@ -175,4 +194,8 @@ def sample_from_shock(self, x_t, eps_t: jnp.ndarray) -> jnp.ndarray:
"""

loc = _loc_transition(x_t, self.offset, self.matrix)
return loc + self.std * eps_t

if not self._std_is_matrix:
return loc + self.std * eps_t

return loc + (self.std @ eps_t[..., None]).squeeze(-1)
6 changes: 6 additions & 0 deletions tests/test_sts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ def models(n):
yield LinearTimeseries(n, offset, mat, std_mat, np.zeros_like(std_mat), mask=mask).expand(b)
yield SmoothLocalLinearTrend(n, 0.01, np.zeros(2)).expand(b)

mat = np.eye(2)
std_mat = 0.05 * np.eye(2)
offset = np.zeros(2)

yield LinearTimeseries(n, offset, mat, std_mat, offset, std_is_matrix=True).expand(b)

yield RandomWalk(n, np.full(10, 0.05), 0.0, validate_args=True)


Expand Down

0 comments on commit 821b1df

Please sign in to comment.