Inference Methods

Table of Contents

Method Selection Guide

MCMC Samplers (Exact Inference)

Method Best For Speed GPU Notes
nutpie Default choice 2-5x faster than PyMC No Rust-based, excellent adaptation
NumPyro Large models, GPU Fast Yes JAX-based, vectorized chains
PyMC NUTS Compatibility Baseline No Most tested, fallback option

Approximate Inference (Fast but Inexact)

Method Best For Speed GPU Notes
ADVI Quick approximations Fast No Mean-field or full-rank Gaussian
DADVI Stable VI Very fast Yes Deterministic gradients
Pathfinder Initialization, screening Very fast Yes Quasi-Newton optimization paths

When to use approximate inference: - Model screening before committing to full MCMC - Very large datasets where MCMC is prohibitively slow - Finding good initial values for MCMC - Posteriors that are approximately Gaussian

Caution: Approximate methods underestimate posterior uncertainty and may miss multimodality. Always validate with MCMC when possible.


Initialization and Common Failures

“Initial evaluation of model at starting point failed”

This error occurs when the log-probability is -inf or NaN at initial parameter values.

Common causes and fixes:

Cause Fix
Data outside distribution support Verify observed data matches likelihood bounds
Jitter pushes parameters outside constraints Use init="adapt_diag" (no jitter)
Invalid default starting values Specify initvals={"param": value}
Constant response variable Ensure target variable has variance
# Fix 1: Reduce/eliminate initialization jitter
idata = pm.sample(init="adapt_diag")

# Fix 2: Specify valid starting values
idata = pm.sample(initvals={"sigma": 1.0, "beta": np.zeros(p)})

# Fix 3: Use ADVI for more robust initialization
idata = pm.sample(init="advi+adapt_diag")

# Debugging: check which variables have invalid log-probabilities
model.point_logps()
model.debug()

The MCMC Prior Sampling Fallacy

Common mistake: Using pm.sample() to sample from the prior distribution.

# BAD: pm.sample() uses MCMC even without observations
with prior_model:
    prior = pm.sample(draws=1000)  # slow, poor convergence for discrete vars

# GOOD: Use ancestral sampling for priors
with prior_model:
    prior = pm.sample_prior_predictive(draws=1000)  # instant, exact

pm.sample_prior_predictive() performs ancestral sampling (drawing directly from distributions in dependency order), which is instant and avoids all MCMC convergence issues.


MCMC Samplers

nutpie (Default)

Rust-based sampler with excellent mass matrix adaptation. Use as default.

Basic Usage

import pymc as pm

with pm.Model() as model:
    # ... model specification ...
    pass

# Sample with nutpie backend
with model:
    idata = pm.sample(
        draws=1000,
        tune=1000,
        chains=4,
        nuts_sampler="nutpie",
        random_seed=42,
    )

    # IMPORTANT: nutpie doesn't store log_likelihood automatically
    # Compute it explicitly if you need LOO-CV or LOO-PIT
    pm.compute_log_likelihood(idata)

Configuration Options

with model:
    idata = pm.sample(
        draws=1000,
        tune=1000,
        chains=4,
        nuts_sampler="nutpie",
        random_seed=42,
        progressbar=True,
        target_accept=0.8,  # increase for difficult posteriors
        cores=4,            # number of parallel chains
    )

When to Use PyMC NUTS Instead

  • Debugging model specification issues (temporary only — switch back to nutpie after debugging)
  • Model requires compound step methods mixing NUTS with discrete samplers

Note: nutpie supports all standard PyMC distributions and operations including pytensor.scan, GaussianRandomWalk, AR, GARCH11, HSGP, Mixture, NormalMixture, and DensityDist. “Complex model” is never a reason to avoid nutpie.

NumPyro/JAX Backend

JAX-based sampling with GPU support and vectorized chains.

Setup

import pymc as pm

# Optional: configure JAX for GPU
import jax
jax.config.update("jax_platform_name", "gpu")  # or "cpu"

Basic Usage

with pm.Model() as model:
    # ... model specification ...
    pass

# Sample with NumPyro NUTS
with model:
    idata = pm.sample(
        draws=1000,
        tune=1000,
        chains=4,
        nuts_sampler="numpyro",
        random_seed=42,
    )

Vectorized Chains (GPU Efficient)

# Run all chains in parallel on GPU
with model:
    idata = pm.sample(
        draws=1000,
        tune=1000,
        chains=4,
        nuts_sampler="numpyro",
        nuts_sampler_kwargs={"chain_method": "vectorized"},
    )

Configuration

with model:
    idata = pm.sample(
        draws=1000,
        tune=1000,
        chains=4,
        nuts_sampler="numpyro",
        target_accept=0.9,
        nuts_sampler_kwargs={"max_tree_depth": 12},
        progressbar=True,
    )

When to Use NumPyro

  • Large models that benefit from GPU
  • Many chains needed (vectorization efficient)
  • Already in JAX ecosystem

PyMC NUTS

Legacy sampler. Only use for debugging or when nutpie and numpyro cannot be installed.

with model:
    idata = pm.sample(
        draws=1000,
        tune=1000,
        chains=4,
        random_seed=42,
        target_accept=0.8,
    )

Approximate Inference

Variational Inference

ADVI (Automatic Differentiation Variational Inference)

Approximates the posterior with a Gaussian distribution.

with model:
    approx = pm.fit(
        n=30000,
        method="advi",
        callbacks=[pm.callbacks.CheckParametersConvergence()],
    )

# Draw samples from the approximation
idata = approx.sample(1000)

Full-Rank ADVI

Better posterior approximation (captures correlations) at higher cost:

with model:
    approx = pm.fit(
        n=50000,
        method="fullrank_advi",
    )

DADVI (Deterministic ADVI)

More stable variational inference from pymc-extras:

import pymc_extras as pmx

with model:
    idata = pmx.fit(
        method="dadvi",
        num_steps=10000,
        random_seed=42,
    )

DADVI advantages: - Deterministic gradients (no Monte Carlo noise) - Faster convergence - More stable optimization

Pathfinder

Quasi-Newton variational method that follows optimization paths. Very fast for quick approximations or initialization.

with model:
    idata = pm.fit(method="pathfinder")

Multi-Path Pathfinder

with model:
    idata = pm.fit(
        method="pathfinder",
        num_paths=8,  # multiple optimization paths
        maxcor=10,    # L-BFGS history
    )

When to Use Pathfinder

  • Quick posterior approximation (seconds vs minutes)
  • Finding good initial values for MCMC
  • Model screening before full inference
  • When posterior is approximately Gaussian

Combining Methods

Pathfinder Initialization + MCMC

Use Pathfinder to find good starting points, then run MCMC for accurate inference:

with model:
    # Quick pathfinder approximation
    pathfinder_idata = pm.fit(method="pathfinder")

    # Extract initial values
    init_vals = {
        var.name: pathfinder_idata.posterior[var.name].mean(dim=["chain", "draw"]).values
        for var in model.free_RVs
    }

    # Run MCMC with good initialization
    idata = pm.sample(initvals=init_vals)

VI for Screening, MCMC for Final Inference

# Screen model with VI (fast)
with model:
    vi_approx = pm.fit(n=30000)

# If model looks reasonable, run full MCMC
with model:
    idata = pm.sample()

VI for Large Data, MCMC for Validation

# Full data: use VI for speed
with full_model:
    vi_approx = pm.fit(n=30000)

# Validation subset: use MCMC for accurate uncertainty
with subset_model:
    idata = pm.sample()