Prior Selection Guide
Table of Contents
- The Foundational Role of Priors
- Hierarchy of Prior Informativeness
- Weakly Informative Defaults
- Regression and GLM Priors
- High-Dimensional and Sparse Priors
- Hierarchical Model Priors
- Prior Predictive Checking
- Sensitivity Analysis
- Expert Elicitation
- Practical Implementation Guidelines
- Domain-Specific Guidance
The Foundational Role of Priors
In Bayesian inference, the prior distribution encodes pre-experimental knowledge, structural assumptions, and epistemic uncertainty. It acts as a weight function penalizing parameter regions inconsistent with theoretical constraints.
Key insight: While the prior’s influence typically diminishes as data increases, it remains critical in: - High-dimensional settings (p > n) - Small-sample scenarios - Models with weak identification
Analogy: Choosing a prior is like selecting foundations for a building on unknown soil. A flat prior is building without foundations—risky unless ground is solid. A weakly informative prior provides sturdy pilings that prevent collapse while adjusting as the structure settles. An informative prior builds on existing foundations—faster if correct, catastrophic if flawed.
Hierarchy of Prior Informativeness
| Prior Level | Intended Use | Recommendation |
|---|---|---|
| Flat | “Objective” estimation; no preference | Discouraged. Can lead to improper posteriors and lacks regularization. |
| Super-vague | Minimal influence (e.g., N(0, 10⁶)) | Not recommended. Causes numerical instability and funnel geometries. |
| Weakly Informative | Regularization; ruling out absurdities | Preferred default. Provides stability while letting data dominate. |
| Specifically Informative | Expert belief or historical data | Use with caution. Requires formal elicitation or meta-analysis. |
# Examples of the hierarchy
# AVOID: Flat/super-vague
beta_bad = pm.Flat("beta") # improper, avoid
beta_vague = pm.Normal("beta", mu=0, sigma=1e6) # numerical issues
# PREFERRED: Weakly informative
beta = pm.Normal("beta", mu=0, sigma=2.5) # regularizing, stable
# USE CAREFULLY: Informative (requires justification)
beta_informed = pm.Normal("beta", mu=0.3, sigma=0.1) # from prior studyWeakly Informative Defaults
Location Parameters (means, intercepts, coefficients)
# Standardize predictors first, then use unit-scale priors
beta = pm.Normal("beta", mu=0, sigma=1) # coefficients on standardized scale
intercept = pm.Normal("intercept", mu=0, sigma=2.5) # wider for intercepts
# Student-t for robustness to outlier effects (nu between 3-7)
beta = pm.StudentT("beta", nu=4, mu=0, sigma=2.5)Scale Parameters (standard deviations)
Modern practice favors half-Normal, half-t, or Exponential priors.
# Half-Normal (recommended default)
sigma = pm.HalfNormal("sigma", sigma=1)
# Half-Student-t (robust, good for hierarchical variance components)
sigma = pm.HalfStudentT("sigma", nu=4, sigma=1)
# Exponential (stronger regularization toward 0)
sigma = pm.Exponential("sigma", lam=1)Avoid InverseGamma for variances: Despite historical popularity, InverseGamma can be surprisingly informative near zero, pushing posteriors away from zero even when groups are identical.
# AVOID in hierarchical models:
variance = pm.InverseGamma("variance", alpha=0.01, beta=0.01) # problematic
# PREFER:
sigma = pm.HalfNormal("sigma", sigma=1)Correlation Matrices (LKJ Prior)
The LKJ prior is standard for correlation matrices: - η = 1: Uniform over all valid correlation matrices - η > 1: Concentrates toward identity matrix (lower correlations expected) - η < 1: Favors extreme correlations
# LKJ prior on correlation matrix
chol, corr, stds = pm.LKJCholeskyCov(
"chol", n=n_dims, eta=2.0, # mild shrinkage toward identity
sd_dist=pm.HalfNormal.dist(sigma=1)
)
# Use in Multivariate Normal
vals = pm.MvNormal("vals", mu=0, chol=chol, observed=data)Probability Parameters
# Beta prior for bounded [0, 1]
p = pm.Beta("p", alpha=1, beta=1) # uniform
p = pm.Beta("p", alpha=2, beta=2) # slight mode at 0.5
# Logit-normal for softer constraints
logit_p = pm.Normal("logit_p", mu=0, sigma=1.5)
p = pm.Deterministic("p", pm.math.sigmoid(logit_p))Count/Rate Parameters
# Gamma for rates (positive, right-skewed)
rate = pm.Gamma("rate", alpha=2, beta=1)
# Log-normal when multiplicative effects expected
rate = pm.LogNormal("rate", mu=0, sigma=1)Regression and GLM Priors
Coefficients
For standardized predictors, default to N(0, 1) or N(0, 2.5):
# Linear regression (standardized predictors)
beta = pm.Normal("beta", mu=0, sigma=1, dims="predictors")
# Logistic regression (wider to allow strong effects)
beta = pm.Normal("beta", mu=0, sigma=2.5, dims="predictors")
# Robust to outlier predictors (Student-t, 3 < nu < 7)
beta = pm.StudentT("beta", nu=4, mu=0, sigma=2.5, dims="predictors")Intercepts
# If outcome standardized
intercept = pm.Normal("intercept", mu=0, sigma=1)
# If outcome on original scale, center on domain-reasonable value
# Example: human height in cm
intercept = pm.Normal("intercept", mu=170, sigma=20)High-Dimensional and Sparse Priors
When p > n, regularization is essential for the bias-variance trade-off.
Laplace (Lasso) Prior
Continuous shrinkage with tall mode at zero and thick tails. Efficient but shrinks even large coefficients.
# Manual Laplace prior
beta = pm.Laplace("beta", mu=0, b=1, dims="features")Horseshoe Prior
Global-local shrinkage: global hyperparameter shrinks all coefficients while local half-Cauchy allows large signals to escape.
# Manual horseshoe (no built-in pmx.Horseshoe exists)
tau = pm.HalfCauchy("tau", beta=1) # global shrinkage
lam = pm.HalfCauchy("lam", beta=1, dims="features") # local shrinkage
beta = pm.Normal("beta", mu=0, sigma=tau * lam, dims="features")Sampling challenges: The Horseshoe creates a “double-funnel” geometry (massive spike at zero + heavy tails) that is extremely difficult for NUTS. Divergences are common unless using very high target_accept (0.99+). Consider:
- Using the Regularized Horseshoe parameterization (see below)
- Using simpler Laplace prior if full sparsity isn’t required
- Increasing
target_acceptto 0.99 and allowing longer sampling
Regularized (Finnish) Horseshoe
Adds a finite “slab” width to prevent infinite estimates in cases of data separation (e.g., logistic regression). Must be implemented manually:
import pytensor.tensor as pt
# Regularized horseshoe (manual implementation)
# D0 = expected number of non-zero coefficients, D = total features, N = observations
tau = pm.HalfStudentT("tau", nu=2, sigma=D0 / (D - D0) * sigma / np.sqrt(N))
lam = pm.HalfStudentT("lam", nu=5, dims="features")
c2 = pm.InverseGamma("c2", alpha=1, beta=1) # slab variance
z = pm.Normal("z", 0, 1, dims="features")
# Regularized shrinkage factor
lam_tilde = pt.sqrt(c2 / (c2 + tau**2 * lam**2))
beta = pm.Deterministic("beta", z * tau * lam * lam_tilde, dims="features")R2D2 Prior
Induces a prior directly on R² (variance explained), often more interpretable. Available in pymc-extras.
import pymc_extras as pmx
# X must be centered; returns (residual_sigma, coefficients)
# output_sigma and input_sigma are required
output_sigma = y.std()
input_sigma = X.std(axis=0)
residual_sigma, beta = pmx.R2D2M2CP(
"r2d2",
output_sigma=output_sigma,
input_sigma=input_sigma,
dims="features",
r2=0.5, # prior mean R²
r2_std=0.2, # uncertainty in R²
)Spike-and-Slab
“Gold standard” for variable selection using discrete mixture of point mass at zero and diffuse slab. Computationally demanding.
# Spike-and-slab (requires careful implementation)
import pytensor.tensor as pt
inclusion = pm.Bernoulli("inclusion", p=0.5, dims="features")
beta_slab = pm.Normal("beta_slab", mu=0, sigma=2, dims="features")
beta = pm.Deterministic("beta", inclusion * beta_slab, dims="features")Hierarchical Model Priors
Hierarchical models use “partial pooling” where group estimates are informed by both their own data and the population distribution.
Variance Components
# Half-Student-t with nu=3-4 is robust default
# Prevents variances from collapsing to zero while remaining broad
sigma_group = pm.HalfStudentT("sigma_group", nu=4, sigma=1)
# For moderate pooling
sigma_group = pm.HalfNormal("sigma_group", sigma=0.5)
# AVOID: uniform on large range
sigma_group = pm.Uniform("sigma_group", 0, 100) # problematicSmall Number of Groups
With few groups (< 5-10), broader Cauchy priors can cause underpooling. Use tighter priors:
# Few groups: more informative prior prevents underpooling
sigma_group = pm.HalfNormal("sigma_group", sigma=0.5)
# Many groups: can use broader priors
sigma_group = pm.HalfCauchy("sigma_group", beta=1)Non-Centered vs Centered Parameterization
When group-level variance is small relative to data, use non-centered:
# Non-centered (preferred when data per group is sparse)
sigma_group = pm.HalfNormal("sigma_group", sigma=1)
z = pm.Normal("z", mu=0, sigma=1, dims="groups")
group_effect = pm.Deterministic("group_effect", z * sigma_group, dims="groups")
# Centered (can work when lots of data per group)
sigma_group = pm.HalfNormal("sigma_group", sigma=1)
group_effect = pm.Normal("group_effect", mu=0, sigma=sigma_group, dims="groups")Prior Predictive Checking
Always simulate from priors before fitting to validate that priors generate plausible data.
with model:
prior_pred = pm.sample_prior_predictive(draws=500)
# Visualize prior predictive distribution
import arviz as az
az.plot_ppc(prior_pred, group="prior", kind="cumulative")
# Numerical summary
prior_y = prior_pred.prior_predictive["y"].values.flatten()
print(f"Prior predictive range: [{prior_y.min():.2f}, {prior_y.max():.2f}]")
print(f"Prior predictive mean: {prior_y.mean():.2f}")
print(f"Prior predictive std: {prior_y.std():.2f}")Warning Signs
| Problem | Symptom | Solution |
|---|---|---|
| Implausible values | Negative counts, >100% probabilities, heights > 10m | Tighten priors or add constraints |
| Too concentrated | Prior predictive doesn’t cover observed data range | Widen priors |
| Extremely wide | Prior generates absurd values (humans 100m tall) | Use domain knowledge to constrain |
| Miscentered | Prior predictive systematically biased | Adjust location parameters |
# Example: checking for implausible values
prior_y = prior_pred.prior_predictive["y"].values.flatten()
# For count data - should be non-negative
if (prior_y < 0).any():
print("WARNING: Prior allows negative counts!")
# For proportions - should be in [0, 1]
if (prior_y < 0).any() or (prior_y > 1).any():
print("WARNING: Prior allows invalid proportions!")
# For physical measurements - check against domain knowledge
if prior_y.max() > 300: # e.g., human height in cm
print("WARNING: Prior allows implausible heights!")Sensitivity Analysis
Evaluates posterior stability under prior perturbations. Critical in small datasets where prior influence is strongest.
Basic Sensitivity Check
def fit_with_prior(prior_sd):
"""Fit model with different prior scales."""
with pm.Model() as m:
beta = pm.Normal("beta", mu=0, sigma=prior_sd)
sigma = pm.HalfNormal("sigma", sigma=1)
y = pm.Normal("y", mu=beta * X, sigma=sigma, observed=y_obs)
trace = pm.sample()
return trace
# Compare posteriors under different priors
traces = {}
for sd in [0.5, 1.0, 2.0, 5.0]:
traces[sd] = fit_with_prior(sd)
# Check if posteriors are similar
for sd, trace in traces.items():
print(f"Prior SD={sd}: posterior mean={trace.posterior['beta'].mean():.3f}")Prior-Likelihood Conflict
Occurs when data substantially contradicts the prior. Can detect using prior-posterior comparison:
# Compare prior and posterior
az.plot_dist_comparison(trace, var_names=["beta"])
# Large divergence suggests prior-likelihood conflict
prior_mean = 0 # from prior specification
posterior_mean = trace.posterior["beta"].mean().item()
posterior_sd = trace.posterior["beta"].std().item()
# Check if prior is many SDs from posterior
if abs(prior_mean - posterior_mean) > 3 * posterior_sd:
print("WARNING: Potential prior-likelihood conflict")Expert Elicitation
When data is scarce (rare diseases, novel phenomena), priors must be constructed through formal elicitation. For formal elicitation protocols (SHELF framework, roulette method, expert aggregation), consult domain-specific elicitation literature such as O’Hagan et al. (2006) Uncertain Judgements: Eliciting Experts’ Probabilities.
Practical Implementation Guidelines
1. Scale Your Data
Rescale predictors and outcomes to unit scale (order of magnitude ~1):
# Standardize predictors
X_scaled = (X - X.mean(axis=0)) / X.std(axis=0)
# Standardize outcome (optional but helpful)
y_scaled = (y - y.mean()) / y.std()
# Now use simple priors
with pm.Model() as model:
beta = pm.Normal("beta", mu=0, sigma=1, dims="predictors")
intercept = pm.Normal("intercept", mu=0, sigma=1)
sigma = pm.HalfNormal("sigma", sigma=1)
mu = intercept + pm.math.dot(X_scaled, beta)
y_obs = pm.Normal("y", mu=mu, sigma=sigma, observed=y_scaled)2. Use Coords and Dims
Makes prior specification clearer and InferenceData more interpretable:
coords = {
"predictors": ["age", "income", "education"],
"obs_id": range(len(y))
}
with pm.Model(coords=coords) as model:
beta = pm.Normal("beta", mu=0, sigma=1, dims="predictors")3. Document Your Prior Choices
# Good practice: document prior justification
with pm.Model() as model:
# Prior: N(0, 2.5) allows coefficients up to ~5 on standardized scale
# Justification: covers effect sizes seen in similar studies
beta = pm.Normal("beta", mu=0, sigma=2.5, dims="predictors")
# Prior: HalfNormal(1) - weakly informative, allows broad range
# On standardized outcome, SD > 3 would be extreme
sigma = pm.HalfNormal("sigma", sigma=1)4. Reparameterize When Needed
If models show poor convergence (funnels, divergences), reparameterize:
# Centered (may cause funnels with small group sizes)
with pm.Model() as model_centered:
mu = pm.Normal("mu", mu=0, sigma=1)
sigma_group = pm.HalfNormal("sigma_group", sigma=1)
group_effect = pm.Normal("group_effect", mu=mu, sigma=sigma_group, dims="groups")
# Non-centered (more robust)
with pm.Model() as model_noncentered:
mu = pm.Normal("mu", mu=0, sigma=1)
sigma_group = pm.HalfNormal("sigma_group", sigma=1)
z = pm.Normal("z", mu=0, sigma=1, dims="groups") # standard normal
group_effect = pm.Deterministic("group_effect", mu + z * sigma_group, dims="groups")Domain-Specific Guidance
Biostatistics/Epidemiology
# Log-odds ratios: N(0, 2.5) allows ORs from ~0.01 to ~100
log_or = pm.Normal("log_or", mu=0, sigma=2.5)
# Hazard ratios: similar on log scale
log_hr = pm.Normal("log_hr", mu=0, sigma=1)
# Prevalence: Beta unless strong prior info
prevalence = pm.Beta("prevalence", alpha=1, beta=1) # or (2, 2) for mode at 0.5Physical Sciences
# Incorporate physical constraints
# Example: diffusion coefficient (must be positive)
D = pm.LogNormal("D", mu=np.log(1e-9), sigma=0.5) # centered on typical value
# Use informative priors from previous experiments
# Example: speed of sound with known uncertainty
c = pm.Normal("c", mu=343, sigma=5) # m/s in air at 20°CMachine Learning / Prediction
# Focus on predictive performance
# Use shrinkage priors for automatic relevance determination
import pymc_extras as pmx
# R2D2 prior - requires output_sigma and input_sigma
output_sigma = y.std()
input_sigma = X.std(axis=0)
residual_sigma, beta = pmx.R2D2M2CP(
"r2d2",
output_sigma=output_sigma,
input_sigma=input_sigma,
dims="features",
r2=0.5,
r2_std=0.25,
)
# Evaluate via LOO-CV
trace = pm.sample()
loo = az.loo(trace)