Skip to content

State-Centered Temporal Processes for PyRenew #810

@cdc-mitzimorris

Description

@cdc-mitzimorris

The current production temporal processes are innovation-based, and wherever HEW uses ARProcess, PyRenew enforces non-centered transition innovations. There is no centered state-space temporal-process implementation being used for Rt or the other AR temporal components.

This matters for data-rich time series. In settings where the likelihood strongly identifies the latent Rt trajectory, a non-centered innovation parameterization can be inefficient: changing one early standardized innovation can affect all downstream states after cumulative reconstruction, forcing small HMC step sizes and deep trees. A centered state-space parameterization may be substantially faster for these models.

Goal

Add temporal process classes that support a centered state-space parameterization through the existing pyrenew.latent.TemporalProcess API.

These classes should be usable with existing PyRenew model construction machinery, including PopulationInfections, SubpopulationInfections, WeeklyTemporalProcess, StepwiseTemporalProcess, PyrenewBuilder, and MultiSignalModel.

Proposed Terminology

Use the term state-centered temporal process or centered state-space temporal process.

The key distinction is:

Current innovation-based style:
  sample innovations -> deterministically scan/integrate -> state path

State-centered style:
  sample state path directly -> apply Markov transition priors to states

Proposed API

Add new classes in pyrenew.latent.temporal_processes that satisfy the existing TemporalProcess protocol directly:

sample(
    n_timepoints: int,
    initial_value: float | ArrayLike | None = None,
    n_processes: int = 1,
    name_prefix: str = "temporal",
    *,
    first_day_dow: int | None = None,
) -> ArrayLike

Each implementation must return an array of shape:

(n_timepoints, n_processes)

Proposed classes:

CenteredRandomWalk
CenteredAR1
CenteredDifferencedAR1

Alternative names:

StateCenteredRandomWalk
StateCenteredAR1
StateCenteredDifferencedAR1

Centered* is concise, but the documentation should be explicit that "centered" means state-centered.

CenteredRandomWalk

Model:

x[0] = initial_value
x[t] ~ Normal(x[t - 1], sigma)

This is equivalent to a RW1 prior on the state path, but the sampled latent variables are the states, not standardized innovations.

Implementation sketch:

sigma = self.innovation_sd_rv()
x0 = _prepare_initial_value(initial_value, n_processes)

def transition(prev_x, _):
    x_t = numpyro.sample(
        f"{name_prefix}_state",
        dist.Normal(prev_x, sigma),
    )
    return x_t, x_t

_, xs = scan(transition, x0, xs=None, length=n_timepoints - 1)
return jnp.concatenate([x0[jnp.newaxis, :], xs], axis=0)

CenteredAR1

Model:

x[0] ~ Normal(initial_value, sigma / sqrt(1 - phi^2))
x[t] ~ Normal(phi * x[t - 1], sigma)

This matches the current PyRenew AR behavior, where the AR process reverts toward zero and initial_value is the prior location for the initial state.

If a nonzero long-run mean is needed later, that should be a separate explicit extension rather than an implicit behavior change.

CenteredDifferencedAR1

This is the key analogue for HEW-style Rt dynamics.

Current HEW Rt is approximately:

delta[0] ~ Normal(0, stationary_sd)
delta[t] = phi * delta[t - 1] + sigma * z[t]
x[t] = x[t - 1] + delta[t]

A state-centered analogue should sample the x path directly:

x[0] = initial_value
x[1] ~ Normal(x[0], stationary_sd)

for t >= 2:
  previous_delta = x[t - 1] - x[t - 2]
  x[t] ~ Normal(x[t - 1] + phi * previous_delta, sigma)

where:

stationary_sd = sigma / sqrt(1 - phi^2)

This avoids making the entire Rt path a deterministic cumulative sum of sampled transition innovations.

Reparameterization Control

Existing ARProcess hard-codes LocScaleReparam(0) around its transition innovation sample site, which forces fully non-centered innovation sampling.

The proposed state-space temporal processes should not hard-code this behavior. They should expose transition sample sites directly as state variables and default to centered sampling:

x[t] ~ Normal(transition_mean[t], transition_sd)

However, the implementation should allow optional NumPyro reparameterization of those transition sample sites, for example via an optional reparam argument:

CenteredAR1(..., reparam=None)                # default centered
CenteredAR1(..., reparam=LocScaleReparam(0))  # optional NCP

This separates two modeling concerns:

1. What latent object is exposed? states vs innovations
2. How does NumPyro parameterize each Normal sample site? centered vs non-centered

This also keeps the new classes useful for benchmarking centered, non-centered, and partially centered parameterizations.

Integration With Existing Builders

The new classes should be drop-in replacements anywhere a TemporalProcess is currently accepted.

Example:

latent = PopulationInfections(
    I0_rv=...,
    gen_int_rv=...,
    log_rt_time_0_rv=...,
    single_rt_process=CenteredDifferencedAR1(
        autoreg_rv=...,
        innovation_sd_rv=...,
    ),
)

Weekly Rt should work through the existing wrapper:

single_rt_process=WeeklyTemporalProcess(
    CenteredDifferencedAR1(
        autoreg_rv=...,
        innovation_sd_rv=...,
    ),
    start_dow=...,
)

This should allow direct performance comparisons using existing integration fixtures and examples in test/integration/conftest.py, PyrenewBuilder, and MultiSignalModel.

Testing Plan

Add unit tests in test/test_temporal_processes.py:

  • Shape checks for (n_timepoints, n_processes).
  • Scalar and vector initial_value handling.
  • n_timepoints == 1.
  • Scalar and vector hyperparameter behavior.
  • Compatibility with StepwiseTemporalProcess and WeeklyTemporalProcess.
  • Trace checks showing that state sample sites are present and transition innovations are not hard-coded as non-centered.
  • Optional reparameterization checks using LocScaleReparam(0).

Add integration tests or fixtures parallel to existing examples:

he_model_centered_ar1
he_weekly_rt_model_centered_ar1
he_weekly_rt_model_centered_differenced_ar1

Use the same synthetic data, priors, observation models, and MCMC settings as the existing innovation-based versions to compare sampler performance.

Non-Goals

Do not retrofit pyrenew.process.DifferencedProcess to provide state-centered sampling. That abstraction is fundamentally "sample differences, integrate to states."

Do not change existing AR1, DifferencedAR1, or RandomWalk behavior in the initial implementation. Add centered variants first so existing models remain stable and performance can be compared directly.

Open Questions

  • Should class names use Centered* or StateCentered*?
  • Should CenteredAR1 support a nonzero long-run mean now, or should it match the current zero-mean AR behavior exactly?
  • Should optional reparameterization accept one reparameterizer for all transition sites, or separate reparameterizers for initial-state and transition-state sample sites?
  • Should CenteredDifferencedAR1 expose the first transition as a state sample (x[1]) as proposed, or expose an initial rate sample and then state transitions thereafter?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions