diff --git a/docs/tutorials/.pages b/docs/tutorials/.pages index 875214fc..4f1079b3 100644 --- a/docs/tutorials/.pages +++ b/docs/tutorials/.pages @@ -3,5 +3,7 @@ nav: - basic_renewal_model.md - custom_randomvariables.md - hospital_admissions_model.md + - observation_processes_counts.md + - observation_processes_measurements.md - day_of_the_week.md - periodic_effects.md diff --git a/docs/tutorials/observation_processes_counts.qmd b/docs/tutorials/observation_processes_counts.qmd new file mode 100644 index 00000000..b3b415c7 --- /dev/null +++ b/docs/tutorials/observation_processes_counts.qmd @@ -0,0 +1,425 @@ +--- +title: "Observation processes for count data" +format: gfm +engine: jupyter +--- + +This tutorial demonstrates how to use the `Counts` observation process to model count data such as hospital admissions, emergency department visits, or deaths. + +```{python} +# | label: setup +# | output: false +import jax.numpy as jnp +import numpy as np +import numpyro +import matplotlib.pyplot as plt + +from pyrenew.observation import Counts, NegativeBinomialNoise, PoissonNoise +from pyrenew.deterministic import DeterministicVariable, DeterministicPMF +from pyrenew import datasets +``` + +## Overview + +Count observation processes model the lag between infections and an observed outcome such as hospital admissions, emergency department visits, confirmed cases, or deaths. +Observed data can be aggregated or available as subpopulation-level counts, which are modeled by classes `Counts` and `CountsBySubpop`, respectively. + +Count observation processes transform infections into expected observed counts by applying an ascertainment rate and convolving with a delay distribution. + +The expected observations on day $t$ are: + +$$\lambda_t = \alpha \cdot \sum_{d=0}^{D} I_{t-d} \cdot p_d$$ + +where: + +- $I_{t-d}$ is the number of incident (new) infections on day $t-d$ (i.e., $d$ days before day $t$) +- $\alpha$ is the ascertainment rate (e.g., infection-hospitalization ratio) +- $p_d$ is the delay distribution from infection to observation, conditional on an infection leading to an observation +- $D$ is the maximum delay + +Discrete observations are generated by sampling from a noise distribution—either Poisson or negative binomial—to model reporting variability. +Poisson assumes variance equals the mean; negative binomial accommodates the overdispersion common in surveillance data. + +**Note on terminology:** In real-world inference, infections are *latent* (unobserved) and must be estimated from observed data like hospital admissions. In this tutorial, we simulate the observation process by specifying infections directly and showing how they produce hospital admissions through convolution and sampling. + +## Hospital admissions example + +For hospital admissions data, we construct a `Counts` observation process. +The delay is the key mechanism: infections from $d$ days ago ($I_{t-d}$) contribute to today's hospital admissions ($\lambda_t$) weighted by the probability ($p_d$) that an infection leads to hospitalization after exactly $d$ days. The convolution sums these contributions across all past days. + +The process generates hospital admissions by sampling from a negative binomial distribution: +$$Y_t \sim \text{NegativeBinomial}(\mu = \lambda_t, \text{concentration} = \phi)$$ + +The concentration parameter $\phi$ (sometimes called $k$ or the dispersion parameter) controls overdispersion: as $\phi \to \infty$, the distribution approaches Poisson; smaller values allow greater overdispersion. + +We use the negative binomial distribution because real-world hospital admission counts exhibit overdispersion—the variance exceeds the mean. +The Poisson distribution assumes variance equals the mean, which is too restrictive. The negative binomial adds an overdispersion term: +$$\text{Var}[Y_t] = \mu + \frac{\mu^2}{\phi}$$ + +In this example, we use fixed parameter values for illustration; in practice, these parameters would be estimated from data using weakly informative priors. + +## Infection-to-hospitalization delay distribution + +The delay distribution specifies the probability that an infected person is hospitalized $d$ days after infection, conditional on the infection leading to a hospitalization. +For example, if `hosp_delay_pmf[5] = 0.2`, then 20% of infections that result in hospitalization will appear as hospital admissions 5 days after infection. + +We load a delay distribution from PyRenew's datasets: + +```{python} +# | label: load-delay +inf_hosp_int = datasets.load_infection_admission_interval() +hosp_delay_pmf = jnp.array(inf_hosp_int["probability_mass"].to_numpy()) + +delay_rv = DeterministicPMF("inf_to_hosp_delay", hosp_delay_pmf) + +# Summary statistics +days = np.arange(len(hosp_delay_pmf)) +mean_delay = float(np.sum(days * hosp_delay_pmf)) +mode_delay = int(np.argmax(hosp_delay_pmf)) +print(f"Mode delay: {mode_delay} days, Mean delay: {mean_delay:.1f} days") +``` + +```{python} +# | label: fig-delay-distribution +# | fig-cap: Infection-to-hospitalization delay distribution +fig, ax = plt.subplots(figsize=(8, 4)) +ax.bar(days, hosp_delay_pmf, color="steelblue", alpha=0.7, edgecolor="black") +ax.axvline( + mode_delay, + color="purple", + linestyle="-", + linewidth=2, + label=f"Mode: {mode_delay}", +) +ax.axvline( + mean_delay, + color="red", + linestyle="--", + linewidth=2, + label=f"Mean: {mean_delay:.1f}", +) +ax.set_xlabel("Days from infection to hospitalization") +ax.set_ylabel("Probability") +ax.set_title("Infection-to-Hospitalization Delay Distribution") +ax.legend() +plt.tight_layout() +plt.show() +``` + +## Creating a Counts observation process + +A `Counts` object takes the following arguments: + +- **`ascertainment_rate_rv`**: the probability an infection results in an observation (e.g., IHR) +- **`delay_distribution_rv`**: delay distribution from infection to observation (PMF) +- **`noise`**: noise model (`PoissonNoise()` or `NegativeBinomialNoise(concentration_rv)`) + +```{python} +# | label: create-counts-process +# Infection-hospitalization ratio (1% of infections lead to hospitalization) +ihr_rv = DeterministicVariable("ihr", 0.01) + +# Overdispersion parameter for negative binomial +concentration_rv = DeterministicVariable("concentration", 10.0) + +# Create the observation process +hosp_process = Counts( + ascertainment_rate_rv=ihr_rv, + delay_distribution_rv=delay_rv, + noise=NegativeBinomialNoise(concentration_rv), +) + +print(f"Required lookback: {hosp_process.lookback_days()} days") +``` + +### Timeline alignment and lookback period + +The observation process convolves infections with a delay distribution, maintaining alignment between input and output: day $t$ in the output corresponds to day $t$ in the input. + +Hospital admissions depend on infections from prior days (the length of our delay distribution minus one). The method `lookback_days()` returns this length; the first valid observation day is at index `lookback - 1`. Earlier days are marked invalid. + +```{python} +# | label: helper-function +def first_valid_observation_day(obs_process) -> int: + """Return the first day index with complete infection history for convolution.""" + return obs_process.lookback_days() - 1 +``` + +## Simulating hospital admissions from infections + +To demonstrate how the observation process works, we simulate admissions from a spike of infections on a single day. + +```{python} +# | label: simulate-spike +n_days = 100 +lookback = hosp_process.lookback_days() + +# First valid observation day +day_one = lookback - 1 + +# Create infections with a spike +infection_spike_day = day_one + 10 +infections = jnp.zeros(n_days) +infections = infections.at[infection_spike_day].set(2000) + +# Sample hospital admissions +with numpyro.handlers.seed(rng_seed=42): + hosp_admissions = hosp_process.sample( + infections=infections, + counts=None, # Sample from prior (no observed data) + ) +``` + +```{python} +# | label: fig-spike-infections +# | fig-cap: Input infections with a single-day spike +# Plot relative to first valid observation day +spike_day = infection_spike_day - day_one +n_plot_days = n_days - day_one + +fig, ax = plt.subplots(figsize=(8, 4)) +ax.plot( + np.arange(n_plot_days), + np.array(infections[day_one:]), + "-o", + color="darkblue", +) +ax.axvline(spike_day, color="darkred", linestyle="--", alpha=0.7) +ax.annotate( + f"Infection spike\n(day {spike_day})", + xy=(spike_day, 1800), + xytext=(spike_day + 5, 1800), + fontsize=10, + color="darkred", +) +ax.set_xlabel("Day") +ax.set_ylabel("Daily Infections") +ax.set_title("Infections (Input)") +plt.tight_layout() +plt.show() +``` + +Because all infections occur on a single day, we can see how they spread into hospital admissions over subsequent days according to the delay distribution: + +```{python} +# | label: fig-spike-admissions +# | fig-cap: Hospital admissions from a single-day infection spike +fig, ax = plt.subplots(figsize=(8, 4)) +ax.plot( + np.arange(n_plot_days), + np.array(hosp_admissions.observed[day_one:]), + "-o", + color="purple", +) +ax.axvline( + spike_day, + color="darkred", + linestyle="--", + alpha=0.5, + label="Infection spike", +) +ax.axvline( + spike_day + mode_delay, + color="purple", + linestyle="--", + alpha=0.5, + label="Peak admissions", +) +ax.set_xlabel("Day") +ax.set_ylabel("Hospital Admissions") +ax.set_title("Hospital Admissions (Output)") +ax.legend() +plt.tight_layout() +plt.show() +``` + +The admissions peak occurs `{python} mode_delay` days after the infection spike, matching the mode of the delay distribution. + +## Observation noise + +The negative binomial distribution adds stochastic variation. Sampling multiple times from the same infections shows the range of possible observations: + +```{python} +# | label: sample-realizations +n_samples = 50 +samples = [] + +for seed in range(n_samples): + with numpyro.handlers.seed(rng_seed=seed): + result = hosp_process.sample(infections=infections, counts=None) + samples.append(np.array(result.observed[day_one:])) + +samples = np.array(samples) +sample_mean = samples.mean(axis=0) +``` + +```{python} +# | label: fig-sampled-admissions +# | fig-cap: Multiple realizations showing observation noise +fig, ax = plt.subplots(figsize=(8, 4)) + +# Plot all samples in light orange +for i in range(1, n_samples): + ax.plot( + np.arange(n_plot_days), + samples[i], + color="orange", + alpha=0.15, + linewidth=0.5, + ) + +# Highlight one sample +ax.plot( + np.arange(n_plot_days), + samples[0], + color="steelblue", + linewidth=1, + label="One realization", +) + +# Sample mean +ax.plot( + np.arange(n_plot_days), + sample_mean, + color="darkred", + linewidth=1.2, + label="Sample mean", +) + +ax.axvline(spike_day, color="darkblue", linestyle="--", alpha=0.5) +ax.set_xlabel("Day") +ax.set_ylabel("Hospital Admissions") +ax.set_title(f"Observation Noise: {n_samples} Samples from Same Infections") +ax.legend() +plt.tight_layout() +plt.show() +``` + +```{python} +# | label: timeline-stats +# Print timeline statistics +print("Timeline Analysis:") +print( + f" Infection spike on day {spike_day}: {infections[infection_spike_day]:.0f} people" +) +print(f" Mode delay from infection to hospitalization: {mode_delay} days") +print( + f" Expected hospitalization peak: day {spike_day + mode_delay} (= {spike_day} + {mode_delay})" +) +``` + +## Effect of the ascertainment rate + +The ascertainment rate directly scales expected hospital admissions. We compare IHR values of 0.5% and 2.5%: + +```{python} +# | label: compare-ihr +ihr_values = [0.005, 0.025] +infections_decay = 3000 * jnp.exp(-jnp.arange(n_days) / 20.0) + +results = {} +for ihr_val in ihr_values: + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", ihr_val), + delay_distribution_rv=delay_rv, + noise=NegativeBinomialNoise(concentration_rv), + ) + # Average over multiple samples to show the effect on the mean + samples = [] + for seed in range(20): + with numpyro.handlers.seed(rng_seed=seed): + result = process.sample(infections=infections_decay, counts=None) + samples.append(np.array(result.observed[day_one:])) + results[ihr_val] = np.mean(samples, axis=0) +``` + +```{python} +# | label: fig-ihr-comparison +# | fig-cap: Effect of different infection-hospitalization rates +fig, ax = plt.subplots(figsize=(8, 4)) +colors = ["steelblue", "darkred"] + +for (ihr_val, mean_sample), color in zip(results.items(), colors): + ax.plot( + np.arange(n_plot_days), + mean_sample, + color=color, + linewidth=1.5, + label=f"IHR = {ihr_val:.1%}", + ) + +ax.set_xlabel("Day") +ax.set_ylabel("Hospital Admissions (mean of samples)") +ax.set_title("Effect of IHR on Hospital Admissions") +ax.legend() +plt.tight_layout() +plt.show() +``` + +## Negative binomial concentration parameter + +The concentration parameter $\phi$ controls overdispersion: + +- Higher $\phi$ → less overdispersion (approaches Poisson) +- Lower $\phi$ → more overdispersion (noisier data) + +```{python} +# | label: fig-concentration-comparison +# | fig-cap: Effect of concentration parameter on variability +infections_constant = 2000 * jnp.ones(n_days) +concentration_values = [1.0, 10.0, 100.0] +n_replicates = 10 + +fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True) + +for ax, conc_val in zip(axes, concentration_values): + process = Counts( + ascertainment_rate_rv=ihr_rv, + delay_distribution_rv=delay_rv, + noise=NegativeBinomialNoise(DeterministicVariable("conc", conc_val)), + ) + + for seed in range(n_replicates): + with numpyro.handlers.seed(rng_seed=seed): + result = process.sample( + infections=infections_constant, counts=None + ) + ax.plot( + np.arange(n_plot_days), + np.array(result.observed[day_one:]), + color="steelblue", + alpha=0.5, + linewidth=0.8, + ) + + ax.set_xlabel("Day") + ax.set_title(f"φ = {int(conc_val)}") + +axes[0].set_ylabel("Hospital Admissions") +fig.suptitle("Effect of Concentration Parameter on Variability") +plt.tight_layout() +plt.show() +``` + +## Swapping noise models + +To use Poisson noise instead of negative binomial, change the noise model: + +```{python} +# | label: poisson-noise +hosp_process_poisson = Counts( + ascertainment_rate_rv=ihr_rv, + delay_distribution_rv=delay_rv, + noise=PoissonNoise(), +) + +with numpyro.handlers.seed(rng_seed=42): + poisson_result = hosp_process_poisson.sample( + infections=infections, + counts=None, + ) + +print( + f"Sampled {len(poisson_result.observed)} days of hospital admissions with Poisson noise" +) +``` diff --git a/docs/tutorials/observation_processes_measurements.qmd b/docs/tutorials/observation_processes_measurements.qmd new file mode 100644 index 00000000..21dbff21 --- /dev/null +++ b/docs/tutorials/observation_processes_measurements.qmd @@ -0,0 +1,573 @@ +--- +title: "Observation processes for continuous measurements" +format: gfm +engine: jupyter +--- + +This tutorial demonstrates how to use the `Measurements` observation process to model continuous measurement data such as wastewater viral concentrations. + +```{python} +# | label: setup +# | output: false +import jax +import jax.numpy as jnp +import numpy as np +import numpyro +import matplotlib.pyplot as plt + +from pyrenew.observation import Measurements, HierarchicalNormalNoise +from pyrenew.randomvariable import HierarchicalNormalPrior, GammaGroupSdPrior +from pyrenew.deterministic import DeterministicVariable, DeterministicPMF +``` + +## Overview + +Measurement observation processes model continuous signals derived from infections, such as viral RNA concentrations in wastewater. Unlike count observations (hospital admissions, deaths), measurements are continuous values that span orders of magnitude. + +The expected measurement on day $t$ is: + +$$\lambda_t = \frac{G}{V} \cdot \sum_{d=0}^{D} I_{t-d} \cdot p_d$$ + +where: + +- $I_{t-d}$ is the number of incident (new) infections on day $t-d$ +- $G$ is a scaling factor (e.g., genome copies shed per infection) +- $V$ is a normalization factor (e.g., wastewater volume per person per day) +- $p_d$ is the temporal distribution (e.g., viral shedding kinetics) +- $D$ is the maximum duration + +Observed log-concentrations are generated by sampling from a normal distribution: + +$$\log(C_t) \sim \text{Normal}(\mu = \log(\lambda_t), \sigma)$$ + +The log-normal distribution is appropriate for concentration data because concentrations are strictly positive and often span several orders of magnitude. + +### Comparison with count observations + +The core convolution $\sum_{d=0}^{D} I_{t-d} \cdot p_d$ is the same as for count observations. The key differences are: + +| Aspect | Counts | Measurements | +|--------|--------|--------------| +| Scaling factor | Ascertainment rate $\alpha \in [0,1]$ | Domain-specific (e.g., $G/V$) | +| Temporal PMF | Delay distribution | Shedding/decay kinetics | +| Output space | Expected counts (linear) | Log-concentrations | +| Noise model | Poisson or Negative Binomial | Normal on log scale | +| Subpop structure | Optional (`CountsBySubpop`) | Inherent (hierarchical effects) | + +**Key features of measurement data:** + +- **Multiple sites**: Each jurisdiction has several measurement sites (e.g., wastewater treatment plants). +- **Irregular sampling**: Sites measure on different schedules. +- **Site-level variability**: Lab protocols and sampling methods vary across sites. +- **Temporal lag**: Signal peaks several days after infection. + +**Note on terminology:** In real-world inference, infections are *latent* (unobserved) and must be estimated from observed data. In this tutorial, we simulate the observation process by specifying infections directly and showing how they produce measurements through convolution and sampling. + +## Subclassing Measurements for wastewater + +The `Measurements` class is abstract—you must subclass it and implement `_expected_signal()` for your specific signal type. Here we create a `Wastewater` class for viral concentration measurements: + +```{python} +# | label: wastewater-class +from jax.typing import ArrayLike +from pyrenew.metaclass import RandomVariable +from pyrenew.observation.noise import MeasurementNoise + + +class Wastewater(Measurements): + """ + Wastewater viral concentration observation process. + + Transforms site-level infections into expected log-concentrations + via shedding kinetics convolution and genome/volume scaling. + """ + + def __init__( + self, + shedding_kinetics_rv: RandomVariable, + log10_genome_per_infection_rv: RandomVariable, + ml_per_person_per_day: float, + noise: MeasurementNoise, + ) -> None: + """ + Initialize wastewater observation process. + + Parameters + ---------- + shedding_kinetics_rv : RandomVariable + Viral shedding PMF (fraction shed each day post-infection). + log10_genome_per_infection_rv : RandomVariable + Log10 genome copies shed per infection. + ml_per_person_per_day : float + Wastewater volume per person per day (mL). + noise : MeasurementNoise + Noise model (e.g., HierarchicalNormalNoise). + """ + super().__init__(temporal_pmf_rv=shedding_kinetics_rv, noise=noise) + self.log10_genome_per_infection_rv = log10_genome_per_infection_rv + self.ml_per_person_per_day = ml_per_person_per_day + + def validate(self) -> None: + """Validate parameters.""" + shedding_pmf = self.temporal_pmf_rv() + self._validate_pmf(shedding_pmf, "shedding_kinetics_rv") + self.noise.validate() + + def lookback_days(self) -> int: + """Return shedding PMF length.""" + return len(self.temporal_pmf_rv()) + + def _expected_signal(self, infections: ArrayLike) -> ArrayLike: + """ + Compute expected log-concentration from infections. + + Applies shedding kinetics convolution, then scales by + genome copies and volume to get concentration. + """ + shedding_pmf = self.temporal_pmf_rv() + log10_genome = self.log10_genome_per_infection_rv() + + # Convolve each site's infections with shedding kinetics + def convolve_site(site_infections): + convolved, _ = self._convolve_with_alignment( + site_infections, shedding_pmf, p_observed=1.0 + ) + return convolved + + # Apply to all subpops (infections shape: n_days x n_subpops) + shedding_signal = jax.vmap(convolve_site, in_axes=1, out_axes=1)( + infections + ) + + # Convert to concentration: genomes per mL + genome_copies = 10**log10_genome + concentration = ( + shedding_signal * genome_copies / self.ml_per_person_per_day + ) + + # Return log-concentration (what we model) + return jnp.log(concentration) +``` + +## Viral shedding kinetics + +The shedding PMF describes what fraction of total viral shedding occurs on each day after infection: + +```{python} +# | label: shedding-pmf +# Peak shedding ~3 days after infection, continues for ~10 days +shedding_pmf = jnp.array( + [0.0, 0.05, 0.15, 0.25, 0.20, 0.15, 0.10, 0.05, 0.03, 0.02] +) +print(f"PMF sums to: {shedding_pmf.sum():.2f}") + +shedding_rv = DeterministicPMF("viral_shedding", shedding_pmf) + +# Summary statistics +days = np.arange(len(shedding_pmf)) +mean_shedding_day = float(np.sum(days * shedding_pmf)) +mode_shedding_day = int(np.argmax(shedding_pmf)) +print(f"Mode: {mode_shedding_day} days, Mean: {mean_shedding_day:.1f} days") +``` + +```{python} +# | label: fig-shedding +# | fig-cap: Viral shedding kinetics distribution +fig, ax = plt.subplots(figsize=(8, 4)) +ax.bar(days, shedding_pmf, color="steelblue", alpha=0.7, edgecolor="black") +ax.axvline( + mode_shedding_day, + color="purple", + linestyle="-", + linewidth=2, + label=f"Mode: {mode_shedding_day}", +) +ax.axvline( + mean_shedding_day, + color="red", + linestyle="--", + linewidth=2, + label=f"Mean: {mean_shedding_day:.1f}", +) +ax.set_xlabel("Days after infection") +ax.set_ylabel("Fraction of total shedding") +ax.set_title("Viral Shedding Kinetics") +ax.legend() +plt.tight_layout() +plt.show() +``` + +## Genome copies and wastewater volume + +```{python} +# | label: scaling-params +# Log10 genome copies shed per infection (typical: 8-10) +log10_genome_rv = DeterministicVariable("log10_genome", 9.0) + +# Wastewater volume per person per day (mL) +ml_per_person_per_day = 1000.0 +``` + +## Noise model with sensor-level effects + +A measurement depends on both the thing being measured (e.g., wastewater from a treatment plant) and the calibration of the lab/instruments used to obtain that measurement. We call this combination a "sensor"—the WWTP/lab pair that determines the measurement characteristics. Different sensors have systematic biases and variabilities that we model with hierarchical effects. + +```{python} +# | label: noise-model +# Sensor-level mode: systematic differences between sensors +sensor_mode_prior = HierarchicalNormalPrior( + name="ww_sensor_mode", + sd_rv=DeterministicVariable("mode_sd", 0.5), +) + +# Sensor-level SD: measurement variability within each sensor +sensor_sd_prior = GammaGroupSdPrior( + name="ww_sensor_sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.3), + sd_concentration_rv=DeterministicVariable("sd_concentration", 4.0), + sd_min=0.10, +) + +# Create the noise model +ww_noise = HierarchicalNormalNoise( + sensor_mode_prior_rv=sensor_mode_prior, + sensor_sd_prior_rv=sensor_sd_prior, +) +``` + +## Creating the observation process + +```{python} +# | label: create-process +ww_process = Wastewater( + shedding_kinetics_rv=shedding_rv, + log10_genome_per_infection_rv=log10_genome_rv, + ml_per_person_per_day=ml_per_person_per_day, + noise=ww_noise, +) + +print(f"Required lookback: {ww_process.lookback_days()} days") +``` + +### Timeline alignment and lookback period + +The observation process convolves infections with a shedding distribution, maintaining alignment between input and output: day $t$ in the output corresponds to day $t$ in the input. + +Wastewater concentrations depend on infections from prior days (the length of our shedding distribution minus one). The method `lookback_days()` returns this length; the first valid observation day is at index `lookback - 1`. Earlier days are marked invalid with NaN. + +```{python} +# | label: helper-function +def first_valid_observation_day(obs_process) -> int: + """Return the first day index with complete infection history for convolution.""" + return obs_process.lookback_days() - 1 +``` + +## Simulating wastewater observations from infections + +To demonstrate how the observation process works, we simulate concentrations from a spike of infections on a single day. + +```{python} +# | label: simulate-spike +n_days = 50 + +# First valid observation day +day_one = first_valid_observation_day(ww_process) + +# Create infections with a spike (shape: n_days x n_subpops) +infection_spike_day = day_one + 10 +infections = jnp.zeros((n_days, 1)) # 1 subpopulation +infections = infections.at[infection_spike_day, 0].set(2000.0) + +# For plotting +spike_day = infection_spike_day - day_one +n_plot_days = n_days - day_one + +# Observation times and indices +observation_days = jnp.arange(day_one, 40, dtype=jnp.int32) +n_obs = len(observation_days) + +with numpyro.handlers.seed(rng_seed=42): + ww_obs = ww_process.sample( + infections=infections, + subpop_indices=jnp.zeros(n_obs, dtype=jnp.int32), + sensor_indices=jnp.zeros(n_obs, dtype=jnp.int32), + times=observation_days, + concentrations=None, # Sample from prior + n_sensors=1, + ) +``` + +```{python} +# | label: fig-spike-infections +# | fig-cap: Input infections with a single-day spike +fig, ax = plt.subplots(figsize=(8, 4)) +ax.plot( + np.arange(n_plot_days), + np.array(infections[day_one:, 0]), + "-o", + color="darkblue", +) +ax.axvline(spike_day, color="darkred", linestyle="--", alpha=0.7) +ax.annotate( + f"Infection spike\n(day {spike_day})", + xy=(spike_day, 1800), + xytext=(spike_day + 3, 1800), + fontsize=10, + color="darkred", +) +ax.set_xlabel("Day") +ax.set_ylabel("Daily Infections") +ax.set_title("Infections (Input)") +plt.tight_layout() +plt.show() +``` + +Because all infections occur on a single day, we can see how they spread into wastewater concentrations over subsequent days according to the shedding kinetics. + +## Observation noise + +The log-normal noise model adds stochastic variation. Sampling multiple times from the same infections shows the range of possible observations: + +```{python} +# | label: sample-realizations +n_samples = 50 +samples_by_day = {int(d - day_one): [] for d in observation_days} + +for seed in range(n_samples): + with numpyro.handlers.seed(rng_seed=seed): + ww_result = ww_process.sample( + infections=infections, + subpop_indices=jnp.zeros(n_obs, dtype=jnp.int32), + sensor_indices=jnp.zeros(n_obs, dtype=jnp.int32), + times=observation_days, + concentrations=None, + n_sensors=1, + ) + for day_idx, conc in zip(observation_days, ww_result.observed): + samples_by_day[int(day_idx) - day_one].append(float(conc)) + +# Convert to arrays for plotting +plot_days = sorted(samples_by_day.keys()) +all_samples = np.array( + [[samples_by_day[d][i] for d in plot_days] for i in range(n_samples)] +) +sample_mean = all_samples.mean(axis=0) +``` + +```{python} +# | label: fig-sampled-concentrations +# | fig-cap: Multiple realizations showing observation noise +fig, ax = plt.subplots(figsize=(8, 4)) + +# Plot all samples +for i in range(1, n_samples): + ax.plot( + plot_days, all_samples[i], color="orange", alpha=0.15, linewidth=0.5 + ) + +# Highlight one sample +ax.plot( + plot_days, + all_samples[0], + color="steelblue", + linewidth=1, + label="One realization", +) + +# Sample mean +ax.plot( + plot_days, sample_mean, color="darkred", linewidth=1.2, label="Sample mean" +) + +ax.axvline(spike_day, color="darkblue", linestyle="--", alpha=0.5) +ax.set_xlabel("Day") +ax.set_ylabel("Log Viral Concentration") +ax.set_title(f"Observation Noise: {n_samples} Samples from Same Infections") +ax.legend() +plt.tight_layout() +plt.show() +``` + +```{python} +# | label: timeline-stats +print("Timeline Analysis:") +print( + f" Infection spike on day {spike_day}: {infections[infection_spike_day, 0]:.0f} people" +) +print( + f" Mode delay from infection to concentration peak: {mode_shedding_day} days" +) +print( + f" Expected concentration peak: day {spike_day + mode_shedding_day} (= {spike_day} + {mode_shedding_day})" +) +``` + +## Sensor-level variability + +Different sensors measuring the same underlying infections will show systematic differences: + +```{python} +# | label: multi-sensor +num_sensors = 4 +infections_constant = jnp.ones((40, 1)) * 1500.0 + +# Each sensor samples at multiple time points +times_per_sensor = 10 +observation_times = jnp.tile( + jnp.arange(10, 30, 2, dtype=jnp.int32), num_sensors +) +sensor_ids = jnp.repeat( + jnp.arange(num_sensors, dtype=jnp.int32), times_per_sensor +) +subpop_ids = jnp.zeros(num_sensors * times_per_sensor, dtype=jnp.int32) + +with numpyro.handlers.seed(rng_seed=42): + ww_multi_sensor = ww_process.sample( + infections=infections_constant, + subpop_indices=subpop_ids, + sensor_indices=sensor_ids, + times=observation_times, + concentrations=None, + n_sensors=num_sensors, + ) +``` + +```{python} +# | label: fig-multi-sensor +# | fig-cap: Multiple sensors observing the same infections +fig, ax = plt.subplots(figsize=(8, 4)) + +colors = ["steelblue", "coral", "green", "purple"] +for sensor in range(num_sensors): + mask = np.array(sensor_ids) == sensor + ax.plot( + np.array(observation_times)[mask], + np.array(ww_multi_sensor.observed)[mask], + "-o", + color=colors[sensor], + label=f"Sensor {sensor}", + ) + +ax.set_xlabel("Day") +ax.set_ylabel("Log Viral Concentration") +ax.set_title("Multiple Sensors Observing Same Infections") +ax.legend() +plt.tight_layout() +plt.show() +``` + +Each sensor has a different baseline concentration (sensor mode) and different measurement variability (sensor SD). The hierarchical model learns these sensor-specific effects during inference. + +## Effect of genome shedding parameter + +Higher genome shedding means higher observed concentrations: + +```{python} +# | label: compare-genome +genome_values = [8.5, 9.0, 9.5, 10.0] +infections_decay = 2000.0 * jnp.exp(-jnp.arange(40) / 15.0) +infections_decay = infections_decay.reshape(-1, 1) + +observation_days_cmp = jnp.arange(10, 30, dtype=jnp.int32) +n_obs_cmp = len(observation_days_cmp) + +results = {} +for genome_val in genome_values: + process_temp = Wastewater( + shedding_kinetics_rv=shedding_rv, + log10_genome_per_infection_rv=DeterministicVariable( + "log10_genome", genome_val + ), + ml_per_person_per_day=ml_per_person_per_day, + noise=ww_noise, + ) + + with numpyro.handlers.seed(rng_seed=42): + ww_temp = process_temp.sample( + infections=infections_decay, + subpop_indices=jnp.zeros(n_obs_cmp, dtype=jnp.int32), + sensor_indices=jnp.zeros(n_obs_cmp, dtype=jnp.int32), + times=observation_days_cmp, + concentrations=None, + n_sensors=1, + ) + results[genome_val] = np.array(ww_temp.observed) +``` + +```{python} +# | label: fig-genome-effect +# | fig-cap: Effect of genome shedding parameter on concentrations +fig, ax = plt.subplots(figsize=(8, 4)) + +colors = ["steelblue", "coral", "green", "purple"] +for (genome_val, conc), color in zip(results.items(), colors): + ax.plot( + np.array(observation_days_cmp), + conc, + "-o", + color=color, + label=f"log10 = {genome_val}", + markersize=4, + ) + +ax.set_xlabel("Day") +ax.set_ylabel("Log Viral Concentration") +ax.set_title("Effect of Genome Shedding Parameter") +ax.legend() +plt.tight_layout() +plt.show() +``` + +## Multiple subpopulations + +In hierarchical models, each sensor serves a distinct subpopulation (catchment area). Different subpopulations can have different infection levels: + +```{python} +# | label: multi-subpop +# Two subpopulations with different infection patterns +n_days_mp = 40 +infections_subpop1 = 1000.0 * jnp.exp( + -jnp.arange(n_days_mp) / 20.0 +) # Slow decay +infections_subpop2 = 2000.0 * jnp.exp( + -jnp.arange(n_days_mp) / 10.0 +) # Fast decay +infections_multi = jnp.stack([infections_subpop1, infections_subpop2], axis=1) + +# Two sensors, each observing a different subpopulation +obs_days_mp = jnp.tile(jnp.arange(10, 30, 2, dtype=jnp.int32), 2) +subpop_ids_mp = jnp.array([0] * 10 + [1] * 10, dtype=jnp.int32) +sensor_ids_mp = jnp.array([0] * 10 + [1] * 10, dtype=jnp.int32) + +with numpyro.handlers.seed(rng_seed=42): + ww_multi_subpop = ww_process.sample( + infections=infections_multi, + subpop_indices=subpop_ids_mp, + sensor_indices=sensor_ids_mp, + times=obs_days_mp, + concentrations=None, + n_sensors=2, + ) +``` + +```{python} +# | label: fig-multi-subpop +# | fig-cap: Sensors observing different subpopulations with different infection dynamics +fig, ax = plt.subplots(figsize=(8, 4)) + +for subpop in [0, 1]: + mask = np.array(subpop_ids_mp) == subpop + ax.plot( + np.array(obs_days_mp)[mask], + np.array(ww_multi_subpop.observed)[mask], + "-o", + label=f"Subpop {subpop}", + ) + +ax.set_xlabel("Day") +ax.set_ylabel("Log Viral Concentration") +ax.set_title("Multiple Subpopulations with Different Infection Dynamics") +ax.legend() +plt.tight_layout() +plt.show() +``` diff --git a/pyrenew/observation/__init__.py b/pyrenew/observation/__init__.py index b0e04e69..8a0cdeab 100644 --- a/pyrenew/observation/__init__.py +++ b/pyrenew/observation/__init__.py @@ -1,9 +1,54 @@ # numpydoc ignore=GL08 +""" +Observation processes for connecting infections to observed data. +Architecture +------------ +``BaseObservationProcess`` is the abstract base. Concrete subclasses: + +- ``Counts``: Aggregate counts (admissions, deaths) +- ``CountsBySubpop``: Subpopulation-level counts +- ``Measurements``: Continuous subpopulation-level signals (e.g., wastewater) + +All observation processes implement: + +- ``sample()``: Sample observations given infections +- ``infection_resolution()``: returns ``"aggregate"`` or ``"subpop"`` +- ``lookback_days()``: returns required infection history length + +Noise models (``CountNoise``, ``MeasurementNoise``) are composable—pass them +to observation constructors to control the output distribution. +""" + +from pyrenew.observation.base import BaseObservationProcess +from pyrenew.observation.count_observations import Counts, CountsBySubpop +from pyrenew.observation.measurements import Measurements from pyrenew.observation.negativebinomial import NegativeBinomialObservation +from pyrenew.observation.noise import ( + CountNoise, + HierarchicalNormalNoise, + MeasurementNoise, + NegativeBinomialNoise, + PoissonNoise, +) from pyrenew.observation.poisson import PoissonObservation +from pyrenew.observation.types import ObservationSample __all__ = [ + # Existing (kept for backward compatibility) "NegativeBinomialObservation", "PoissonObservation", + # Base classes and types + "BaseObservationProcess", + "ObservationSample", + # Noise models + "CountNoise", + "PoissonNoise", + "NegativeBinomialNoise", + "MeasurementNoise", + "HierarchicalNormalNoise", + # Observation processes + "Counts", + "CountsBySubpop", + "Measurements", ] diff --git a/pyrenew/observation/base.py b/pyrenew/observation/base.py new file mode 100644 index 00000000..635d386f --- /dev/null +++ b/pyrenew/observation/base.py @@ -0,0 +1,334 @@ +# numpydoc ignore=GL08 +""" +Abstract base class for observation processes. + +Provides common functionality for observation processes that use convolution +with temporal distributions to connect infections to observed data. +""" + +from __future__ import annotations + +from abc import abstractmethod + +import jax.numpy as jnp +import numpyro +from jax.typing import ArrayLike + +from pyrenew.convolve import compute_delay_ascertained_incidence +from pyrenew.metaclass import RandomVariable + + +class BaseObservationProcess(RandomVariable): + """ + Abstract base class for observation processes that use convolution + with temporal distributions. + + This class provides common functionality for connecting infections + to observed data (e.g., hospital admissions, wastewater concentrations) + through temporal convolution operations. + + Key features provided: + + - PMF validation (sum to 1, non-negative) + - Minimum observation day calculation + - Convolution wrapper with timeline alignment + - Deterministic quantity tracking + + Subclasses must implement: + + - ``validate()``: Validate parameters (call ``_validate_pmf()`` for PMFs) + - ``lookback_days()``: Return PMF length for initialization + - ``infection_resolution()``: Return ``"aggregate"`` or ``"subpop"`` + - ``_expected_signal()``: Transform infections to expected values + - ``sample()``: Apply noise model to expected signal + + Notes + ----- + Computing expected observations on day t requires infection history + from previous days (determined by the temporal PMF length). + The first ``len(pmf) - 1`` days have insufficient history and return NaN. + + See Also + -------- + pyrenew.convolve.compute_delay_ascertained_incidence : + Underlying convolution function + pyrenew.metaclass.RandomVariable : + Base class for all random variables + """ + + def __init__(self, temporal_pmf_rv: RandomVariable) -> None: + """ + Initialize base observation process. + + Parameters + ---------- + temporal_pmf_rv : RandomVariable + The temporal distribution PMF (e.g., delay or shedding distribution). + Must sample to a 1D array that sums to ~1.0 with non-negative values. + Subclasses may have additional parameters. + + Notes + ----- + Subclasses should call ``super().__init__(temporal_pmf_rv)`` + in their constructors and may add additional parameters. + """ + self.temporal_pmf_rv = temporal_pmf_rv + + @abstractmethod + def validate(self) -> None: + """ + Validate observation process parameters. + + Subclasses must implement this method to validate all parameters. + Typically this involves calling ``_validate_pmf()`` for the PMF + and adding any additional parameter-specific validation. + + Raises + ------ + ValueError + If any parameters fail validation. + """ + pass # pragma: no cover + + @abstractmethod + def lookback_days(self) -> int: + """ + Return the number of days this observation process needs to look back. + + This determines the minimum n_initialization_points required by the + latent process when this observation is included in a multi-signal model. + + Returns + ------- + int + Number of days of infection history required. + Typically the length of the delay or shedding PMF. + + Notes + ----- + This is used by model builders to automatically compute + n_initialization_points as: + ``max(gen_int_length, max(all lookbacks)) - 1`` + """ + pass # pragma: no cover + + @abstractmethod + def infection_resolution(self) -> str: + """ + Return whether this observation uses aggregate or subpop infections. + + Returns one of: + + - ``"aggregate"``: Uses a single aggregated infection trajectory. + Shape: ``(n_days,)`` + - ``"subpop"``: Uses subpopulation-level infection trajectories. + Shape: ``(n_days, n_subpops)``, indexed via ``subpop_indices``. + + Returns + ------- + str + Either ``"aggregate"`` or ``"subpop"`` + + Examples + -------- + >>> # Aggregated count observations + >>> hosp_obs.infection_resolution() # Returns "aggregate" + >>> + >>> # Subpopulation-level observations (wastewater, subpop-specific counts) + >>> ww_obs.infection_resolution() # Returns "subpop" + + Notes + ----- + This is used by multi-signal models to route the correct infection + output to each observation process. + """ + pass # pragma: no cover + + def _validate_pmf( + self, + pmf: ArrayLike, + param_name: str, + atol: float = 1e-6, + ) -> None: + """ + Validate that an array is a valid probability mass function. + + Checks: + + - Non-empty array + - Sums to 1.0 (within tolerance) + - All non-negative values + + Parameters + ---------- + pmf : ArrayLike + The PMF array to validate + param_name : str + Name of the parameter (for error messages) + atol : float, default 1e-6 + Absolute tolerance for sum-to-one check + + Raises + ------ + ValueError + If PMF is empty, doesn't sum to 1.0 (within tolerance), + or contains negative values. + """ + if pmf.size == 0: + raise ValueError(f"{param_name} must return non-empty array") + + pmf_sum = jnp.sum(pmf) + if not jnp.isclose(pmf_sum, 1.0, atol=atol): + raise ValueError( + f"{param_name} must sum to 1.0 (±{atol}), got {float(pmf_sum):.6f}" + ) + + if jnp.any(pmf < 0): + raise ValueError(f"{param_name} must have non-negative values") + + def get_minimum_observation_day(self) -> int: + """ + Get the first day with valid (non-NaN) convolution results. + + Due to the convolution operation requiring a history window, + the first ``len(pmf) - 1`` days will have NaN values in the + output. This method returns the index of the first valid day. + + Returns + ------- + int + Day index (0-based) of first valid observation. + Equal to ``len(pmf) - 1``. + """ + pmf = self.temporal_pmf_rv() + return int(len(pmf) - 1) + + def _convolve_with_alignment( + self, + latent_incidence: ArrayLike, + pmf: ArrayLike, + p_observed: float = 1.0, + ) -> tuple[ArrayLike, int]: + """ + Convolve latent incidence with PMF while maintaining timeline alignment. + + This is a wrapper around ``compute_delay_ascertained_incidence`` that + always uses ``pad=True`` to ensure day t in the output corresponds to + day t in the input. The first ``len(pmf) - 1`` days will be NaN. + + Parameters + ---------- + latent_incidence : ArrayLike + Latent incidence time series (infections, prevalence, etc.). + Shape: (n_days,) + pmf : ArrayLike + Delay or shedding PMF. Shape: (n_pmf,) + p_observed : float, default 1.0 + Observation probability multiplier. Scales the convolution result. + + Returns + ------- + tuple[ArrayLike, int] + - convolved_array : ArrayLike + Convolved time series with same length as input. + First ``len(pmf) - 1`` days are NaN. + Shape: (n_days,) + - offset : int + Always 0 when pad=True (maintained for API compatibility) + + Notes + ----- + For t < len(pmf)-1, there is insufficient history, so output[t] = NaN. + + See Also + -------- + pyrenew.convolve.compute_delay_ascertained_incidence : + Underlying function + """ + return compute_delay_ascertained_incidence( + latent_incidence=latent_incidence, + delay_incidence_to_observation_pmf=pmf, + p_observed_given_incident=p_observed, + pad=True, # Maintains timeline alignment + ) + + def _deterministic(self, name: str, value: ArrayLike) -> None: + """ + Track a deterministic quantity in the numpyro execution trace. + + This is a convenience wrapper around ``numpyro.deterministic`` for + tracking intermediate quantities (e.g., latent admissions, expected + concentrations) that are useful for diagnostics and model checking. + These quantities are stored in MCMC samples and can be used for + model diagnostics and posterior predictive checks. + + Parameters + ---------- + name : str + Name for the tracked quantity. Will appear in MCMC samples. + value : ArrayLike + Value to track. Can be any shape. + """ + numpyro.deterministic(name, value) + + @abstractmethod + def _expected_signal( + self, + infections: ArrayLike, + ) -> ArrayLike: + """ + Transform infections to expected observation values. + + This is the core transformation that each observation process must + implement. It converts infections (from the infection process) + to expected values for the observation model. + + Parameters + ---------- + infections : ArrayLike + Infections from the infection process. + Shape: (n_days,) for aggregate observations + Shape: (n_days, n_subpops) for subpop-level observations + + Returns + ------- + ArrayLike + Expected observation values (counts, log-concentrations, etc.). + Same shape as input, with first len(pmf)-1 days as NaN. + + Notes + ----- + The transformation is observation-specific: + + - Count observations: ascertainment x delay convolution -> expected counts + - Wastewater: shedding convolution -> genome scaling -> dilution -> log + + See Also + -------- + sample : Uses this method then applies noise model + """ + pass # pragma: no cover + + @abstractmethod + def sample(self, **kwargs) -> ArrayLike: + """ + Sample from the observation process. + + Subclasses must implement this method to define the specific + observation model. Typically calls ``_expected_signal`` first, + then applies the noise model. + + Parameters + ---------- + **kwargs + Subclass-specific parameters. At minimum, should include: + + - infections from the infection process + - Observed data (or None for prior predictive sampling) + + Returns + ------- + ArrayLike + Observed or sampled values from the observation process. + """ + pass # pragma: no cover diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py new file mode 100644 index 00000000..7a883932 --- /dev/null +++ b/pyrenew/observation/count_observations.py @@ -0,0 +1,348 @@ +# numpydoc ignore=GL08 +""" +Count observations with composable noise models. + +Ascertainment x delay convolution with pluggable noise (Poisson, Negative Binomial, etc.). +""" + +from __future__ import annotations + +import jax +import jax.numpy as jnp +from jax.typing import ArrayLike + +from pyrenew.metaclass import RandomVariable +from pyrenew.observation.base import BaseObservationProcess +from pyrenew.observation.noise import CountNoise +from pyrenew.observation.types import ObservationSample + + +class _CountBase(BaseObservationProcess): + """ + Internal base for count observation processes. + + Implements ascertainment x delay convolution with pluggable noise model. + """ + + def __init__( + self, + ascertainment_rate_rv: RandomVariable, + delay_distribution_rv: RandomVariable, + noise: CountNoise, + ) -> None: + """ + Initialize count observation base. + + Parameters + ---------- + ascertainment_rate_rv : RandomVariable + Ascertainment rate in [0, 1] (e.g., IHR, IER). + delay_distribution_rv : RandomVariable + Delay distribution PMF (must sum to ~1.0). + noise : CountNoise + Noise model for count observations (Poisson, NegBin, etc.). + """ + super().__init__(temporal_pmf_rv=delay_distribution_rv) + self.ascertainment_rate_rv = ascertainment_rate_rv + self.noise = noise + + def validate(self) -> None: + """ + Validate observation parameters. + + Raises + ------ + ValueError + If delay PMF invalid, ascertainment rate outside [0,1], + or noise params invalid. + """ + delay_pmf = self.temporal_pmf_rv() + self._validate_pmf(delay_pmf, "delay_distribution_rv") + + ascertainment_rate = self.ascertainment_rate_rv() + if jnp.any(ascertainment_rate < 0) or jnp.any(ascertainment_rate > 1): + raise ValueError( + "ascertainment_rate_rv must be in [0, 1], " + "got value(s) outside this range" + ) + + self.noise.validate() + + def lookback_days(self) -> int: + """ + Return delay PMF length. + + Returns + ------- + int + Length of delay distribution PMF. + """ + return len(self.temporal_pmf_rv()) + + def infection_resolution(self) -> str: + """ + Return required infection resolution. + + Returns + ------- + str + "aggregate" or "subpop". + """ + raise NotImplementedError("Subclasses must implement infection_resolution()") + + def _expected_signal( + self, + infections: ArrayLike, + ) -> ArrayLike: + """ + Compute expected counts via ascertainment x delay convolution. + + Parameters + ---------- + infections : ArrayLike + Infections from the infection process. + Shape: (n_days,) for aggregate + Shape: (n_days, n_subpops) for subpop-level + + Returns + ------- + ArrayLike + Expected counts with timeline alignment. + Same shape as input. + First len(delay_pmf)-1 days are NaN. + """ + delay_pmf = self.temporal_pmf_rv() + ascertainment_rate = self.ascertainment_rate_rv() + + is_1d = infections.ndim == 1 + if is_1d: + infections = infections[:, jnp.newaxis] + + def convolve_col(col): # numpydoc ignore=GL08 + return self._convolve_with_alignment(col, delay_pmf, ascertainment_rate)[0] + + expected_counts = jax.vmap(convolve_col, in_axes=1, out_axes=1)(infections) + + return expected_counts[:, 0] if is_1d else expected_counts + + +class Counts(_CountBase): + """ + Aggregated count observation. + + Maps aggregate infections to counts through ascertainment x delay + convolution with composable noise model. + + Parameters + ---------- + ascertainment_rate_rv : RandomVariable + Ascertainment rate in [0, 1] (e.g., IHR, IER). + delay_distribution_rv : RandomVariable + Delay distribution PMF (must sum to ~1.0). + noise : CountNoise + Noise model (PoissonNoise, NegativeBinomialNoise, etc.). + + Notes + ----- + Output preserves input timeline. First len(delay_pmf)-1 days return + -1 or ~0 (depending on noise model) due to NaN padding. + + Examples + -------- + >>> from pyrenew.deterministic import DeterministicVariable, DeterministicPMF + >>> from pyrenew.observation import Counts, NegativeBinomialNoise + >>> import jax.numpy as jnp + >>> import numpyro + >>> + >>> delay_pmf = jnp.array([0.2, 0.5, 0.3]) + >>> counts_obs = Counts( + ... ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + ... delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + ... noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ... ) + >>> + >>> with numpyro.handlers.seed(rng_seed=42): + ... infections = jnp.ones(30) * 1000 + ... sampled_counts = counts_obs.sample(infections=infections, counts=None) + """ + + def infection_resolution(self) -> str: + """ + Return "aggregate" for aggregated observations. + + Returns + ------- + str + The string "aggregate". + """ + return "aggregate" + + def __repr__(self) -> str: + """Return string representation.""" + return ( + f"Counts(ascertainment_rate_rv={self.ascertainment_rate_rv!r}, " + f"delay_distribution_rv={self.temporal_pmf_rv!r}, " + f"noise={self.noise!r})" + ) + + def sample( + self, + infections: ArrayLike, + counts: ArrayLike | None = None, + times: ArrayLike | None = None, + ) -> ObservationSample: + """ + Sample aggregated counts with dense or sparse observations. + + Validation is performed before JAX tracing at runtime, + prior to calling this method. + + Parameters + ---------- + infections : ArrayLike + Aggregate infections from the infection process. + Shape: (n_days,) + counts : ArrayLike | None + Observed counts. Dense: (n_days,), Sparse: (n_obs,), None: prior. + times : ArrayLike | None + Day indices for sparse observations. None for dense observations. + + Returns + ------- + ObservationSample + Named tuple with `observed` (sampled/conditioned counts) and + `expected` (expected counts before noise). + """ + expected_counts = self._expected_signal(infections) + self._deterministic("expected_counts", expected_counts) + expected_counts_safe = jnp.nan_to_num(expected_counts, nan=0.0) + + # Only use sparse indexing when conditioning on observations + if times is not None and counts is not None: + expected_obs = expected_counts_safe[times] + else: + expected_obs = expected_counts_safe + + observed = self.noise.sample( + name="counts", + expected=expected_obs, + obs=counts, + ) + + return ObservationSample(observed=observed, expected=expected_counts) + + +class CountsBySubpop(_CountBase): + """ + Subpopulation-level count observation. + + Maps subpopulation-level infections to counts through + ascertainment x delay convolution with composable noise model. + + Parameters + ---------- + ascertainment_rate_rv : RandomVariable + Ascertainment rate in [0, 1]. + delay_distribution_rv : RandomVariable + Delay distribution PMF (must sum to ~1.0). + noise : CountNoise + Noise model (PoissonNoise, NegativeBinomialNoise, etc.). + + Notes + ----- + Output preserves input timeline. First len(delay_pmf)-1 days are NaN. + + Examples + -------- + >>> from pyrenew.deterministic import DeterministicVariable, DeterministicPMF + >>> from pyrenew.observation import CountsBySubpop, PoissonNoise + >>> import jax.numpy as jnp + >>> import numpyro + >>> + >>> delay_pmf = jnp.array([0.3, 0.4, 0.3]) + >>> counts_obs = CountsBySubpop( + ... ascertainment_rate_rv=DeterministicVariable("ihr", 0.02), + ... delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + ... noise=PoissonNoise(), + ... ) + >>> + >>> with numpyro.handlers.seed(rng_seed=42): + ... infections = jnp.ones((30, 3)) * 500 # 30 days, 3 subpops + ... times = jnp.array([10, 15, 10, 15]) + ... subpop_indices = jnp.array([0, 0, 1, 1]) + ... sampled = counts_obs.sample( + ... infections=infections, + ... subpop_indices=subpop_indices, + ... times=times, + ... counts=None, + ... ) + """ + + def __repr__(self) -> str: + """Return string representation.""" + return ( + f"CountsBySubpop(ascertainment_rate_rv={self.ascertainment_rate_rv!r}, " + f"delay_distribution_rv={self.temporal_pmf_rv!r}, " + f"noise={self.noise!r})" + ) + + def infection_resolution(self) -> str: + """ + Return "subpop" for subpopulation-level observations. + + Returns + ------- + str + The string "subpop". + """ + return "subpop" + + def sample( + self, + infections: ArrayLike, + subpop_indices: ArrayLike, + times: ArrayLike, + counts: ArrayLike | None = None, + ) -> ObservationSample: + """ + Sample subpopulation-level counts with flexible indexing. + + Validation is performed before JAX tracing at runtime, + prior to calling this method. + + Parameters + ---------- + infections : ArrayLike + Subpopulation-level infections from the infection process. + Shape: (n_days, n_subpops) + subpop_indices : ArrayLike + Subpopulation index for each observation (0-indexed). + Shape: (n_obs,) + times : ArrayLike + Day index for each observation (0-indexed). + Shape: (n_obs,) + counts : ArrayLike | None + Observed counts (n_obs,), or None for prior sampling. + + Returns + ------- + ObservationSample + Named tuple with `observed` (sampled/conditioned counts) and + `expected` (expected counts before noise, shape: n_days x n_subpops). + """ + # Compute expected counts for all subpops + expected_counts_all = self._expected_signal(infections) + + self._deterministic("expected_counts_by_subpop", expected_counts_all) + + # Replace NaN padding with 0 for distribution creation + expected_counts_safe = jnp.nan_to_num(expected_counts_all, nan=0.0) + expected_obs = expected_counts_safe[times, subpop_indices] + + observed = self.noise.sample( + name="counts_by_subpop", + expected=expected_obs, + obs=counts, + ) + + return ObservationSample(observed=observed, expected=expected_counts_all) diff --git a/pyrenew/observation/measurements.py b/pyrenew/observation/measurements.py new file mode 100644 index 00000000..33e38980 --- /dev/null +++ b/pyrenew/observation/measurements.py @@ -0,0 +1,144 @@ +# numpydoc ignore=GL08 +""" +Continuous measurement observation processes. + +Abstract base for any population-level continuous measurements (wastewater, +air quality, serology, etc.) with signal-specific processing. +""" + +from jax.typing import ArrayLike + +from pyrenew.metaclass import RandomVariable +from pyrenew.observation.base import BaseObservationProcess +from pyrenew.observation.noise import MeasurementNoise +from pyrenew.observation.types import ObservationSample + + +class Measurements(BaseObservationProcess): + """ + Abstract base for continuous measurement observations. + + Subclasses implement signal-specific transformations from infections + to expected measurement values, then add measurement noise. + + Parameters + ---------- + temporal_pmf_rv : RandomVariable + Temporal distribution PMF (e.g., shedding kinetics for wastewater). + noise : MeasurementNoise + Noise model for continuous measurements + (e.g., HierarchicalNormalNoise). + + Notes + ----- + Subclasses must implement ``_expected_signal()`` according to their + specific signal processing (e.g., wastewater shedding kinetics, + dilution factors, etc.). + + See Also + -------- + pyrenew.observation.noise.HierarchicalNormalNoise : + Suitable noise model for sensor-level measurements + pyrenew.observation.base.BaseObservationProcess : + Parent class with common observation utilities + """ + + def __init__( + self, + temporal_pmf_rv: RandomVariable, + noise: MeasurementNoise, + ) -> None: + """ + Initialize measurement observation base. + + Parameters + ---------- + temporal_pmf_rv : RandomVariable + Temporal distribution PMF (e.g., shedding kinetics). + noise : MeasurementNoise + Noise model (e.g., HierarchicalNormalNoise with sensor effects). + """ + super().__init__(temporal_pmf_rv=temporal_pmf_rv) + self.noise = noise + + def __repr__(self) -> str: + """Return string representation.""" + return ( + f"{self.__class__.__name__}(" + f"temporal_pmf_rv={self.temporal_pmf_rv!r}, " + f"noise={self.noise!r})" + ) + + def infection_resolution(self) -> str: + """ + Return "subpop" for measurement observations. + + Measurement observations require subpopulation-level infections + because each measurement corresponds to a specific catchment area. + + Returns + ------- + str + ``"subpop"`` + """ + return "subpop" + + def sample( + self, + infections: ArrayLike, + subpop_indices: ArrayLike, + sensor_indices: ArrayLike, + times: ArrayLike, + concentrations: ArrayLike | None, + n_sensors: int, + ) -> ObservationSample: + """ + Sample measurements from observed sensors. + + This method does not perform runtime validation of index values + (times, subpop_indices, sensor_indices). Validate observation data + before sampling. + + Transforms infections to expected values via signal-specific processing + (``_expected_signal``), then applies noise model. + + Parameters + ---------- + infections : ArrayLike + Infections from the infection process. + Shape: (n_days, n_subpops) + subpop_indices : ArrayLike + Subpopulation index for each observation (0-indexed). + Shape: (n_obs,) + sensor_indices : ArrayLike + Sensor index for each observation (0-indexed). + Shape: (n_obs,) + times : ArrayLike + Day index for each observation (0-indexed). + Shape: (n_obs,) + concentrations : ArrayLike | None + Observed measurements (n_obs,), or None for prior sampling. + n_sensors : int + Total number of measurement sensors. + + Returns + ------- + ObservationSample + Named tuple with `observed` (sampled/conditioned measurements) and + `expected` (expected values before noise, shape: n_days x n_subpops). + """ + expected_values = self._expected_signal(infections) + + self._deterministic("expected_log_conc", expected_values) + + expected_obs = expected_values[times, subpop_indices] + + observed = self.noise.sample( + name="concentrations", + expected=expected_obs, + obs=concentrations, + sensor_indices=sensor_indices, + n_sensors=n_sensors, + ) + + return ObservationSample(observed=observed, expected=expected_values) diff --git a/pyrenew/observation/noise.py b/pyrenew/observation/noise.py new file mode 100644 index 00000000..647b1781 --- /dev/null +++ b/pyrenew/observation/noise.py @@ -0,0 +1,372 @@ +# numpydoc ignore=GL08 +""" +Noise models for observation processes. + +Provides composable noise strategies for count and measurement observations, +separating the noise distribution from the observation structure. + +Count Noise +----------- +- ``PoissonNoise``: Equidispersed counts (variance = mean). No parameters. +- ``NegativeBinomialNoise``: Overdispersed counts (variance > mean). + Takes ``concentration_rv`` (higher = less overdispersion). + +Measurement Noise +----------------- +- ``HierarchicalNormalNoise``: Normal noise with hierarchical sensor effects. + Takes ``sensor_mode_prior_rv`` and ``sensor_sd_prior_rv`` for sensor-level + bias and variability. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod + +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist +from jax.typing import ArrayLike + +from pyrenew.metaclass import RandomVariable + +_EPSILON = 1e-10 + + +class CountNoise(ABC): + """ + Abstract base for count observation noise models. + + Defines how discrete count observations are distributed around expected values. + """ + + @abstractmethod + def sample( + self, + name: str, + expected: ArrayLike, + obs: ArrayLike | None = None, + ) -> ArrayLike: + """ + Sample count observations given expected counts. + + Parameters + ---------- + name : str + Numpyro sample site name. + expected : ArrayLike + Expected count values (non-negative). + obs : ArrayLike | None + Observed counts for conditioning, or None for prior sampling. + + Returns + ------- + ArrayLike + Sampled or conditioned counts, same shape as expected. + """ + pass # pragma: no cover + + @abstractmethod + def validate(self) -> None: + """ + Validate noise model parameters. + + Raises + ------ + ValueError + If parameters are invalid. + """ + pass # pragma: no cover + + +class PoissonNoise(CountNoise): + """ + Poisson noise for equidispersed counts (variance = mean). + """ + + def __init__(self) -> None: + """Initialize Poisson noise (no parameters).""" + pass + + def __repr__(self) -> str: + """Return string representation.""" + return "PoissonNoise()" + + def validate(self) -> None: + """Validate Poisson noise (always valid).""" + pass + + def sample( + self, + name: str, + expected: ArrayLike, + obs: ArrayLike | None = None, + ) -> ArrayLike: + """ + Sample from Poisson distribution. + + Parameters + ---------- + name : str + Numpyro sample site name. + expected : ArrayLike + Expected count values. + obs : ArrayLike | None + Observed counts for conditioning. + + Returns + ------- + ArrayLike + Poisson-distributed counts. + """ + return numpyro.sample( + name, + dist.Poisson(rate=expected + _EPSILON), + obs=obs, + ) + + +class NegativeBinomialNoise(CountNoise): + """ + Negative Binomial noise for overdispersed counts (variance > mean). + + Uses NB2 parameterization. Higher concentration reduces overdispersion. + + Parameters + ---------- + concentration_rv : RandomVariable + Concentration parameter (must be > 0). + Higher values reduce overdispersion. + + Notes + ----- + The NB2 parameterization has variance = mean + mean^2 / concentration. + As concentration -> infinity, this approaches Poisson. + """ + + def __init__(self, concentration_rv: RandomVariable) -> None: + """ + Initialize Negative Binomial noise. + + Parameters + ---------- + concentration_rv : RandomVariable + Concentration parameter (must be > 0). + Higher values reduce overdispersion. + """ + self.concentration_rv = concentration_rv + + def __repr__(self) -> str: + """Return string representation.""" + return f"NegativeBinomialNoise(concentration_rv={self.concentration_rv!r})" + + def validate(self) -> None: + """ + Validate concentration is positive. + + Raises + ------ + ValueError + If concentration <= 0. + """ + concentration = self.concentration_rv() + if jnp.any(concentration <= 0): + raise ValueError( + f"NegativeBinomialNoise: concentration must be positive, " + f"got {float(concentration)}" + ) + + def sample( + self, + name: str, + expected: ArrayLike, + obs: ArrayLike | None = None, + ) -> ArrayLike: + """ + Sample from Negative Binomial distribution. + + Parameters + ---------- + name : str + Numpyro sample site name. + expected : ArrayLike + Expected count values. + obs : ArrayLike | None + Observed counts for conditioning. + + Returns + ------- + ArrayLike + Negative Binomial-distributed counts. + """ + concentration = self.concentration_rv() + return numpyro.sample( + name, + dist.NegativeBinomial2( + mean=expected + _EPSILON, + concentration=concentration, + ), + obs=obs, + ) + + +class MeasurementNoise(ABC): + """ + Abstract base for continuous measurement noise models. + + Defines how continuous observations are distributed around expected values. + """ + + @abstractmethod + def sample( + self, + name: str, + expected: ArrayLike, + obs: ArrayLike | None = None, + **kwargs, + ) -> ArrayLike: + """ + Sample continuous observations given expected values. + + Parameters + ---------- + name : str + Numpyro sample site name. + expected : ArrayLike + Expected measurement values. + obs : ArrayLike | None + Observed measurements for conditioning, or None for prior sampling. + **kwargs + Additional context (e.g., sensor indices). + + Returns + ------- + ArrayLike + Sampled or conditioned measurements, same shape as expected. + """ + pass # pragma: no cover + + @abstractmethod + def validate(self) -> None: + """ + Validate noise model parameters. + + Raises + ------ + ValueError + If parameters are invalid. + """ + pass # pragma: no cover + + +class HierarchicalNormalNoise(MeasurementNoise): + """ + Normal noise with hierarchical sensor-level effects. + + Observation model: ``obs ~ Normal(expected + sensor_mode, sensor_sd)`` + where sensor_mode and sensor_sd are hierarchically modeled. + + Parameters + ---------- + sensor_mode_prior_rv : RandomVariable + Hierarchical prior for sensor-level modes (log-scale biases). + Must support ``sample(n_groups=...)`` interface. + sensor_sd_prior_rv : RandomVariable + Hierarchical prior for sensor-level SDs (must be > 0). + Must support ``sample(n_groups=...)`` interface. + + Notes + ----- + Expects data already on log scale for wastewater applications. + + See Also + -------- + pyrenew.randomvariable.HierarchicalNormalPrior : + Suitable prior for sensor_mode_prior_rv + pyrenew.randomvariable.GammaGroupSdPrior : + Suitable prior for sensor_sd_prior_rv + """ + + def __init__( + self, + sensor_mode_prior_rv: RandomVariable, + sensor_sd_prior_rv: RandomVariable, + ) -> None: + """ + Initialize hierarchical Normal noise. + + Parameters + ---------- + sensor_mode_prior_rv : RandomVariable + Hierarchical prior for sensor-level modes (log-scale biases). + Must support ``sample(n_groups=...)`` interface. + sensor_sd_prior_rv : RandomVariable + Hierarchical prior for sensor-level SDs (must be > 0). + Must support ``sample(n_groups=...)`` interface. + """ + self.sensor_mode_prior_rv = sensor_mode_prior_rv + self.sensor_sd_prior_rv = sensor_sd_prior_rv + + def __repr__(self) -> str: + """Return string representation.""" + return ( + f"HierarchicalNormalNoise(" + f"sensor_mode_prior_rv={self.sensor_mode_prior_rv!r}, " + f"sensor_sd_prior_rv={self.sensor_sd_prior_rv!r})" + ) + + def validate(self) -> None: + """ + Validate noise parameters. + + Notes + ----- + Full validation requires n_groups, which is only available during sample(). + """ + pass + + def sample( + self, + name: str, + expected: ArrayLike, + obs: ArrayLike | None = None, + *, + sensor_indices: ArrayLike, + n_sensors: int, + ) -> ArrayLike: + """ + Sample from Normal distribution with sensor-level hierarchical effects. + + Parameters + ---------- + name : str + Numpyro sample site name. + expected : ArrayLike + Expected log-scale measurement values. + Shape: (n_obs,) + obs : ArrayLike | None + Observed log-scale measurements for conditioning. + Shape: (n_obs,) + sensor_indices : ArrayLike + Sensor index for each observation (0-indexed). + Shape: (n_obs,) + n_sensors : int + Total number of sensors. + + Returns + ------- + ArrayLike + Normal distributed measurements with hierarchical sensor effects. + Shape: (n_obs,) + + Raises + ------ + ValueError + If sensor_sd samples non-positive values. + """ + sensor_mode = self.sensor_mode_prior_rv.sample(n_groups=n_sensors) + sensor_sd = self.sensor_sd_prior_rv.sample(n_groups=n_sensors) + + loc = expected + sensor_mode[sensor_indices] + scale = sensor_sd[sensor_indices] + + return numpyro.sample(name, dist.Normal(loc=loc, scale=scale), obs=obs) diff --git a/pyrenew/observation/types.py b/pyrenew/observation/types.py new file mode 100644 index 00000000..b494e2e6 --- /dev/null +++ b/pyrenew/observation/types.py @@ -0,0 +1,28 @@ +# numpydoc ignore=GL08 +""" +Return types for observation processes. + +Named tuples providing structured access to observation process outputs. +""" + +from typing import NamedTuple + +from jax.typing import ArrayLike + + +class ObservationSample(NamedTuple): + """ + Return type for observation process sample() methods. + + Attributes + ---------- + observed : ArrayLike + Sampled or conditioned observations. Shape depends on the + observation process and indexing. + expected : ArrayLike + Expected values before noise is applied. Useful for + diagnostics and posterior predictive checks. + """ + + observed: ArrayLike + expected: ArrayLike diff --git a/pyrenew/randomvariable/__init__.py b/pyrenew/randomvariable/__init__.py index c599d101..75594c58 100644 --- a/pyrenew/randomvariable/__init__.py +++ b/pyrenew/randomvariable/__init__.py @@ -5,6 +5,11 @@ DynamicDistributionalVariable, StaticDistributionalVariable, ) +from pyrenew.randomvariable.hierarchical import ( + GammaGroupSdPrior, + HierarchicalNormalPrior, + StudentTGroupModePrior, +) from pyrenew.randomvariable.transformedvariable import TransformedVariable __all__ = [ @@ -12,4 +17,7 @@ "StaticDistributionalVariable", "DynamicDistributionalVariable", "TransformedVariable", + "HierarchicalNormalPrior", + "GammaGroupSdPrior", + "StudentTGroupModePrior", ] diff --git a/pyrenew/randomvariable/hierarchical.py b/pyrenew/randomvariable/hierarchical.py new file mode 100644 index 00000000..e97126d4 --- /dev/null +++ b/pyrenew/randomvariable/hierarchical.py @@ -0,0 +1,336 @@ +# numpydoc ignore=GL08 +""" +Hierarchical prior distributions for group-level random effects. + +These classes provide random variables that sample from hierarchical +distributions with a `sample(n_groups=...)` interface, enabling +dynamic group sizes at sample time with proper numpyro plate contexts. +""" + +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist + +from pyrenew.metaclass import RandomVariable + + +class HierarchicalNormalPrior(RandomVariable): + """ + Zero-centered Normal prior for group-level effects. + + Samples n_groups values from Normal(0, sd) within a numpyro plate context. + + Parameters + ---------- + name : str + Unique name for the sampled parameter in numpyro. + sd_rv : RandomVariable + RandomVariable returning the standard deviation. + + Notes + ----- + This class is designed for hierarchical models where group effects + are assumed to be drawn from a common distribution centered at zero. + The number of groups is specified at sample time, allowing dynamic + group sizes. + + Examples + -------- + >>> from pyrenew.deterministic import DeterministicVariable + >>> from pyrenew.randomvariable import HierarchicalNormalPrior + >>> import numpyro + >>> + >>> sd_rv = DeterministicVariable("sd", 0.5) + >>> prior = HierarchicalNormalPrior("site_effects", sd_rv) + >>> + >>> with numpyro.handlers.seed(rng_seed=42): + ... effects = prior.sample(n_groups=5) + >>> effects.shape + (5,) + """ + + def __init__( + self, + name: str, + sd_rv: RandomVariable, + ) -> None: + """ + Default constructor for HierarchicalNormalPrior. + + Parameters + ---------- + name : str + Unique name for the sampled parameter in numpyro. + sd_rv : RandomVariable + RandomVariable returning the standard deviation. + + Returns + ------- + None + """ + if not isinstance(sd_rv, RandomVariable): + raise TypeError( + f"sd_rv must be a RandomVariable, got {type(sd_rv).__name__}. " + "Use DeterministicVariable(name, value) to wrap a fixed value." + ) + + self.name = name + self.sd_rv = sd_rv + + def validate(self): + """Validate the random variable (no-op for this class).""" + pass + + def sample(self, n_groups: int, **kwargs): + """ + Sample group-level effects. + + Parameters + ---------- + n_groups : int + Number of groups. + **kwargs + Additional keyword arguments (unused). + + Returns + ------- + ArrayLike + Array of shape (n_groups,) containing sampled effects. + """ + sd = self.sd_rv() + + with numpyro.plate(f"n_{self.name}", n_groups): + effects = numpyro.sample( + self.name, + dist.Normal(0.0, sd), + ) + return effects + + +class GammaGroupSdPrior(RandomVariable): + """ + Gamma prior for group-level standard deviations, bounded away from zero. + + Samples n_groups positive values from Gamma(concentration, rate) + sd_min + within a numpyro plate context. + + Parameters + ---------- + name : str + Unique name for the sampled parameter in numpyro. + sd_mean_rv : RandomVariable + RandomVariable returning the mean of the Gamma distribution. + sd_concentration_rv : RandomVariable + RandomVariable returning the concentration (shape) parameter of Gamma. + sd_min : float, default=0.05 + Minimum SD value (lower bound). + + Notes + ----- + This class parameterizes Gamma by mean and concentration rather than + shape and rate, which is often more interpretable. The rate is computed as + concentration / mean. + + The sd_min floor prevents numerical issues when standard deviations + approach zero. + + Examples + -------- + >>> from pyrenew.deterministic import DeterministicVariable + >>> from pyrenew.randomvariable import GammaGroupSdPrior + >>> import numpyro + >>> + >>> mean_rv = DeterministicVariable("sd_mean", 0.3) + >>> conc_rv = DeterministicVariable("sd_conc", 4.0) + >>> prior = GammaGroupSdPrior("site_sd", mean_rv, conc_rv, sd_min=0.05) + >>> + >>> with numpyro.handlers.seed(rng_seed=42): + ... sds = prior.sample(n_groups=5) + >>> all(sds >= 0.05) + True + """ + + def __init__( + self, + name: str, + sd_mean_rv: RandomVariable, + sd_concentration_rv: RandomVariable, + sd_min: float = 0.05, + ) -> None: + """ + Default constructor for GammaGroupSdPrior. + + Parameters + ---------- + name : str + Unique name for the sampled parameter in numpyro. + sd_mean_rv : RandomVariable + RandomVariable returning the mean of the Gamma distribution. + sd_concentration_rv : RandomVariable + RandomVariable returning the concentration (shape) parameter. + sd_min : float, default=0.05 + Minimum SD value (lower bound). + + Returns + ------- + None + """ + if not isinstance(sd_mean_rv, RandomVariable): + raise TypeError( + f"sd_mean_rv must be a RandomVariable, got {type(sd_mean_rv).__name__}. " + "Use DeterministicVariable(name, value) to wrap a fixed value." + ) + if not isinstance(sd_concentration_rv, RandomVariable): + raise TypeError( + f"sd_concentration_rv must be a RandomVariable, got {type(sd_concentration_rv).__name__}. " + "Use DeterministicVariable(name, value) to wrap a fixed value." + ) + if sd_min < 0: + raise ValueError(f"sd_min must be non-negative, got {sd_min}") + + self.name = name + self.sd_mean_rv = sd_mean_rv + self.sd_concentration_rv = sd_concentration_rv + self.sd_min = sd_min + + def validate(self): + """Validate the random variable (no-op for this class).""" + pass + + def sample(self, n_groups: int, **kwargs): + """ + Sample group-level standard deviations. + + Parameters + ---------- + n_groups : int + Number of groups. + **kwargs + Additional keyword arguments (unused). + + Returns + ------- + ArrayLike + Array of shape (n_groups,) with values >= sd_min. + """ + sd_mean = self.sd_mean_rv() + concentration = self.sd_concentration_rv() + rate = concentration / sd_mean + + with numpyro.plate(f"n_{self.name}", n_groups): + raw_sd = numpyro.sample( + f"{self.name}_raw", + dist.Gamma(concentration, rate), + ) + + group_sd = numpyro.deterministic( + self.name, + jnp.maximum(raw_sd, self.sd_min), + ) + return group_sd + + +class StudentTGroupModePrior(RandomVariable): + """ + Zero-centered Student-t prior for group-level modes (robust alternative to Normal). + + Samples n_groups values from StudentT(df, 0, sd) within a numpyro plate context. + This is useful when group effects may have heavier tails than a Normal distribution. + + Parameters + ---------- + name : str + Unique name for the sampled parameter in numpyro. + sd_rv : RandomVariable + RandomVariable returning the scale parameter. + df_rv : RandomVariable + RandomVariable returning the degrees of freedom. + + Notes + ----- + The Student-t distribution approaches the Normal distribution as df -> infinity. + Lower df values give heavier tails, making the prior more robust to outliers. + Common choices include df=3 (heavy tails) or df=7 (moderate tails). + + Examples + -------- + >>> from pyrenew.deterministic import DeterministicVariable + >>> from pyrenew.randomvariable import StudentTGroupModePrior + >>> import numpyro + >>> + >>> sd_rv = DeterministicVariable("scale", 0.5) + >>> df_rv = DeterministicVariable("df", 4.0) + >>> prior = StudentTGroupModePrior("site_modes", sd_rv, df_rv) + >>> + >>> with numpyro.handlers.seed(rng_seed=42): + ... modes = prior.sample(n_groups=5) + >>> modes.shape + (5,) + """ + + def __init__( + self, + name: str, + sd_rv: RandomVariable, + df_rv: RandomVariable, + ) -> None: + """ + Default constructor for StudentTGroupModePrior. + + Parameters + ---------- + name : str + Unique name for the sampled parameter in numpyro. + sd_rv : RandomVariable + RandomVariable returning the scale parameter. + df_rv : RandomVariable + RandomVariable returning the degrees of freedom. + + Returns + ------- + None + """ + if not isinstance(sd_rv, RandomVariable): + raise TypeError( + f"sd_rv must be a RandomVariable, got {type(sd_rv).__name__}. " + "Use DeterministicVariable(name, value) to wrap a fixed value." + ) + if not isinstance(df_rv, RandomVariable): + raise TypeError( + f"df_rv must be a RandomVariable, got {type(df_rv).__name__}. " + "Use DeterministicVariable(name, value) to wrap a fixed value." + ) + + self.name = name + self.sd_rv = sd_rv + self.df_rv = df_rv + + def validate(self): + """Validate the random variable (no-op for this class).""" + pass + + def sample(self, n_groups: int, **kwargs): + """ + Sample group-level modes. + + Parameters + ---------- + n_groups : int + Number of groups. + **kwargs + Additional keyword arguments (unused). + + Returns + ------- + ArrayLike + Array of shape (n_groups,) containing sampled modes. + """ + sd = self.sd_rv() + df = self.df_rv() + + with numpyro.plate(f"n_{self.name}", n_groups): + effects = numpyro.sample( + self.name, + dist.StudentT(df=df, loc=0.0, scale=sd), + ) + return effects diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 00000000..6f55681a --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,406 @@ +""" +Shared pytest fixtures for PyRenew tests. + +This module provides reusable fixtures for creating observation processes, +test data, and common configurations used across multiple test files. +""" + +import jax.numpy as jnp +import pytest + +from pyrenew.deterministic import DeterministicPMF, DeterministicVariable +from pyrenew.observation import Counts, NegativeBinomialNoise +from pyrenew.randomvariable import GammaGroupSdPrior, HierarchicalNormalPrior + +# ============================================================================= +# PMF Fixtures +# ============================================================================= + + +@pytest.fixture +def simple_delay_pmf(): + """ + Simple 1-day delay PMF (no delay). + + Returns + ------- + jnp.ndarray + A single-element PMF array representing no delay. + """ + return jnp.array([1.0]) + + +@pytest.fixture +def short_delay_pmf(): + """ + Short 2-day delay PMF. + + Returns + ------- + jnp.ndarray + A 2-element PMF array. + """ + return jnp.array([0.5, 0.5]) + + +@pytest.fixture +def medium_delay_pmf(): + """ + Medium 4-day delay PMF. + + Returns + ------- + jnp.ndarray + A 4-element PMF array. + """ + return jnp.array([0.1, 0.3, 0.4, 0.2]) + + +@pytest.fixture +def realistic_delay_pmf(): + """ + Realistic 10-day delay PMF (shifted gamma-like). + + Returns + ------- + jnp.ndarray + A 10-element PMF array with gamma-like shape. + """ + return jnp.array([0.01, 0.05, 0.10, 0.15, 0.20, 0.20, 0.15, 0.08, 0.04, 0.02]) + + +@pytest.fixture +def long_delay_pmf(): + """ + Long 10-day delay PMF for edge case testing. + + Returns + ------- + jnp.ndarray + A 10-element PMF array. + """ + return jnp.array([0.05, 0.1, 0.15, 0.2, 0.2, 0.15, 0.1, 0.03, 0.01, 0.01]) + + +@pytest.fixture +def simple_shedding_pmf(): + """ + Simple 1-day shedding PMF (no delay). + + Returns + ------- + jnp.ndarray + A single-element PMF array representing no shedding delay. + """ + return jnp.array([1.0]) + + +@pytest.fixture +def short_shedding_pmf(): + """ + Short 3-day shedding PMF. + + Returns + ------- + jnp.ndarray + A 3-element PMF array. + """ + return jnp.array([0.3, 0.4, 0.3]) + + +@pytest.fixture +def medium_shedding_pmf(): + """ + Medium 5-day shedding PMF. + + Returns + ------- + jnp.ndarray + A 5-element PMF array. + """ + return jnp.array([0.1, 0.3, 0.3, 0.2, 0.1]) + + +# ============================================================================= +# Hierarchical Prior Fixtures +# ============================================================================= + + +@pytest.fixture +def sensor_mode_prior(): + """ + Standard hierarchical normal prior for sensor modes. + + Returns + ------- + HierarchicalNormalPrior + A hierarchical normal prior with standard deviation 0.5. + """ + return HierarchicalNormalPrior( + name="ww_sensor_mode", sd_rv=DeterministicVariable("mode_sd", 0.5) + ) + + +@pytest.fixture +def sensor_mode_prior_tight(): + """ + Tight hierarchical normal prior for deterministic-like behavior. + + Returns + ------- + HierarchicalNormalPrior + A hierarchical normal prior with small standard deviation 0.01. + """ + return HierarchicalNormalPrior( + name="ww_sensor_mode", sd_rv=DeterministicVariable("mode_sd_tight", 0.01) + ) + + +@pytest.fixture +def sensor_sd_prior(): + """ + Standard gamma prior for sensor standard deviations. + + Returns + ------- + GammaGroupSdPrior + A gamma prior for group standard deviations. + """ + return GammaGroupSdPrior( + name="ww_sensor_sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.3), + sd_concentration_rv=DeterministicVariable("sd_concentration", 4.0), + sd_min=0.10, + ) + + +@pytest.fixture +def sensor_sd_prior_tight(): + """ + Tight gamma prior for deterministic-like behavior. + + Returns + ------- + GammaGroupSdPrior + A gamma prior with small mean for tight behavior. + """ + return GammaGroupSdPrior( + name="ww_sensor_sd", + sd_mean_rv=DeterministicVariable("sd_mean_tight", 0.01), + sd_concentration_rv=DeterministicVariable("sd_concentration_tight", 4.0), + sd_min=0.005, + ) + + +# ============================================================================= +# Counts Process Fixtures +# ============================================================================= + + +@pytest.fixture +def counts_process(simple_delay_pmf): + """ + Standard Counts observation process with simple delay. + + Returns + ------- + Counts + A Counts observation process with no delay. + """ + return Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + + +@pytest.fixture +def counts_process_medium_delay(medium_delay_pmf): + """ + Counts observation process with medium delay. + + Returns + ------- + Counts + A Counts observation process with 4-day delay. + """ + return Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", medium_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 50.0)), + ) + + +@pytest.fixture +def counts_process_realistic(realistic_delay_pmf): + """ + Counts observation process with realistic delay and ascertainment. + + Returns + ------- + Counts + A Counts observation process with realistic parameters. + """ + return Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.005), + delay_distribution_rv=DeterministicPMF("delay", realistic_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 100.0)), + ) + + +class CountsProcessFactory: + """Factory for creating Counts processes with custom parameters.""" + + @staticmethod + def create( + delay_pmf=None, + ascertainment_rate=0.01, + concentration=10.0, + ): + """ + Create a Counts process with specified parameters. + + Returns + ------- + Counts + A Counts observation process with the specified parameters. + """ + if delay_pmf is None: + delay_pmf = jnp.array([1.0]) + return Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", ascertainment_rate), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", concentration)), + ) + + +@pytest.fixture +def counts_factory(): + """ + Factory fixture for creating custom Counts processes. + + Returns + ------- + CountsProcessFactory + A factory for creating Counts processes. + """ + return CountsProcessFactory() + + +# ============================================================================= +# Infection Fixtures +# ============================================================================= + + +@pytest.fixture +def constant_infections(): + """ + Constant infections array (30 days, 100 infections/day). + + Returns + ------- + jnp.ndarray + A 1D array of shape (30,) with constant value 100. + """ + return jnp.ones(30) * 100 + + +@pytest.fixture +def constant_infections_2d(): + """ + Constant infections array for 2 subpopulations. + + Returns + ------- + jnp.ndarray + A 2D array of shape (30, 2) with constant value 100. + """ + return jnp.ones((30, 2)) * 100 + + +def make_infections(n_days, n_subpops=None, value=100.0): + """ + Create infection arrays for testing. + + Parameters + ---------- + n_days : int + Number of days + n_subpops : int, optional + Number of subpopulations (None for 1D array) + value : float + Constant infection value + + Returns + ------- + jnp.ndarray + Infections array + """ + if n_subpops is None: + return jnp.ones(n_days) * value + return jnp.ones((n_days, n_subpops)) * value + + +def make_spike_infections(n_days, spike_day, spike_value=1000.0, n_subpops=None): + """ + Create spike infection arrays for testing. + + Parameters + ---------- + n_days : int + Number of days + spike_day : int + Day of the spike + spike_value : float + Value at spike + n_subpops : int, optional + Number of subpopulations + + Returns + ------- + jnp.ndarray + Infections array with spike + """ + if n_subpops is None: + infections = jnp.zeros(n_days) + return infections.at[spike_day].set(spike_value) + infections = jnp.zeros((n_days, n_subpops)) + return infections.at[spike_day, :].set(spike_value) + + +def create_mock_infections( + n_days: int, + peak_day: int = 10, + peak_value: float = 1000.0, + shape: str = "spike", +) -> jnp.ndarray: + """ + Create mock infection time series for testing. + + Parameters + ---------- + n_days : int + Number of days + peak_day : int + Day of peak infections + peak_value : float + Peak infection value + shape : str + Shape of the curve: "spike", "constant", or "decay" + + Returns + ------- + jnp.ndarray + Array of infections of shape (n_days,) + """ + if shape == "spike": + infections = jnp.zeros(n_days) + infections = infections.at[peak_day].set(peak_value) + elif shape == "constant": + infections = jnp.ones(n_days) * peak_value + elif shape == "decay": + infections = peak_value * jnp.exp(-jnp.arange(n_days) / 20.0) + else: + raise ValueError(f"Unknown shape: {shape}") + + return infections diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py new file mode 100644 index 00000000..ddb4d684 --- /dev/null +++ b/test/test_observation_counts.py @@ -0,0 +1,543 @@ +""" +Unit tests for Counts (aggregated count observations). + +These tests validate the count observation process implementation. +""" + +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist +import pytest + +from pyrenew.deterministic import DeterministicPMF, DeterministicVariable +from pyrenew.observation import ( + Counts, + CountsBySubpop, + NegativeBinomialNoise, + PoissonNoise, +) +from pyrenew.observation.count_observations import _CountBase +from pyrenew.randomvariable import DistributionalVariable + + +def create_mock_infections( + n_days: int, + peak_day: int = 10, + peak_value: float = 1000.0, + shape: str = "spike", +) -> jnp.ndarray: + """ + Create mock infection time series for testing. + + Parameters + ---------- + n_days : int + Number of days + peak_day : int + Day of peak infections + peak_value : float + Peak infection value + shape : str + Shape of the curve: "spike", "constant", or "decay" + + Returns + ------- + jnp.ndarray + Array of infections of shape (n_days,) + """ + if shape == "spike": + infections = jnp.zeros(n_days) + infections = infections.at[peak_day].set(peak_value) + elif shape == "constant": + infections = jnp.ones(n_days) * peak_value + elif shape == "decay": + infections = peak_value * jnp.exp(-jnp.arange(n_days) / 20.0) + else: + raise ValueError(f"Unknown shape: {shape}") + + return infections + + +class TestCountsBasics: + """Test basic functionality of aggregated count observation process.""" + + def test_sample_returns_correct_shape(self, counts_process): + """Test that sample returns correct shape.""" + infections = jnp.ones(30) * 100 + + with numpyro.handlers.seed(rng_seed=42): + result = counts_process.sample( + infections=infections, + counts=None, + ) + + assert result.observed.shape[0] > 0 + assert result.observed.ndim == 1 + assert result.expected.shape == infections.shape + + def test_delay_convolution(self, counts_factory, short_delay_pmf): + """Test that delay is properly applied.""" + process = counts_factory.create(delay_pmf=short_delay_pmf) + + infections = jnp.zeros(30) + infections = infections.at[10].set(1000) + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample( + infections=infections, + counts=None, + ) + + # Timeline alignment: output length equals input length + assert result.observed.shape[0] == len(infections) + # First len(delay_pmf)-1 days are NaN (appear as -1 after NegativeBinomial sampling) + assert jnp.all(result.observed[1:] >= 0) + assert jnp.sum(result.observed[result.observed >= 0]) > 0 + + def test_ascertainment_scaling(self, counts_factory, simple_delay_pmf): + """Test that ascertainment rate properly scales counts.""" + infections = jnp.ones(20) * 100 + + results = [] + for rate_value in [0.01, 0.02, 0.05]: + process = counts_factory.create( + delay_pmf=simple_delay_pmf, + ascertainment_rate=rate_value, + ) + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample( + infections=infections, + counts=None, + ) + results.append(jnp.mean(result.observed)) + + # Higher ascertainment rate should lead to more counts + assert results[1] > results[0] + assert results[2] > results[1] + + def test_negative_binomial_observation(self, counts_factory, simple_delay_pmf): + """Test that negative binomial observation is used.""" + process = counts_factory.create( + delay_pmf=simple_delay_pmf, + concentration=5.0, + ) + + infections = jnp.ones(20) * 100 + + samples = [] + for seed in range(5): + with numpyro.handlers.seed(rng_seed=seed): + result = process.sample( + infections=infections, + counts=None, + ) + samples.append(jnp.sum(result.observed)) + + # Should have some variability due to negative binomial sampling + assert jnp.std(jnp.array(samples)) > 0 + + +class TestCountsWithPriors: + """Test aggregated count observation with uncertain parameters.""" + + def test_with_stochastic_ascertainment(self, short_shedding_pmf): + """Test with uncertain ascertainment rate parameter.""" + delay = DeterministicPMF("delay", jnp.array([0.2, 0.5, 0.3])) + ascertainment = DistributionalVariable("ihr", dist.Beta(2, 100)) + concentration = DeterministicVariable("conc", 10.0) + + process = Counts( + ascertainment_rate_rv=ascertainment, + delay_distribution_rv=delay, + noise=NegativeBinomialNoise(concentration), + ) + + infections = jnp.ones(20) * 100 + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample( + infections=infections, + counts=None, + ) + + assert result.observed.shape[0] > 0 + # Skip NaN padding + valid_counts = result.observed[2:] + assert jnp.all(valid_counts >= 0) + + def test_with_stochastic_concentration(self, simple_delay_pmf): + """Test with uncertain concentration parameter.""" + delay = DeterministicPMF("delay", simple_delay_pmf) + ascertainment = DeterministicVariable("ihr", 0.01) + concentration = DistributionalVariable("conc", dist.HalfNormal(10.0)) + + process = Counts( + ascertainment_rate_rv=ascertainment, + delay_distribution_rv=delay, + noise=NegativeBinomialNoise(concentration), + ) + + infections = jnp.ones(20) * 100 + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample( + infections=infections, + counts=None, + ) + + assert result.observed.shape[0] > 0 + assert jnp.all(result.observed >= 0) + + +class TestCountsEdgeCases: + """Test edge cases and error handling.""" + + def test_zero_infections(self, counts_process): + """Test with zero infections.""" + infections = jnp.zeros(20) + + with numpyro.handlers.seed(rng_seed=42): + result = counts_process.sample( + infections=infections, + counts=None, + ) + + assert result.observed.shape[0] > 0 + + def test_small_infections(self, counts_process): + """Test with small infection values.""" + infections = jnp.ones(20) * 10 + + with numpyro.handlers.seed(rng_seed=42): + result = counts_process.sample( + infections=infections, + counts=None, + ) + + assert result.observed.shape[0] > 0 + assert jnp.all(result.observed >= 0) + + def test_long_delay_distribution(self, counts_factory, long_delay_pmf): + """Test with longer delay distribution.""" + process = counts_factory.create(delay_pmf=long_delay_pmf) + + infections = create_mock_infections(40, peak_day=20, shape="spike") + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample( + infections=infections, + counts=None, + ) + + # Timeline alignment maintained + assert result.observed.shape[0] == infections.shape[0] + # Skip NaN padding: 10-day delay -> first 9 days are NaN + valid_counts = result.observed[9:] + assert jnp.sum(valid_counts) > 0 + + +class TestCountsSparseObservations: + """Test sparse observation support.""" + + def test_sparse_observations(self, counts_process): + """Test with sparse (irregular) observations.""" + n_days = 30 + infections = jnp.ones(n_days) * 100 + + # Sparse observations: only days 5, 10, 15, 20 + times = jnp.array([5, 10, 15, 20]) + counts_data = jnp.array([10, 12, 8, 15]) + + with numpyro.handlers.seed(rng_seed=42): + result = counts_process.sample( + infections=infections, + counts=counts_data, + times=times, + ) + + assert result.observed.shape == times.shape + assert jnp.allclose(result.observed, counts_data) + + def test_sparse_vs_dense_sampling(self, counts_process): + """Test that sparse sampling gives different output shape than dense.""" + n_days = 30 + infections = jnp.ones(n_days) * 100 + + # Dense: prior sampling (counts=None, no times) + with numpyro.handlers.seed(rng_seed=42): + dense_result = counts_process.sample( + infections=infections, + counts=None, + ) + + # Sparse with observed data: only some days + times = jnp.array([5, 10, 15, 20]) + sparse_obs_data = jnp.array([10, 12, 8, 15]) + with numpyro.handlers.seed(rng_seed=42): + sparse_result = counts_process.sample( + infections=infections, + counts=sparse_obs_data, + times=times, + ) + + # Dense prior produces full length output + assert dense_result.observed.shape == (n_days,) + + # Sparse observations produce output matching times shape + assert sparse_result.observed.shape == times.shape + assert jnp.allclose(sparse_result.observed, sparse_obs_data) + + def test_prior_sampling_ignores_times(self, counts_process): + """Test that times parameter is ignored when counts=None (prior sampling).""" + n_days = 30 + infections = jnp.ones(n_days) * 100 + times = jnp.array([5, 10, 15, 20]) + + # When counts=None, times is ignored - output is dense + with numpyro.handlers.seed(rng_seed=42): + result_with_times = counts_process.sample( + infections=infections, + counts=None, + times=times, + ) + + with numpyro.handlers.seed(rng_seed=42): + result_without_times = counts_process.sample( + infections=infections, + counts=None, + ) + + # Both should produce dense output of shape (n_days,) + assert result_with_times.observed.shape == (n_days,) + assert result_without_times.observed.shape == (n_days,) + # With same seed, outputs should be identical + assert jnp.allclose(result_with_times.observed, result_without_times.observed) + + +class TestCountsBySubpop: + """Test CountsBySubpop for subpopulation-level observations.""" + + def test_sample_returns_correct_shape(self): + """Test that CountsBySubpop sample returns correct shape.""" + delay_pmf = jnp.array([0.3, 0.4, 0.3]) + process = CountsBySubpop( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.02), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=PoissonNoise(), + ) + + infections = jnp.ones((30, 3)) * 500 # 30 days, 3 subpops + times = jnp.array([10, 15, 10, 15]) + subpop_indices = jnp.array([0, 0, 1, 1]) + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample( + infections=infections, + subpop_indices=subpop_indices, + times=times, + counts=None, + ) + + assert result.observed.shape == times.shape + assert result.expected.shape == infections.shape + + def test_infection_resolution(self): + """Test that CountsBySubpop returns 'subpop' resolution.""" + delay_pmf = jnp.array([1.0]) + process = CountsBySubpop( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=PoissonNoise(), + ) + + assert process.infection_resolution() == "subpop" + + +class TestPoissonNoise: + """Test PoissonNoise model.""" + + def test_poisson_counts(self, simple_delay_pmf): + """Test Counts with Poisson noise.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + ) + + infections = jnp.ones(20) * 1000 + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample( + infections=infections, + counts=None, + ) + + assert result.observed.shape[0] == 20 + assert jnp.all(result.observed >= 0) + + +class TestCountBaseInternalMethods: + """Test internal _CountBase methods for coverage.""" + + def test_count_base_infection_resolution_raises(self, simple_delay_pmf): + """Test that _CountBase.infection_resolution() raises NotImplementedError.""" + + # Create a subclass that doesn't override infection_resolution + class IncompleteCountProcess(_CountBase): + """Incomplete count process for testing.""" + + def sample(self, **kwargs): + """Sample method stub.""" + pass + + process = IncompleteCountProcess( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + with pytest.raises( + NotImplementedError, match="Subclasses must implement infection_resolution" + ): + process.infection_resolution() + + +class TestValidationMethods: + """Test validation methods for coverage.""" + + def test_validate_calls_all_validations(self, simple_delay_pmf): + """Test that validate() calls all necessary validations.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + # Should not raise + process.validate() + + def test_validate_invalid_ascertainment_rate_negative(self, simple_delay_pmf): + """Test that validate raises for negative ascertainment rate.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", -0.1), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + with pytest.raises(ValueError, match="ascertainment_rate_rv must be in"): + process.validate() + + def test_validate_invalid_ascertainment_rate_greater_than_one( + self, simple_delay_pmf + ): + """Test that validate raises for ascertainment rate > 1.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 1.5), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + with pytest.raises(ValueError, match="ascertainment_rate_rv must be in"): + process.validate() + + def test_lookback_days(self, simple_delay_pmf, long_delay_pmf): + """Test lookback_days returns PMF length.""" + process_short = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + assert process_short.lookback_days() == 1 + + process_long = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", long_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + assert process_long.lookback_days() == 10 + + def test_infection_resolution_counts(self, simple_delay_pmf): + """Test that Counts returns 'aggregate' resolution.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + assert process.infection_resolution() == "aggregate" + + +class TestNoiseValidation: + """Test noise model validation methods.""" + + def test_poisson_noise_validate(self): + """Test PoissonNoise validate method.""" + noise = PoissonNoise() + # Should not raise - Poisson has no parameters to validate + noise.validate() + + def test_negative_binomial_noise_validate_success(self): + """Test NegativeBinomialNoise validate with valid concentration.""" + noise = NegativeBinomialNoise(DeterministicVariable("conc", 10.0)) + # Should not raise + noise.validate() + + def test_negative_binomial_noise_validate_zero_concentration(self): + """Test NegativeBinomialNoise validate with zero concentration.""" + noise = NegativeBinomialNoise(DeterministicVariable("conc", 0.0)) + with pytest.raises(ValueError, match="concentration must be positive"): + noise.validate() + + def test_negative_binomial_noise_validate_negative_concentration(self): + """Test NegativeBinomialNoise validate with negative concentration.""" + noise = NegativeBinomialNoise(DeterministicVariable("conc", -1.0)) + with pytest.raises(ValueError, match="concentration must be positive"): + noise.validate() + + +class TestBaseObservationProcessValidation: + """Test base observation process PMF validation.""" + + def test_validate_pmf_empty_array(self, simple_delay_pmf): + """Test that _validate_pmf raises for empty array.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + empty_pmf = jnp.array([]) + with pytest.raises(ValueError, match="must return non-empty array"): + process._validate_pmf(empty_pmf, "test_pmf") + + def test_validate_pmf_sum_not_one(self, simple_delay_pmf): + """Test that _validate_pmf raises for PMF not summing to 1.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + bad_pmf = jnp.array([0.3, 0.3, 0.3]) # sums to 0.9 + with pytest.raises(ValueError, match="must sum to 1.0"): + process._validate_pmf(bad_pmf, "test_pmf") + + def test_validate_pmf_negative_values(self, simple_delay_pmf): + """Test that _validate_pmf raises for negative values.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + bad_pmf = jnp.array([1.5, -0.5]) # sums to 1.0 but has negative + with pytest.raises(ValueError, match="must have non-negative values"): + process._validate_pmf(bad_pmf, "test_pmf") + + def test_get_minimum_observation_day(self): + """Test get_minimum_observation_day returns correct value.""" + delay_pmf = jnp.array([0.2, 0.5, 0.3]) # length 3 + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + # First valid day should be len(pmf) - 1 = 2 + assert process.get_minimum_observation_day() == 2 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/test_observation_measurements.py b/test/test_observation_measurements.py new file mode 100644 index 00000000..e5c014a2 --- /dev/null +++ b/test/test_observation_measurements.py @@ -0,0 +1,249 @@ +""" +Unit tests for Measurements (continuous measurement observations). + +These tests validate the measurement observation process base class implementation. +""" + +import jax.numpy as jnp +import numpyro +import pytest + +from pyrenew.deterministic import DeterministicPMF, DeterministicVariable +from pyrenew.observation import HierarchicalNormalNoise, Measurements +from pyrenew.observation.base import BaseObservationProcess +from pyrenew.randomvariable import GammaGroupSdPrior, HierarchicalNormalPrior + + +class ConcreteMeasurements(Measurements): + """Concrete implementation of Measurements for testing.""" + + def __init__(self, temporal_pmf_rv, noise, log10_scale=9.0): + """Initialize the concrete measurements for testing.""" + super().__init__(temporal_pmf_rv=temporal_pmf_rv, noise=noise) + self.log10_scale = log10_scale + + def validate(self) -> None: + """Validate parameters.""" + pmf = self.temporal_pmf_rv() + self._validate_pmf(pmf, "temporal_pmf_rv") + + def lookback_days(self) -> int: + """ + Return temporal PMF length. + + Returns + ------- + int + Length of the temporal PMF. + """ + return len(self.temporal_pmf_rv()) + + def _expected_signal(self, infections): + """ + Simple expected signal: log(convolution * scale). + + Returns + ------- + jnp.ndarray + Log-transformed expected signal. + """ + pmf = self.temporal_pmf_rv() + + # Handle 2D infections (n_days, n_subpops) + if infections.ndim == 1: + infections = infections[:, jnp.newaxis] + + def convolve_col(col): # numpydoc ignore=GL08 + return self._convolve_with_alignment(col, pmf, 1.0)[0] + + import jax + + expected = jax.vmap(convolve_col, in_axes=1, out_axes=1)(infections) + + # Apply log10 scaling (simplified from wastewater model) + log_expected = jnp.log(expected + 1e-10) + self.log10_scale * jnp.log(10) + + return log_expected + + +class TestMeasurementsBase: + """Test Measurements abstract base class.""" + + def test_is_base_observation_process(self): + """Test that Measurements inherits from BaseObservationProcess.""" + assert issubclass(Measurements, BaseObservationProcess) + + def test_infection_resolution_is_subpop(self): + """Test that Measurements returns 'subpop' resolution.""" + shedding_pmf = jnp.array([0.3, 0.4, 0.3]) + sensor_mode_prior = HierarchicalNormalPrior( + name="mode", sd_rv=DeterministicVariable("sd", 0.5) + ) + sensor_sd_prior = GammaGroupSdPrior( + name="sd", + sd_mean_rv=DeterministicVariable("mean", 0.3), + sd_concentration_rv=DeterministicVariable("conc", 4.0), + ) + noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) + + process = ConcreteMeasurements( + temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), + noise=noise, + ) + + assert process.infection_resolution() == "subpop" + + +class TestHierarchicalNormalNoise: + """Test HierarchicalNormalNoise model.""" + + def test_validate(self): + """Test HierarchicalNormalNoise validate method.""" + sensor_mode_prior = HierarchicalNormalPrior( + name="mode", sd_rv=DeterministicVariable("mode_sd", 0.5) + ) + sensor_sd_prior = GammaGroupSdPrior( + name="sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.3), + sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + ) + noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) + # Should not raise - validation is deferred to sample time + noise.validate() + + def test_sample_shape(self): + """Test that HierarchicalNormalNoise produces correct shape.""" + sensor_mode_prior = HierarchicalNormalPrior( + name="mode", sd_rv=DeterministicVariable("mode_sd", 0.5) + ) + sensor_sd_prior = GammaGroupSdPrior( + name="sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.3), + sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + ) + noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) + + expected = jnp.array([1.0, 2.0, 3.0, 4.0]) + sensor_indices = jnp.array([0, 0, 1, 1]) + + with numpyro.handlers.seed(rng_seed=42): + samples = noise.sample( + name="test", + expected=expected, + obs=None, + sensor_indices=sensor_indices, + n_sensors=2, + ) + + assert samples.shape == expected.shape + + def test_sample_with_observations(self): + """Test that HierarchicalNormalNoise conditions on observations.""" + sensor_mode_prior = HierarchicalNormalPrior( + name="mode", sd_rv=DeterministicVariable("mode_sd", 0.5) + ) + sensor_sd_prior = GammaGroupSdPrior( + name="sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.3), + sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + ) + noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) + + expected = jnp.array([1.0, 2.0, 3.0, 4.0]) + obs = jnp.array([1.1, 2.1, 3.1, 4.1]) + sensor_indices = jnp.array([0, 0, 1, 1]) + + with numpyro.handlers.seed(rng_seed=42): + samples = noise.sample( + name="test", + expected=expected, + obs=obs, + sensor_indices=sensor_indices, + n_sensors=2, + ) + + # When obs is provided, samples should equal obs + assert jnp.allclose(samples, obs) + + +class TestConcreteMeasurements: + """Test concrete Measurements implementation.""" + + def test_sample_shape(self): + """Test that sample returns correct shape.""" + shedding_pmf = jnp.array([0.3, 0.4, 0.3]) + sensor_mode_prior = HierarchicalNormalPrior( + name="mode", sd_rv=DeterministicVariable("mode_sd", 0.5) + ) + sensor_sd_prior = GammaGroupSdPrior( + name="sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.3), + sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + ) + noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) + + process = ConcreteMeasurements( + temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), + noise=noise, + ) + + # 30 days, 2 subpops + infections = jnp.ones((30, 2)) * 1000 + subpop_indices = jnp.array([0, 0, 1, 1]) + sensor_indices = jnp.array([0, 0, 1, 1]) + times = jnp.array([10, 15, 10, 15]) + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample( + infections=infections, + subpop_indices=subpop_indices, + sensor_indices=sensor_indices, + times=times, + concentrations=None, + n_sensors=2, + ) + + assert result.observed.shape == times.shape + assert result.expected.shape == infections.shape + + def test_expected_signal_stored(self): + """Test that expected_log_conc is stored as deterministic.""" + shedding_pmf = jnp.array([0.5, 0.5]) + sensor_mode_prior = HierarchicalNormalPrior( + name="mode", sd_rv=DeterministicVariable("mode_sd", 0.01) + ) + sensor_sd_prior = GammaGroupSdPrior( + name="sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.01), + sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + sd_min=0.001, + ) + noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) + + process = ConcreteMeasurements( + temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), + noise=noise, + ) + + infections = jnp.ones((20, 2)) * 1000 + subpop_indices = jnp.array([0, 1]) + sensor_indices = jnp.array([0, 1]) + times = jnp.array([10, 10]) + + with numpyro.handlers.seed(rng_seed=42): + trace = numpyro.handlers.trace( + lambda: process.sample( + infections=infections, + subpop_indices=subpop_indices, + sensor_indices=sensor_indices, + times=times, + concentrations=None, + n_sensors=2, + ) + ).get_trace() + + assert "expected_log_conc" in trace + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/test_observation_poisson.py b/test/test_observation_poisson.py index b9d975be..8b9c0716 100644 --- a/test/test_observation_poisson.py +++ b/test/test_observation_poisson.py @@ -20,3 +20,10 @@ def test_poisson_obs(): sim_pois = pois(mu=rates) testing.assert_array_equal(sim_pois, jnp.ceil(sim_pois)) + + +def test_poisson_validate(): + """ + Check that PoissonObservation.validate() runs without error. + """ + PoissonObservation.validate() diff --git a/test/test_randomvariable_hierarchical.py b/test/test_randomvariable_hierarchical.py new file mode 100644 index 00000000..89a58c3a --- /dev/null +++ b/test/test_randomvariable_hierarchical.py @@ -0,0 +1,204 @@ +"""Unit tests for hierarchical prior distributions.""" + +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist +import pytest + +from pyrenew.deterministic import DeterministicVariable +from pyrenew.randomvariable import ( + DistributionalVariable, + GammaGroupSdPrior, + HierarchicalNormalPrior, + StudentTGroupModePrior, +) + + +class TestHierarchicalNormalPrior: + """Test HierarchicalNormalPrior.""" + + def test_sample_shape(self): + """Test that sample returns correct shape.""" + prior = HierarchicalNormalPrior( + "effect", sd_rv=DeterministicVariable("sd", 1.0) + ) + + with numpyro.handlers.seed(rng_seed=42): + samples = prior.sample(n_groups=5) + + assert samples.shape == (5,) + + def test_smaller_sd_produces_tighter_distribution(self): + """Test that smaller sd produces samples closer to zero.""" + prior_tight = HierarchicalNormalPrior( + "a", sd_rv=DeterministicVariable("sd_tight", 0.1) + ) + prior_wide = HierarchicalNormalPrior( + "b", sd_rv=DeterministicVariable("sd_wide", 10.0) + ) + + n_samples = 1000 + with numpyro.handlers.seed(rng_seed=42): + samples_tight = prior_tight.sample(n_groups=n_samples) + with numpyro.handlers.seed(rng_seed=43): + samples_wide = prior_wide.sample(n_groups=n_samples) + + # Tight prior should have smaller standard deviation + assert jnp.std(samples_tight) < jnp.std(samples_wide) + + def test_validate(self): + """Test that validate() runs without error.""" + prior = HierarchicalNormalPrior( + "effect", sd_rv=DeterministicVariable("sd", 1.0) + ) + prior.validate() # Should not raise + + def test_rejects_non_random_variable_sd(self): + """Test that non-RandomVariable sd_rv is rejected.""" + with pytest.raises(TypeError, match="sd_rv must be a RandomVariable"): + HierarchicalNormalPrior("effect", sd_rv=1.0) + + def test_accepts_distributional_variable_for_sd(self): + """Test that DistributionalVariable can be used for sd_rv.""" + sd_rv = DistributionalVariable("sd", dist.HalfNormal(1.0)) + prior = HierarchicalNormalPrior("effect", sd_rv=sd_rv) + + with numpyro.handlers.seed(rng_seed=42): + samples = prior.sample(n_groups=5) + + assert samples.shape == (5,) + + +class TestGammaGroupSdPrior: + """Test GammaGroupSdPrior.""" + + def test_sample_shape(self): + """Test that sample returns correct shape.""" + prior = GammaGroupSdPrior( + "sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.5), + sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + ) + + with numpyro.handlers.seed(rng_seed=42): + samples = prior.sample(n_groups=5) + + assert samples.shape == (5,) + + def test_respects_sd_min(self): + """Test that sd_min is enforced as lower bound.""" + prior = GammaGroupSdPrior( + "sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.1), + sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + sd_min=0.5, + ) + + with numpyro.handlers.seed(rng_seed=42): + samples = prior.sample(n_groups=100) + + assert jnp.all(samples >= 0.5) + + def test_rejects_non_random_variable_params(self): + """Test that non-RandomVariable parameters are rejected.""" + with pytest.raises(TypeError, match="sd_mean_rv must be a RandomVariable"): + GammaGroupSdPrior( + "sd", + sd_mean_rv=0.5, + sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + ) + + with pytest.raises( + TypeError, match="sd_concentration_rv must be a RandomVariable" + ): + GammaGroupSdPrior( + "sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.5), + sd_concentration_rv=4.0, + ) + + def test_rejects_negative_sd_min(self): + """Test that negative sd_min is rejected.""" + with pytest.raises(ValueError, match="sd_min must be non-negative"): + GammaGroupSdPrior( + "sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.5), + sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + sd_min=-0.1, + ) + + def test_validate(self): + """Test that validate() runs without error.""" + prior = GammaGroupSdPrior( + "sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.5), + sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + ) + prior.validate() # Should not raise + + +class TestStudentTGroupModePrior: + """Test StudentTGroupModePrior.""" + + def test_sample_shape(self): + """Test that sample returns correct shape.""" + prior = StudentTGroupModePrior( + "mode", + sd_rv=DeterministicVariable("sd", 1.0), + df_rv=DeterministicVariable("df", 4.0), + ) + + with numpyro.handlers.seed(rng_seed=42): + samples = prior.sample(n_groups=5) + + assert samples.shape == (5,) + + def test_heavier_tails_than_normal(self): + """Test Student-t produces more extreme values than Normal.""" + # df=2 gives very heavy tails + student_prior = StudentTGroupModePrior( + "s", + sd_rv=DeterministicVariable("sd_s", 1.0), + df_rv=DeterministicVariable("df", 2.0), + ) + normal_prior = HierarchicalNormalPrior( + "n", sd_rv=DeterministicVariable("sd_n", 1.0) + ) + + n_samples = 5000 + with numpyro.handlers.seed(rng_seed=42): + student_samples = student_prior.sample(n_groups=n_samples) + with numpyro.handlers.seed(rng_seed=42): + normal_samples = normal_prior.sample(n_groups=n_samples) + + # Student-t should have more extreme values (higher max absolute value) + assert jnp.max(jnp.abs(student_samples)) > jnp.max(jnp.abs(normal_samples)) + + def test_rejects_non_random_variable_params(self): + """Test that non-RandomVariable parameters are rejected.""" + with pytest.raises(TypeError, match="sd_rv must be a RandomVariable"): + StudentTGroupModePrior( + "mode", + sd_rv=1.0, + df_rv=DeterministicVariable("df", 4.0), + ) + + with pytest.raises(TypeError, match="df_rv must be a RandomVariable"): + StudentTGroupModePrior( + "mode", + sd_rv=DeterministicVariable("sd", 1.0), + df_rv=4.0, + ) + + def test_validate(self): + """Test that validate() runs without error.""" + prior = StudentTGroupModePrior( + "mode", + sd_rv=DeterministicVariable("sd", 1.0), + df_rv=DeterministicVariable("df", 4.0), + ) + prior.validate() # Should not raise + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])