BART (Bayesian Additive Regression Trees)

BART is a nonparametric regression approach using an ensemble of trees with a Bayesian prior. Available via pymc-bart.

Table of Contents

Basic Usage

import pymc as pm
import pymc_bart as pmb

with pm.Model() as bart_model:
    # BART prior over the regression function
    mu = pmb.BART("mu", X=X, Y=y, m=50)

    # Observation noise
    sigma = pm.HalfNormal("sigma", sigma=1)

    # Likelihood
    y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y)

    idata = pm.sample()

Regression

Continuous Outcome

with pm.Model() as regression_bart:
    mu = pmb.BART("mu", X=X_train, Y=y_train, m=50)
    sigma = pm.HalfNormal("sigma", 1)
    y = pm.Normal("y", mu=mu, sigma=sigma, observed=y_train)

    idata = pm.sample()

# Predictions
with regression_bart:
    pmb.set_data({"mu": X_test})
    ppc = pm.sample_posterior_predictive(idata)

Heteroscedastic Regression

with pm.Model() as hetero_bart:
    # Mean function
    mu = pmb.BART("mu", X=X, Y=y, m=50)

    # Variance function (also BART)
    log_sigma = pmb.BART("log_sigma", X=X, Y=y, m=20)
    sigma = pm.Deterministic("sigma", pm.math.exp(log_sigma))

    y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y)

Classification

Binary Classification

with pm.Model() as binary_bart:
    # BART on latent scale
    mu = pmb.BART("mu", X=X, Y=y, m=50)

    # Probit or logit link
    p = pm.Deterministic("p", pm.math.sigmoid(mu))

    y_obs = pm.Bernoulli("y_obs", p=p, observed=y)

    idata = pm.sample()

Multiclass Classification

with pm.Model(coords={"class": classes}) as multiclass_bart:
    # Separate BART for each class (one-vs-rest style)
    mu = pmb.BART("mu", X=X, Y=y_onehot, m=50, dims="class")

    # Softmax
    p = pm.Deterministic("p", pm.math.softmax(mu, axis=-1))

    y_obs = pm.Categorical("y_obs", p=p, observed=y)

Variable Importance

Compute Variable Importance

# After sampling
vi = pmb.compute_variable_importance(idata, X, method="VI")

# Plot
pmb.plot_variable_importance(vi, X)

Methods

  • "VI": Based on inclusion frequency in trees
  • "backward": Backward elimination importance
# Backward elimination (more expensive but often better)
vi_backward = pmb.compute_variable_importance(
    idata, X, method="backward", random_seed=42
)

Partial Dependence

1D Partial Dependence

# Partial dependence for variable at index 0
pmb.plot_pdp(idata, X=X, Y=y, xs_interval="quantiles", var_idx=[0])

# Multiple variables
pmb.plot_pdp(idata, X=X, Y=y, var_idx=[0, 1, 2])

2D Partial Dependence (Interaction)

# Interaction between variables 0 and 1
pmb.plot_pdp(idata, X=X, Y=y, var_idx=[0, 1], grid="wide")

Individual Conditional Expectation (ICE)

pmb.plot_ice(idata, X=X, Y=y, var_idx=0)

Configuration

Key Parameters

mu = pmb.BART(
    "mu",
    X=X,
    Y=y,
    m=50,              # number of trees (default 50, more = smoother)
    alpha=0.95,        # prior probability tree has depth 1
    beta=2.0,          # controls depth of trees
    split_prior=None,  # prior on split variable selection
)

Number of Trees (m)

  • m=50: Good default
  • m=100-200: Smoother fit, more computation
  • m=20-30: Faster, may underfit
# More trees for complex functions
mu = pmb.BART("mu", X=X, Y=y, m=100)

Tree Depth (alpha, beta)

Controls tree complexity via prior P(node is terminal at depth d) = alpha * (1 + d)^(-beta)

  • Higher alpha or lower beta: Deeper trees
  • Default alpha=0.95, beta=2 works well

Split Prior

Control which variables are preferred for splitting:

# Uniform (default)
split_prior = None

# Favor first 3 variables
split_prior = [2, 2, 2, 1, 1, 1, 1]  # length = n_features

mu = pmb.BART("mu", X=X, Y=y, split_prior=split_prior)

Combining BART with Parametric Components

BART + Linear

with pm.Model() as semi_parametric:
    # Linear component for known effects
    beta = pm.Normal("beta", 0, 1, shape=p_linear)
    linear = pm.math.dot(X_linear, beta)

    # BART for nonlinear/interaction effects
    nonlinear = pmb.BART("nonlinear", X=X_nonlinear, Y=y, m=50)

    mu = linear + nonlinear
    sigma = pm.HalfNormal("sigma", 1)
    y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y)

BART + Random Effects

with pm.Model(coords={"group": groups}) as bart_mixed:
    # Group random effects
    sigma_group = pm.HalfNormal("sigma_group", 1)
    alpha = pm.Normal("alpha", 0, sigma_group, dims="group")

    # BART for fixed effects
    mu_bart = pmb.BART("mu_bart", X=X, Y=y, m=50)

    mu = mu_bart + alpha[group_idx]
    sigma = pm.HalfNormal("sigma", 1)
    y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y)

Out-of-Sample Prediction

# Fit model
with bart_model:
    idata = pm.sample()

# Predict on new data
with bart_model:
    pmb.set_data({"mu": X_new})
    ppc = pm.sample_posterior_predictive(idata, var_names=["y_obs"])

# Extract predictions
y_pred = ppc.posterior_predictive["y_obs"]

Convergence Diagnostics

BART uses a particle Gibbs sampler, so standard MCMC diagnostics apply:

import arviz as az

az.plot_trace(idata, var_names=["sigma"])
az.summary(idata, var_names=["sigma"])

# For BART predictions, check posterior predictive
az.plot_ppc(idata)