Decomposing \(R^2\) with principled priors

Author
Published

July 18, 2025

Introduction

Statistics is largely concerned with explaining variance in observed data. We may wish to estimate differences in outcomes between groups receiving different treatments, or characterise residual spatial clustering after accounting for covariates—whatever the approach, we try to capture as much structure in the data as possible, ideally reducing the noise or error terms with every useful added model component. A popular and intuitive measure is \(R^2\), or the proportion of explained variance in our data (\(\tau^2\)), where \(\sigma^2\) is the residual variance not explained by our model:

\[ R^2 = \frac{\tau^2}{\tau^2 + \sigma^2} \tag{1}\]

It is perhaps surprising that we didn’t put priors on \(R^2\) directly until recently (Zhang et al. 2022, Aguilar and Bürkner 2023). Typically, independent priors are provided for all terms in the model, e.g. intercept, fixed effect coefficients, and random effect and residual scales. However, since \(R^2\) is just a deterministic function of the different model components, any combination of priors for parameters has an implied prior on \(R^2\). Consider a regression model where both the response and \(P\) covariates are scaled to have mean 0 and 1 variance, where “weakly informative” priors might be \(\mathcal{N}\left(0, 1\right)\) for both the coefficients and the residual variance, then the expected \(R^2\) is shown below.

Code
tibble(P = 0:22, 
       `R^2` = P / (P + 1)) |> 
  ggplot(aes(P, `R^2`)) + 
  geom_line() + 
  scale_x_continuous(breaks = seq(5, 20, 5), expand = c(0, 0)) + 
  scale_y_continuous(breaks = seq(0.2, 0.8, 0.2), limits = c(0, 1), expand = c(0, 0)) + 
  labs(y = expression(R^2))

Thus, increasing the number of covariates in the model with independent priors for coefficients implicitly assumes we are going to explain more of the variance. This is inconsistent with how we know the world works, as exemplified by the piranha problem, which essentially states we can only have so many strong, independent effects in a system.

In this post, I’m going to be discussing a particularly appealing joint prior over model features, the \(R^2\) Decomposition (R2D) prior. In a nutshell, we place an independent prior on \(R^2\), derive explained variance from it using Equation 1, and then decompose this using our various model compnents using a simplex.

R2D Prior

We start by placing a prior on \(R^2\), for which Aguilar and Bürkner (2023) suggest a Beta distribution. Defining R2 as a [0, 1] bounded parameter in Stan imposes a uniform prior, which will do for now. Assuming we are using normal linear regression, we then derive \(\tau^2 = \frac{R^2}{1 - R^2} \sigma^2\) (Equation 1) to be the total variance explained by the model. We now wish to decompose this variance into the variances of our separate model components, where I’ll start with a model that has \(P\) fixed effects, all of which are assumed to be standardised with mean 0 and unit variance. We decompose \(\tau^2\) across our \(P\) predictors with a simplex \(\boldsymbol{\phi}\), which can be modeled with appropriate distributions. Originally this was done with the Dirichlet distribution (resulting in the R2D2 prior name, after Dirichlet Decomposition), but recent work suggests that the logit-normal may be preferred to capture more complex dependence structures (Aguilar and Bürkner 2025). I won’t go into it here, and just rely on Stan’s uniform priors implied by the parameter type simplex. The coefficients are then modeled as

\[ \beta_p \sim \mathcal{N} \left(0, \tau^2 \phi_p \right), \quad p = 1, \dots, P \tag{2}\]

Centered parameterisations like this will often run into sampling issues, so instead we adopt the non-centered parameterisation1, where

\[ \begin{aligned} \beta_p &= z \cdot \sqrt{\tau^2 \phi_p} \\ z &\sim \mathcal{N} \left(0, 1 \right). \end{aligned} \tag{3}\]

Our probabilistic model is encoded in a Stan program program (carpenter2017?).

Code
data {
  int<lower=0> N, P;  // number of observations and fixed effects
  matrix[N, P] x;  // covariates
  vector[N] y;  // observations
  int<lower=0, upper=1> R2D;  // R2D indicator
}

transformed data {
  real y_mean = mean(y), y_sd = sd(y), inv_y_sd = inv(y_sd);  // mean/SD priors
  vector[P] shapes = rep_vector(0.5, P);  // Dirichlet shapes
}

parameters {
  real alpha;  // intercept
  real<lower=0> sigma;  // residual SD
  real<lower=0, upper=1> R2;  // proportion explained variance
  simplex[P] phi;  // variance decomposition
  vector[P] z;  // coefficient z-scores
}

transformed parameters {
  real tau2 = R2 / (1 - R2) * square(sigma);  // explained variance
  vector[P] scales = sqrt(tau2 * phi),  // prior scales
            beta = R2D ? scales .* z : z;  // coefficients
}

model {
  y ~ normal_id_glm(x, alpha, beta, sigma);
  alpha ~ normal(y_mean, 1);
  sigma ~ exponential(inv_y_sd);
  phi ~ dirichlet(shapes);
  z ~ std_normal();
}

generated quantities {
vector[N] log_lik;
  for (n in 1:N) {
    log_lik[n] = normal_lpdf(y[n] | alpha + dot_product(x[n], beta), sigma);
  }
}

One of the applications where these priors shine is with large numbers of predictors—let’s try this with \(P = 2N = 200\), where only the first 5 are simulated as non-zero. I’ll fit the model both with and without R2D prior, where the latter case just has standard normal priors for all of the coefficients.

Code
# simulate
N <- 100
P <- 200
alpha <- rnorm(1)
sigma <- rexp(1, 0.5)
R2 <- 0.8
tau2 <- R2 / (1 - R2) * sigma^2
u <- c(rgamma(5, 1, 1), rep(0, P - 5))
phi <- u / sum(u)  # Dirichlet decomposition
scales <- sqrt(phi * tau2)
beta <- rnorm(P, 0, scales)
x <- matrix(rnorm(N * P), N, P)
y <- rnorm(N, alpha + x %*% beta, sigma)

# fit
stan_data <- list(N = N, P = P, x = x, y = y)
fits <- map(0:1, \(x) {
  stan_data$R2D <- x
  lm$sample(stan_data, chains = 8, iter_warmup = 200, iter_sampling = 500, refresh = 0)
})

Wow—the R2D model not only recovered the correct effects with better predictive performance, but also ran considerably faster than the default model. I’ll show the estimates for the first 20 coefficients.

Code
# plot first 20
library(tidybayes)
map(fits, ~spread_rvars(., beta[p]) |> 
      mutate(truth = .env$beta)) |> 
  list_rbind(names_to = "R2D") |> 
  filter(p <= 20) |>
  ggplot(aes(factor(p))) + 
  facet_wrap(~ factor(R2D, labels = c("N(0, 1)", "R2D")), 
             scales = "free_y", ncol = 1) + 
  geom_hline(yintercept = 0, linetype = "dashed", colour = green) +
  stat_pointinterval(aes(ydist = beta), 
                     point_interval = median_hdci, .width = 0.95, 
                     size = 0.5, linewidth = 0.5) + 
  geom_point(aes(y = truth),
             colour = green, position = position_nudge(x = -0.2)) + 
  labs(x = "Predictor", y = "Coefficient")

Code
# loo
library(loo)
loos <- map(fits, ~.$loo())
loos
[[1]]

Computed from 4000 by 100 log-likelihood matrix.

         Estimate  SE
elpd_loo    -36.7 1.2
p_loo       108.5 1.2
looic        73.5 2.5
------
MCSE of elpd_loo is NA.
MCSE and ESS estimates assume MCMC draws (r_eff in [0.0, 0.0]).

Pareto k diagnostic values:
                         Count Pct.    Min. ESS
(-Inf, 0.7]   (good)      0     0.0%   <NA>    
   (0.7, 1]   (bad)      94    94.0%   <NA>    
   (1, Inf)   (very bad)  6     6.0%   <NA>    
See help('pareto-k-diagnostic') for details.

[[2]]

Computed from 4000 by 100 log-likelihood matrix.

         Estimate   SE
elpd_loo    127.0  6.8
p_loo        56.4  5.0
looic      -253.9 13.5
------
MCSE of elpd_loo is NA.
MCSE and ESS estimates assume MCMC draws (r_eff in [0.3, 0.9]).

Pareto k diagnostic values:
                         Count Pct.    Min. ESS
(-Inf, 0.7]   (good)     64    64.0%   147     
   (0.7, 1]   (bad)      32    32.0%   <NA>    
   (1, Inf)   (very bad)  4     4.0%   <NA>    
See help('pareto-k-diagnostic') for details.
Code
loo_compare(loos)
       elpd_diff se_diff
model2    0.0       0.0 
model1 -163.7       7.0 

To see the variance decomposition in action, we can also look at the variance partitions of the coefficients, again just shown for the first 20:

Code
fits[[2]] |> 
  spread_rvars(phi[p]) |> 
  mutate(truth = .env$phi) |> 
  filter(p <= 20) |> 
  ggplot(aes(factor(p), ydist = phi)) + 
  stat_pointinterval(point_interval = median_hdci, .width = 0.95, 
                     size = 0.5, linewidth = 0.5) + 
  geom_point(aes(y = truth), colour = green, position = position_nudge(x = -0.2)) + 
  scale_y_continuous(expand = c(0, 0)) + 
  labs(x = "Predictor", y = "Variance Partition")

How about random effects?

Random effects are usually modeled as normally distributed \(\boldsymbol{\epsilon} \sim \mathcal{N} \left( 0, \lambda \right)\), where the scale \(\lambda\) is estimated by the model. This is just another variance component we can add to the R2D prior: just like each fixed effect has a scale to be incorporated, each random effect is seen as a realisation of that partition of the variance. If we have \(R\) random effects with \(L_r\) levels each, our variance partition vector now has \(P + R\) elements with \(\sum_{r=1}^R L_r\) coefficients corresponding to each realisation of all of the random effects. In Stan, we would adjust our model as follows.

Code
data {
  int<lower=0> N, P, R;  // number of observations, fixed and random effects
  array[R] int<lower=1> L;  // random effect levels
  matrix[N, P] x;  // covariates
  array[N, R] int<lower=1> r;  // random levels
  vector[N] y;  // observations
}

transformed data {
  int L_sum = sum(L);  // total random effect realisations
  array[R] int L_idx = cumulative_sum(L);  // for indexing
  for (i in 1:R) {
    L_idx[i] -= L[i] - 1;
  }
  real y_mean = mean(y), y_sd = sd(y), inv_y_sd = inv(y_sd);
  vector[P + R] shapes = rep_vector(0.5, P + R);  // Dirichlet shapes
}

parameters {
  real alpha;  // intercept
  real<lower=0> sigma;  // residual SD
  real<lower=0, upper=1> R2;  // proportion explained variance
  simplex[P + R] phi;  // variance decomposition
  vector[P] beta_z;  // coefficient z-scores
  vector[L_sum] epsilon_z;  // random effect z-scores
}

transformed parameters {
  real tau2 = R2 / (1 - R2) * square(sigma);  // explained variance
  vector[P + R] scales = sqrt(tau2 * phi);  // prior scales
  vector[P] beta = scales[:P] .* beta_z;  // fixed effect coefficients
  matrix[max(L), R] epsilon;  // random effects
  for (i in 1:R) {
    epsilon[:L[i], i] = scales[P + i] .* segment(epsilon_z, L_idx[i], L[i]);
  }
  vector[N] mu = alpha + x * beta;
  for (i in 1:R) {
    mu += epsilon[r[:, i], i];
  }
}

model {
  y ~ normal(mu, sigma);
  alpha ~ normal(y_mean, 1);
  sigma ~ exponential(inv_y_sd);
  R2 ~ beta_proportion(0.8, 0.1);
  phi ~ dirichlet(shapes);
  beta_z ~ std_normal();
  epsilon_z ~ std_normal();
}

generated quantities {
  vector[N] log_lik;
  for (n in 1:N) {
    log_lik[n] = normal_lpdf(y[n] | mu[n], sigma);
  }
}

Let’s do another simulation that also has random effects, first with \(P = 10\) and \(R = 4\), while keeping \(R^2\) the same. Scales and variance partitions are recovered pretty well.

Code
# create more non-zero scales and new coefficients
P <- 10
R <- 4
u <- c(rgamma(5, 1, 1), rep(0, P - 5), rgamma(R, 1, 1))
phi <- u / sum(u)
scales <- sqrt(phi * tau2)
beta <- rnorm(P, 0, scales[1:P])
x <- matrix(rnorm(N * P), N, P)

# random effects
L <- rpois(R, 10)
epsilon <- map2(L, scales[(P + 1):(P + R)], ~rnorm(.x, 0, .y))
r <- map(L, ~sample(., N, replace = T))
re_effects <- map2(epsilon, r, ~.x[.y]) |> simplify2array()

# sample
y <- rnorm(N, alpha + x %*% beta + rowSums(re_effects), sigma)

# fit
# lmm <- cmdstan_model(here("blog/lmm.stan"))
fit <- lmm$sample(list(N = N, P = P, R = R, L = L, x = x, r = simplify2array(r), y = y), 
                  chains = 8, iter_warmup = 200, iter_sampling = 500, refresh = 0)
Running MCMC with 8 parallel chains...
Chain 1 finished in 1.2 seconds.
Chain 2 finished in 1.1 seconds.
Chain 3 finished in 1.1 seconds.
Chain 4 finished in 1.1 seconds.
Chain 7 finished in 1.1 seconds.
Chain 6 finished in 1.2 seconds.
Chain 8 finished in 1.1 seconds.
Chain 5 finished in 2.9 seconds.

All 8 chains finished successfully.
Mean chain execution time: 1.3 seconds.
Total execution time: 3.1 seconds.
Code
fit |> 
  gather_rvars(scales[v], phi[v]) |> 
  mutate(truth = c(scales, .env$phi),
         type = if_else(v > P, "Random", "Fixed")) |> 
  ggplot(aes(factor(v), ydist = .value)) + 
  facet_grid(factor(.variable, labels = c("phi", "lambda")) |> fct_rev() ~ type, 
             scales = "free", labeller = label_parsed) + 
  stat_pointinterval(point_interval = median_hdci, .width = 0.95, 
                     size = 0.5, linewidth = 0.5) + 
  geom_point(aes(y = truth), colour = green, position = position_nudge(x = -0.2)) +  
  labs(x = "Component", y = "Posterior")

Let’s try it again with \(P = 50\). Here we start running into trouble. I’m only plotting the first 10 coefficients.

Code
# create more non-zero scales and new coefficients
P <- 50
R <- 4
u <- c(rgamma(5, 1, 1), rep(0, P - 5), rgamma(R, 1, 1))
phi <- u / sum(u)
scales <- sqrt(phi * tau2)
beta <- rnorm(P, 0, scales[1:P])
x <- matrix(rnorm(N * P), N, P)

# random effects
L <- rpois(R, 10)
epsilon <- map2(L, scales[(P + 1):(P + R)], ~rnorm(.x, 0, .y))
r <- map(L, ~sample(., N, replace = T))
re_effects <- map2(epsilon, r, ~.x[.y]) |> simplify2array()

# sample
y <- rnorm(N, alpha + x %*% beta + rowSums(re_effects), sigma)

# fit
fit <- lmm$sample(list(N = N, P = P, R = R, L = L, x = x, r = simplify2array(r), y = y), 
                  chains = 8, iter_warmup = 1000, iter_sampling = 1000, refresh = 0)
Running MCMC with 8 parallel chains...
Chain 1 finished in 2.0 seconds.
Chain 3 finished in 1.9 seconds.
Chain 6 finished in 1.9 seconds.
Chain 7 finished in 2.0 seconds.
Chain 2 finished in 2.0 seconds.
Chain 8 finished in 2.0 seconds.
Chain 4 finished in 2.1 seconds.
Chain 5 finished in 2.2 seconds.

All 8 chains finished successfully.
Mean chain execution time: 2.0 seconds.
Total execution time: 2.3 seconds.
Code
fit |> 
  gather_rvars(scales[v], phi[v]) |> 
  mutate(truth = c(scales, .env$phi),
         type = if_else(v > P, "Random", "Fixed")) |> 
  ggplot(aes(factor(v), ydist = .value)) + 
  facet_grid(factor(.variable, labels = c("phi", "lambda")) |> fct_rev() ~ type, 
             scales = "free", labeller = label_parsed) + 
  stat_pointinterval(point_interval = median_hdci, .width = 0.95, 
                     size = 0.5, linewidth = 0.5) + 
  geom_point(aes(y = truth), colour = green, position = position_nudge(x = -0.2)) +  
  labs(x = "Component", y = "Posterior")

References

Aguilar, J. E., and P.-C. Bürkner. 2023. Intuitive joint priors for Bayesian linear multilevel models: The R2D2M2 prior. Electronic Journal of Statistics 17:1711–1767.
Aguilar, J. E., and P.-C. Bürkner. 2025, February. Generalized Decomposition Priors on R2. arXiv.
Zhang, Y. D., B. P. Naughton, H. D. Bondell, and B. J. Reich. 2022. Bayesian regression using a prior on the model fit: The R2-D2 shrinkage prior. Journal of the American Statistical Association 117:862–874.

Footnotes

  1. Gassian random variables \(y \sim \mathcal{N}\left(\mu, \sigma \right)\) can always be written as \(y = \mu + z \cdot \sigma\) where \(z \sim \mathcal{N} \left(0, 1 \right)\), which usually samples better unless there is lots of data.↩︎