Skip to content

Commit d5f8f76

Browse files
Allow exogenous regressors in BayesianVARMAX (#567)
* First pass on exogenous variables in VARMA * Adjust state names for API consistency * Allow exogenous variables in BayesianVARMAX * Eagerly simplify model where possible * Typo fix
1 parent 89c6bc0 commit d5f8f76

File tree

2 files changed

+603
-57
lines changed

2 files changed

+603
-57
lines changed

pymc_extras/statespace/models/VARMAX.py

Lines changed: 248 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
ALL_STATE_AUX_DIM,
1515
ALL_STATE_DIM,
1616
AR_PARAM_DIM,
17+
EXOGENOUS_DIM,
1718
MA_PARAM_DIM,
1819
OBS_STATE_AUX_DIM,
1920
OBS_STATE_DIM,
2021
SHOCK_AUX_DIM,
2122
SHOCK_DIM,
23+
TIME_DIM,
2224
)
2325

2426
floatX = pytensor.config.floatX
@@ -28,60 +30,6 @@ class BayesianVARMAX(PyMCStateSpace):
2830
r"""
2931
Vector AutoRegressive Moving Average with eXogenous Regressors
3032
31-
Parameters
32-
----------
33-
order: tuple of (int, int)
34-
Number of autoregressive (AR) and moving average (MA) terms to include in the model. All terms up to the
35-
specified order are included. For restricted models, set zeros directly on the priors.
36-
37-
endog_names: list of str, optional
38-
Names of the endogenous variables being modeled. Used to generate names for the state and shock coords. If
39-
None, the state names will simply be numbered.
40-
41-
Exactly one of either ``endog_names`` or ``k_endog`` must be specified.
42-
43-
k_endog: int, optional
44-
Number of endogenous states to be modeled.
45-
46-
Exactly one of either ``endog_names`` or ``k_endog`` must be specified.
47-
48-
stationary_initialization: bool, default False
49-
If true, the initial state and initial state covariance will not be assigned priors. Instead, their steady
50-
state values will be used. If False, the user is responsible for setting priors on the initial state and
51-
initial covariance.
52-
53-
..warning :: This option is very sensitive to the priors placed on the AR and MA parameters. If the model dynamics
54-
for a given sample are not stationary, sampling will fail with a "covariance is not positive semi-definite"
55-
error.
56-
57-
filter_type: str, default "standard"
58-
The type of Kalman Filter to use. Options are "standard", "single", "univariate", "steady_state",
59-
and "cholesky". See the docs for kalman filters for more details.
60-
61-
state_structure: str, default "fast"
62-
How to represent the state-space system. When "interpretable", each element of the state vector will have a
63-
precise meaning as either lagged data, innovations, or lagged innovations. This comes at the cost of a larger
64-
state vector, which may hurt performance.
65-
66-
When "fast", states are combined to minimize the dimension of the state vector, but lags and innovations are
67-
mixed together as a result. Only the first state (the modeled timeseries) will have an obvious interpretation
68-
in this case.
69-
70-
measurement_error: bool, default True
71-
If true, a measurement error term is added to the model.
72-
73-
verbose: bool, default True
74-
If true, a message will be logged to the terminal explaining the variable names, dimensions, and supports.
75-
76-
mode: str or Mode, optional
77-
Pytensor compile mode, used in auxiliary sampling methods such as ``sample_conditional_posterior`` and
78-
``forecast``. The mode does **not** effect calls to ``pm.sample``.
79-
80-
Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
81-
to all sampling methods.
82-
83-
Notes
84-
-----
8533
The VARMA model is a multivariate extension of the SARIMAX model. Given a set of timeseries :math:`\{x_t\}_{t=0}^T`,
8634
with :math:`x_t = \begin{bmatrix} x_{1,t} & x_{2,t} & \cdots & x_{k,t} \end{bmatrix}^T`, a VARMA models each series
8735
as a function of the histories of all series. Specifically, denoting the AR-MA order as (p, q), a VARMA can be
@@ -152,23 +100,143 @@ def __init__(
152100
order: tuple[int, int],
153101
endog_names: list[str] | None = None,
154102
k_endog: int | None = None,
103+
exog_state_names: list[str] | dict[str, list[str]] | None = None,
104+
k_exog: int | dict[str, int] | None = None,
155105
stationary_initialization: bool = False,
156106
filter_type: str = "standard",
157107
measurement_error: bool = False,
158108
verbose: bool = True,
159109
mode: str | Mode | None = None,
160110
):
111+
"""
112+
Create a Bayesian VARMAX model.
113+
114+
Parameters
115+
----------
116+
order: tuple of (int, int)
117+
Number of autoregressive (AR) and moving average (MA) terms to include in the model. All terms up to the
118+
specified order are included. For restricted models, set zeros directly on the priors.
119+
120+
endog_names: list of str, optional
121+
Names of the endogenous variables being modeled. Used to generate names for the state and shock coords. If
122+
None, the state names will simply be numbered.
123+
124+
Exactly one of either ``endog_names`` or ``k_endog`` must be specified.
125+
126+
exog_state_names : list[str] or dict[str, list[str]], optional
127+
Names of the exogenous state variables. If a list, all endogenous variables will share the same exogenous
128+
variables. If a dict, keys should be the names of the endogenous variables, and values should be lists of the
129+
exogenous variable names for that endogenous variable. Endogenous variables not included in the dict will
130+
be assumed to have no exogenous variables. If None, no exogenous variables will be included.
131+
132+
k_exog : int or dict[str, int], optional
133+
Number of exogenous variables. If an int, all endogenous variables will share the same number of exogenous
134+
variables. If a dict, keys should be the names of the endogenous variables, and values should be the number of
135+
exogenous variables for that endogenous variable. Endogenous variables not included in the dict will be
136+
assumed to have no exogenous variables. If None, no exogenous variables will be included.
137+
138+
stationary_initialization: bool, default False
139+
If true, the initial state and initial state covariance will not be assigned priors. Instead, their steady
140+
state values will be used. If False, the user is responsible for setting priors on the initial state and
141+
initial covariance.
142+
143+
..warning :: This option is very sensitive to the priors placed on the AR and MA parameters. If the model dynamics
144+
for a given sample are not stationary, sampling will fail with a "covariance is not positive semi-definite"
145+
error.
146+
147+
filter_type: str, default "standard"
148+
The type of Kalman Filter to use. Options are "standard", "single", "univariate", "steady_state",
149+
and "cholesky". See the docs for kalman filters for more details.
150+
151+
state_structure: str, default "fast"
152+
How to represent the state-space system. When "interpretable", each element of the state vector will have a
153+
precise meaning as either lagged data, innovations, or lagged innovations. This comes at the cost of a larger
154+
state vector, which may hurt performance.
155+
156+
When "fast", states are combined to minimize the dimension of the state vector, but lags and innovations are
157+
mixed together as a result. Only the first state (the modeled timeseries) will have an obvious interpretation
158+
in this case.
159+
160+
measurement_error: bool, default True
161+
If true, a measurement error term is added to the model.
162+
163+
verbose: bool, default True
164+
If true, a message will be logged to the terminal explaining the variable names, dimensions, and supports.
165+
166+
mode: str or Mode, optional
167+
Pytensor compile mode, used in auxiliary sampling methods such as ``sample_conditional_posterior`` and
168+
``forecast``. The mode does **not** effect calls to ``pm.sample``.
169+
170+
Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
171+
to all sampling methods.
172+
173+
"""
161174
if (endog_names is None) and (k_endog is None):
162175
raise ValueError("Must specify either endog_names or k_endog")
163176
if (endog_names is not None) and (k_endog is None):
164177
k_endog = len(endog_names)
165178
if (endog_names is None) and (k_endog is not None):
166-
endog_names = [f"state.{i + 1}" for i in range(k_endog)]
179+
endog_names = [f"observed_{i}" for i in range(k_endog)]
167180
if (endog_names is not None) and (k_endog is not None):
168181
if len(endog_names) != k_endog:
169182
raise ValueError("Length of provided endog_names does not match provided k_endog")
170183

184+
if k_exog is not None and not isinstance(k_exog, int | dict):
185+
raise ValueError("If not None, k_endog must be either an int or a dict")
186+
if exog_state_names is not None and not isinstance(exog_state_names, list | dict):
187+
raise ValueError("If not None, exog_state_names must be either a list or a dict")
188+
189+
if k_exog is not None and exog_state_names is not None:
190+
if isinstance(k_exog, int) and isinstance(exog_state_names, list):
191+
if len(exog_state_names) != k_exog:
192+
raise ValueError("Length of exog_state_names does not match provided k_exog")
193+
elif isinstance(k_exog, int) and isinstance(exog_state_names, dict):
194+
raise ValueError(
195+
"If k_exog is an int, exog_state_names must be a list of the same length (or None)"
196+
)
197+
elif isinstance(k_exog, dict) and isinstance(exog_state_names, list):
198+
raise ValueError(
199+
"If k_exog is a dict, exog_state_names must be a dict as well (or None)"
200+
)
201+
elif isinstance(k_exog, dict) and isinstance(exog_state_names, dict):
202+
if set(k_exog.keys()) != set(exog_state_names.keys()):
203+
raise ValueError("Keys of k_exog and exog_state_names dicts must match")
204+
if not all(
205+
len(names) == k for names, k in zip(exog_state_names.values(), k_exog.values())
206+
):
207+
raise ValueError(
208+
"If both k_endog and exog_state_names are provided, lengths of exog_state_names "
209+
"lists must match corresponding values in k_exog"
210+
)
211+
212+
if k_exog is not None and exog_state_names is None:
213+
if isinstance(k_exog, int):
214+
exog_state_names = [f"exogenous_{i}" for i in range(k_exog)]
215+
elif isinstance(k_exog, dict):
216+
exog_state_names = {
217+
name: [f"{name}_exogenous_{i}" for i in range(k)] for name, k in k_exog.items()
218+
}
219+
220+
if k_exog is None and exog_state_names is not None:
221+
if isinstance(exog_state_names, list):
222+
k_exog = len(exog_state_names)
223+
elif isinstance(exog_state_names, dict):
224+
k_exog = {name: len(names) for name, names in exog_state_names.items()}
225+
226+
# If exog_state_names is a dict but 1) all endog variables are among the keys, and 2) all values are the same
227+
# then we can drop back to the list case.
228+
if (
229+
isinstance(exog_state_names, dict)
230+
and set(exog_state_names.keys()) == set(endog_names)
231+
and len({frozenset(val) for val in exog_state_names.values()}) == 1
232+
):
233+
exog_state_names = exog_state_names[endog_names[0]]
234+
k_exog = len(exog_state_names)
235+
171236
self.endog_names = list(endog_names)
237+
self.exog_state_names = exog_state_names
238+
239+
self.k_exog = k_exog
172240
self.p, self.q = order
173241
self.stationary_initialization = stationary_initialization
174242

@@ -208,6 +276,14 @@ def param_names(self):
208276
names.remove("ar_params")
209277
if self.q == 0:
210278
names.remove("ma_params")
279+
280+
# Add exogenous regression coefficents rather than remove, since we might have to handle
281+
# several (if self.exog_state_names is a dict)
282+
if isinstance(self.exog_state_names, list):
283+
names.append("beta_exog")
284+
elif isinstance(self.exog_state_names, dict):
285+
names.extend([f"beta_{name}" for name in self.exog_state_names.keys()])
286+
211287
return names
212288

213289
@property
@@ -239,19 +315,65 @@ def param_info(self) -> dict[str, dict[str, Any]]:
239315
},
240316
}
241317

318+
if isinstance(self.exog_state_names, list):
319+
k_exog = len(self.exog_state_names)
320+
info["beta_exog"] = {
321+
"shape": (self.k_endog, k_exog),
322+
"constraints": "None",
323+
}
324+
325+
elif isinstance(self.exog_state_names, dict):
326+
for name, exog_names in self.exog_state_names.items():
327+
k_exog = len(exog_names)
328+
info[f"beta_{name}"] = {
329+
"shape": (k_exog,),
330+
"constraints": "None",
331+
}
332+
242333
for name in self.param_names:
243334
info[name]["dims"] = self.param_dims[name]
244335

245336
return {name: info[name] for name in self.param_names}
246337

338+
@property
339+
def data_info(self) -> dict[str, dict[str, Any]]:
340+
info = None
341+
342+
if isinstance(self.exog_state_names, list):
343+
info = {
344+
"exogenous_data": {
345+
"dims": (TIME_DIM, EXOGENOUS_DIM),
346+
"shape": (None, self.k_exog),
347+
}
348+
}
349+
350+
elif isinstance(self.exog_state_names, dict):
351+
info = {
352+
f"{endog_state}_exogenous_data": {
353+
"dims": (TIME_DIM, f"{EXOGENOUS_DIM}_{endog_state}"),
354+
"shape": (None, len(exog_names)),
355+
}
356+
for endog_state, exog_names in self.exog_state_names.items()
357+
}
358+
359+
return info
360+
361+
@property
362+
def data_names(self) -> list[str]:
363+
if isinstance(self.exog_state_names, list):
364+
return ["exogenous_data"]
365+
elif isinstance(self.exog_state_names, dict):
366+
return [f"{endog_state}_exogenous_data" for endog_state in self.exog_state_names.keys()]
367+
return []
368+
247369
@property
248370
def state_names(self):
249371
state_names = self.endog_names.copy()
250372
state_names += [
251-
f"L{i + 1}.{state}" for i in range(self.p - 1) for state in self.endog_names
373+
f"L{i + 1}_{state}" for i in range(self.p - 1) for state in self.endog_names
252374
]
253375
state_names += [
254-
f"L{i + 1}.{state}_innov" for i in range(self.q) for state in self.endog_names
376+
f"L{i + 1}_{state}_innov" for i in range(self.q) for state in self.endog_names
255377
]
256378

257379
return state_names
@@ -276,6 +398,12 @@ def coords(self) -> dict[str, Sequence]:
276398
if self.q > 0:
277399
coords.update({MA_PARAM_DIM: list(range(1, self.q + 1))})
278400

401+
if isinstance(self.exog_state_names, list):
402+
coords[EXOGENOUS_DIM] = self.exog_state_names
403+
elif isinstance(self.exog_state_names, dict):
404+
for name, exog_names in self.exog_state_names.items():
405+
coords[f"{EXOGENOUS_DIM}_{name}"] = exog_names
406+
279407
return coords
280408

281409
@property
@@ -299,6 +427,14 @@ def param_dims(self):
299427
del coord_map["P0"]
300428
del coord_map["x0"]
301429

430+
if isinstance(self.exog_state_names, list):
431+
coord_map["beta_exog"] = (OBS_STATE_DIM, EXOGENOUS_DIM)
432+
elif isinstance(self.exog_state_names, dict):
433+
# If each state has its own exogenous variables, each parameter needs it own dim, since we expect the
434+
# dim labels to all be different (otherwise we'd be in the list case).
435+
for name in self.exog_state_names.keys():
436+
coord_map[f"beta_{name}"] = (f"{EXOGENOUS_DIM}_{name}",)
437+
302438
return coord_map
303439

304440
def add_default_priors(self):
@@ -386,6 +522,61 @@ def make_symbolic_graph(self) -> None:
386522
)
387523
self.ssm["state_cov", :, :] = state_cov
388524

525+
if self.exog_state_names is not None:
526+
if isinstance(self.exog_state_names, list):
527+
beta_exog = self.make_and_register_variable(
528+
"beta_exog", shape=(self.k_posdef, self.k_exog), dtype=floatX
529+
)
530+
exog_data = self.make_and_register_data(
531+
"exogenous_data", shape=(None, self.k_exog), dtype=floatX
532+
)
533+
534+
obs_intercept = exog_data @ beta_exog.T
535+
536+
elif isinstance(self.exog_state_names, dict):
537+
obs_components = []
538+
for i, name in enumerate(self.endog_names):
539+
if name in self.exog_state_names:
540+
k_exog = len(self.exog_state_names[name])
541+
beta_exog = self.make_and_register_variable(
542+
f"beta_{name}", shape=(k_exog,), dtype=floatX
543+
)
544+
exog_data = self.make_and_register_data(
545+
f"{name}_exogenous_data", shape=(None, k_exog), dtype=floatX
546+
)
547+
obs_components.append(pt.expand_dims(exog_data @ beta_exog, axis=-1))
548+
else:
549+
obs_components.append(pt.zeros((1, 1), dtype=floatX))
550+
551+
# TODO: Replace all of this with pt.concat_with_broadcast once PyMC works with pytensor >= 2.32
552+
553+
# If there were any zeros, they need to be broadcast against the non-zeros.
554+
# Core shape is the last dim, the time dim is always broadcast
555+
non_concat_shape = [1, None]
556+
557+
# Look for the first non-zero component to get the shape from
558+
for tensor_inp in obs_components:
559+
for i, (bcast, sh) in enumerate(
560+
zip(tensor_inp.type.broadcastable, tensor_inp.shape)
561+
):
562+
if bcast or i == 1:
563+
continue
564+
non_concat_shape[i] = sh
565+
566+
assert non_concat_shape.count(None) == 1
567+
568+
bcast_tensor_inputs = []
569+
for tensor_inp in obs_components:
570+
non_concat_shape[1] = tensor_inp.shape[1]
571+
bcast_tensor_inputs.append(pt.broadcast_to(tensor_inp, non_concat_shape))
572+
573+
obs_intercept = pt.join(1, *bcast_tensor_inputs)
574+
575+
else:
576+
raise NotImplementedError()
577+
578+
self.ssm["obs_intercept"] = obs_intercept
579+
389580
if self.stationary_initialization:
390581
# Solve for matrix quadratic for P0
391582
T = self.ssm["transition"]

0 commit comments

Comments
 (0)