Decomposing \(R^2\) with principled priors
Introduction
Statistics is largely concerned with explaining variance in observed data. We may wish to estimate differences in outcomes between groups receiving different treatments, or characterise residual spatial clustering after accounting for covariates—whatever the approach, we try to capture as much structure in the data as possible, ideally reducing the noise or error terms with every useful added model component. A popular and intuitive measure is \(R^2\), or the proportion of explained variance in our data (\(\tau^2\)), where \(\sigma^2\) is the residual variance not explained by our model:
\[ R^2 = \frac{\tau^2}{\tau^2 + \sigma^2} \tag{1}\]
It is perhaps surprising that we didn’t put priors on \(R^2\) directly until recently (Zhang et al. 2022, Aguilar and Bürkner 2023). Typically, independent priors are provided for all terms in the model, e.g. intercept, fixed effect coefficients, and random effect and residual scales. However, since \(R^2\) is just a deterministic function of the different model components, any combination of priors for parameters has an implied prior on \(R^2\). Consider a regression model where both the response and \(P\) covariates are scaled to have mean 0 and 1 variance, where “weakly informative” priors might be \(\mathcal{N}\left(0, 1\right)\) for both the coefficients and the residual variance, then the expected \(R^2\) is shown below.
Thus, increasing the number of covariates in the model with independent priors for coefficients implicitly assumes we are going to explain more of the variance. This is inconsistent with how we know the world works, as exemplified by the piranha problem, which essentially states we can only have so many strong, independent effects in a system.
In this post, I’m going to be discussing a particularly appealing joint prior over model features, the \(R^2\) Decomposition (R2D) prior. In a nutshell, we place an independent prior on \(R^2\), derive explained variance from it using Equation 1, and then decompose this using our various model compnents using a simplex.
R2D Prior
We start by placing a prior on \(R^2\), for which Aguilar and Bürkner (2023) suggest a Beta distribution. Defining R2 as a [0, 1] bounded parameter in Stan imposes a uniform prior, which will do for now. Assuming we are using normal linear regression, we then derive \(\tau^2 = \frac{R^2}{1 - R^2} \sigma^2\) (Equation 1) to be the total variance explained by the model. We now wish to decompose this variance into the variances of our separate model components, where I’ll start with a model that has \(P\) fixed effects, all of which are assumed to be standardised with mean 0 and unit variance. We decompose \(\tau^2\) across our \(P\) predictors with a simplex \(\boldsymbol{\phi}\), which can be modeled with appropriate distributions. Originally this was done with the Dirichlet distribution (resulting in the R2D2 prior name, after Dirichlet Decomposition), but recent work suggests that the logit-normal may be preferred to capture more complex dependence structures (Aguilar and Bürkner 2025). I won’t go into it here, and just rely on Stan’s uniform priors implied by the parameter type simplex. The coefficients are then modeled as
\[ \beta_p \sim \mathcal{N} \left(0, \tau^2 \phi_p \right), \quad p = 1, \dots, P \tag{2}\]
Centered parameterisations like this will often run into sampling issues, so instead we adopt the non-centered parameterisation1, where
\[ \begin{aligned} \beta_p &= z \cdot \sqrt{\tau^2 \phi_p} \\ z &\sim \mathcal{N} \left(0, 1 \right). \end{aligned} \tag{3}\]
Our probabilistic model is encoded in a Stan program program (carpenter2017?).
Code
data {
int<lower=0> N, P; // number of observations and fixed effects
matrix[N, P] x; // covariates
vector[N] y; // observations
int<lower=0, upper=1> R2D; // R2D indicator
}
transformed data {
real y_mean = mean(y), y_sd = sd(y), inv_y_sd = inv(y_sd); // mean/SD priors
vector[P] shapes = rep_vector(0.5, P); // Dirichlet shapes
}
parameters {
real alpha; // intercept
real<lower=0> sigma; // residual SD
real<lower=0, upper=1> R2; // proportion explained variance
simplex[P] phi; // variance decomposition
vector[P] z; // coefficient z-scores
}
transformed parameters {
real tau2 = R2 / (1 - R2) * square(sigma); // explained variance
vector[P] scales = sqrt(tau2 * phi), // prior scales
beta = R2D ? scales .* z : z; // coefficients
}
model {
y ~ normal_id_glm(x, alpha, beta, sigma);
alpha ~ normal(y_mean, 1);
sigma ~ exponential(inv_y_sd);
phi ~ dirichlet(shapes);
z ~ std_normal();
}
generated quantities {
vector[N] log_lik;
for (n in 1:N) {
log_lik[n] = normal_lpdf(y[n] | alpha + dot_product(x[n], beta), sigma);
}
}One of the applications where these priors shine is with large numbers of predictors—let’s try this with \(P = 2N = 200\), where only the first 5 are simulated as non-zero. I’ll fit the model both with and without R2D prior, where the latter case just has standard normal priors for all of the coefficients.
Code
# simulate
N <- 100
P <- 200
alpha <- rnorm(1)
sigma <- rexp(1, 0.5)
R2 <- 0.8
tau2 <- R2 / (1 - R2) * sigma^2
u <- c(rgamma(5, 1, 1), rep(0, P - 5))
phi <- u / sum(u) # Dirichlet decomposition
scales <- sqrt(phi * tau2)
beta <- rnorm(P, 0, scales)
x <- matrix(rnorm(N * P), N, P)
y <- rnorm(N, alpha + x %*% beta, sigma)
# fit
stan_data <- list(N = N, P = P, x = x, y = y)
fits <- map(0:1, \(x) {
stan_data$R2D <- x
lm$sample(stan_data, chains = 8, iter_warmup = 200, iter_sampling = 500, refresh = 0)
})Wow—the R2D model not only recovered the correct effects with better predictive performance, but also ran considerably faster than the default model. I’ll show the estimates for the first 20 coefficients.
Code
# plot first 20
library(tidybayes)
map(fits, ~spread_rvars(., beta[p]) |>
mutate(truth = .env$beta)) |>
list_rbind(names_to = "R2D") |>
filter(p <= 20) |>
ggplot(aes(factor(p))) +
facet_wrap(~ factor(R2D, labels = c("N(0, 1)", "R2D")),
scales = "free_y", ncol = 1) +
geom_hline(yintercept = 0, linetype = "dashed", colour = green) +
stat_pointinterval(aes(ydist = beta),
point_interval = median_hdci, .width = 0.95,
size = 0.5, linewidth = 0.5) +
geom_point(aes(y = truth),
colour = green, position = position_nudge(x = -0.2)) +
labs(x = "Predictor", y = "Coefficient")[[1]]
Computed from 4000 by 100 log-likelihood matrix.
Estimate SE
elpd_loo -36.7 1.2
p_loo 108.5 1.2
looic 73.5 2.5
------
MCSE of elpd_loo is NA.
MCSE and ESS estimates assume MCMC draws (r_eff in [0.0, 0.0]).
Pareto k diagnostic values:
Count Pct. Min. ESS
(-Inf, 0.7] (good) 0 0.0% <NA>
(0.7, 1] (bad) 94 94.0% <NA>
(1, Inf) (very bad) 6 6.0% <NA>
See help('pareto-k-diagnostic') for details.
[[2]]
Computed from 4000 by 100 log-likelihood matrix.
Estimate SE
elpd_loo 127.0 6.8
p_loo 56.4 5.0
looic -253.9 13.5
------
MCSE of elpd_loo is NA.
MCSE and ESS estimates assume MCMC draws (r_eff in [0.3, 0.9]).
Pareto k diagnostic values:
Count Pct. Min. ESS
(-Inf, 0.7] (good) 64 64.0% 147
(0.7, 1] (bad) 32 32.0% <NA>
(1, Inf) (very bad) 4 4.0% <NA>
See help('pareto-k-diagnostic') for details.
elpd_diff se_diff
model2 0.0 0.0
model1 -163.7 7.0
To see the variance decomposition in action, we can also look at the variance partitions of the coefficients, again just shown for the first 20:
Code
fits[[2]] |>
spread_rvars(phi[p]) |>
mutate(truth = .env$phi) |>
filter(p <= 20) |>
ggplot(aes(factor(p), ydist = phi)) +
stat_pointinterval(point_interval = median_hdci, .width = 0.95,
size = 0.5, linewidth = 0.5) +
geom_point(aes(y = truth), colour = green, position = position_nudge(x = -0.2)) +
scale_y_continuous(expand = c(0, 0)) +
labs(x = "Predictor", y = "Variance Partition")How about random effects?
Random effects are usually modeled as normally distributed \(\boldsymbol{\epsilon} \sim \mathcal{N} \left( 0, \lambda \right)\), where the scale \(\lambda\) is estimated by the model. This is just another variance component we can add to the R2D prior: just like each fixed effect has a scale to be incorporated, each random effect is seen as a realisation of that partition of the variance. If we have \(R\) random effects with \(L_r\) levels each, our variance partition vector now has \(P + R\) elements with \(\sum_{r=1}^R L_r\) coefficients corresponding to each realisation of all of the random effects. In Stan, we would adjust our model as follows.
Code
data {
int<lower=0> N, P, R; // number of observations, fixed and random effects
array[R] int<lower=1> L; // random effect levels
matrix[N, P] x; // covariates
array[N, R] int<lower=1> r; // random levels
vector[N] y; // observations
}
transformed data {
int L_sum = sum(L); // total random effect realisations
array[R] int L_idx = cumulative_sum(L); // for indexing
for (i in 1:R) {
L_idx[i] -= L[i] - 1;
}
real y_mean = mean(y), y_sd = sd(y), inv_y_sd = inv(y_sd);
vector[P + R] shapes = rep_vector(0.5, P + R); // Dirichlet shapes
}
parameters {
real alpha; // intercept
real<lower=0> sigma; // residual SD
real<lower=0, upper=1> R2; // proportion explained variance
simplex[P + R] phi; // variance decomposition
vector[P] beta_z; // coefficient z-scores
vector[L_sum] epsilon_z; // random effect z-scores
}
transformed parameters {
real tau2 = R2 / (1 - R2) * square(sigma); // explained variance
vector[P + R] scales = sqrt(tau2 * phi); // prior scales
vector[P] beta = scales[:P] .* beta_z; // fixed effect coefficients
matrix[max(L), R] epsilon; // random effects
for (i in 1:R) {
epsilon[:L[i], i] = scales[P + i] .* segment(epsilon_z, L_idx[i], L[i]);
}
vector[N] mu = alpha + x * beta;
for (i in 1:R) {
mu += epsilon[r[:, i], i];
}
}
model {
y ~ normal(mu, sigma);
alpha ~ normal(y_mean, 1);
sigma ~ exponential(inv_y_sd);
R2 ~ beta_proportion(0.8, 0.1);
phi ~ dirichlet(shapes);
beta_z ~ std_normal();
epsilon_z ~ std_normal();
}
generated quantities {
vector[N] log_lik;
for (n in 1:N) {
log_lik[n] = normal_lpdf(y[n] | mu[n], sigma);
}
}Let’s do another simulation that also has random effects, first with \(P = 10\) and \(R = 4\), while keeping \(R^2\) the same. Scales and variance partitions are recovered pretty well.
Code
# create more non-zero scales and new coefficients
P <- 10
R <- 4
u <- c(rgamma(5, 1, 1), rep(0, P - 5), rgamma(R, 1, 1))
phi <- u / sum(u)
scales <- sqrt(phi * tau2)
beta <- rnorm(P, 0, scales[1:P])
x <- matrix(rnorm(N * P), N, P)
# random effects
L <- rpois(R, 10)
epsilon <- map2(L, scales[(P + 1):(P + R)], ~rnorm(.x, 0, .y))
r <- map(L, ~sample(., N, replace = T))
re_effects <- map2(epsilon, r, ~.x[.y]) |> simplify2array()
# sample
y <- rnorm(N, alpha + x %*% beta + rowSums(re_effects), sigma)
# fit
# lmm <- cmdstan_model(here("blog/lmm.stan"))
fit <- lmm$sample(list(N = N, P = P, R = R, L = L, x = x, r = simplify2array(r), y = y),
chains = 8, iter_warmup = 200, iter_sampling = 500, refresh = 0)Running MCMC with 8 parallel chains...
Chain 1 finished in 1.2 seconds.
Chain 2 finished in 1.1 seconds.
Chain 3 finished in 1.1 seconds.
Chain 4 finished in 1.1 seconds.
Chain 7 finished in 1.1 seconds.
Chain 6 finished in 1.2 seconds.
Chain 8 finished in 1.1 seconds.
Chain 5 finished in 2.9 seconds.
All 8 chains finished successfully.
Mean chain execution time: 1.3 seconds.
Total execution time: 3.1 seconds.
Code
fit |>
gather_rvars(scales[v], phi[v]) |>
mutate(truth = c(scales, .env$phi),
type = if_else(v > P, "Random", "Fixed")) |>
ggplot(aes(factor(v), ydist = .value)) +
facet_grid(factor(.variable, labels = c("phi", "lambda")) |> fct_rev() ~ type,
scales = "free", labeller = label_parsed) +
stat_pointinterval(point_interval = median_hdci, .width = 0.95,
size = 0.5, linewidth = 0.5) +
geom_point(aes(y = truth), colour = green, position = position_nudge(x = -0.2)) +
labs(x = "Component", y = "Posterior")Let’s try it again with \(P = 50\). Here we start running into trouble. I’m only plotting the first 10 coefficients.
Code
# create more non-zero scales and new coefficients
P <- 50
R <- 4
u <- c(rgamma(5, 1, 1), rep(0, P - 5), rgamma(R, 1, 1))
phi <- u / sum(u)
scales <- sqrt(phi * tau2)
beta <- rnorm(P, 0, scales[1:P])
x <- matrix(rnorm(N * P), N, P)
# random effects
L <- rpois(R, 10)
epsilon <- map2(L, scales[(P + 1):(P + R)], ~rnorm(.x, 0, .y))
r <- map(L, ~sample(., N, replace = T))
re_effects <- map2(epsilon, r, ~.x[.y]) |> simplify2array()
# sample
y <- rnorm(N, alpha + x %*% beta + rowSums(re_effects), sigma)
# fit
fit <- lmm$sample(list(N = N, P = P, R = R, L = L, x = x, r = simplify2array(r), y = y),
chains = 8, iter_warmup = 1000, iter_sampling = 1000, refresh = 0)Running MCMC with 8 parallel chains...
Chain 1 finished in 2.0 seconds.
Chain 3 finished in 1.9 seconds.
Chain 6 finished in 1.9 seconds.
Chain 7 finished in 2.0 seconds.
Chain 2 finished in 2.0 seconds.
Chain 8 finished in 2.0 seconds.
Chain 4 finished in 2.1 seconds.
Chain 5 finished in 2.2 seconds.
All 8 chains finished successfully.
Mean chain execution time: 2.0 seconds.
Total execution time: 2.3 seconds.
Code
fit |>
gather_rvars(scales[v], phi[v]) |>
mutate(truth = c(scales, .env$phi),
type = if_else(v > P, "Random", "Fixed")) |>
ggplot(aes(factor(v), ydist = .value)) +
facet_grid(factor(.variable, labels = c("phi", "lambda")) |> fct_rev() ~ type,
scales = "free", labeller = label_parsed) +
stat_pointinterval(point_interval = median_hdci, .width = 0.95,
size = 0.5, linewidth = 0.5) +
geom_point(aes(y = truth), colour = green, position = position_nudge(x = -0.2)) +
labs(x = "Component", y = "Posterior")References
Footnotes
Gassian random variables \(y \sim \mathcal{N}\left(\mu, \sigma \right)\) can always be written as \(y = \mu + z \cdot \sigma\) where \(z \sim \mathcal{N} \left(0, 1 \right)\), which usually samples better unless there is lots of data.↩︎