Interpreting Models Has Never Been Easier
I recommend to read the journal article presenting the marginaleffects-package by Arel-Bundock V, Greifer N, and others.
To demonstrate the kinds of questions that analysts and scientists need
to ask from their models (and how it can be done easily), in this post
I’ll use the marginaleffects
package in R. (I also use Quarto CLI to
render from Quarto markdown to regular markdown.)
library(marginaleffects)
library(tidyverse)
To get some data to work with, I use the mtcars dataset with:
- mpg = miles per gallon (1 mpg is about 0.42 km/l)
- cyl = number of cylinders
- hp = horsepowers
- wt = weight
data(mtcars)
d <- mtcars |>
as_tibble() |>
select(mpg, cyl, wt, hp) |>
arrange(cyl) |>
mutate(cyl = as.factor(cyl))
print(d)
Returns…
# A tibble: 32 × 4
mpg cyl wt hp
<dbl> <fct> <dbl> <dbl>
1 22.8 4 2.32 93
2 24.4 4 3.19 62
3 22.8 4 3.15 95
4 32.4 4 2.2 66
5 30.4 4 1.62 52
6 33.9 4 1.84 65
7 21.5 4 2.46 97
8 27.3 4 1.94 66
9 26 4 2.14 91
10 30.4 4 1.51 113
# ℹ 22 more rows
Normally you would think about the data generating mechanisms very carefully and construct a model of them, possibly iteratively. Often you would do it all for a single question in mind. We will ignore all of this since our focus is on a later step, model interpretation. Think of the following simple model as a replacement for almost any model you can think of.
m <- lm(mpg ~ cyl + hp + wt + (hp * wt), data = d)
The most common way to interpret models is simply to report the coefficients from the fitted model with their confidence intervals or p-values.
summary(m)
Returns…
Call:
lm(formula = mpg ~ cyl + hp + wt + (hp * wt), data = d)
Residuals:
Min 1Q Median 3Q Max
-3.5309 -1.6451 -0.4154 1.3838 4.4788
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 47.337329 4.679790 10.115 1.67e-10 ***
cyl6 -1.259073 1.489594 -0.845 0.405685
cyl8 -1.454339 2.063696 -0.705 0.487246
hp -0.103331 0.031907 -3.238 0.003274 **
wt -7.306337 1.675258 -4.361 0.000181 ***
hp:wt 0.023951 0.008966 2.671 0.012865 *
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Residual standard error: 2.203 on 26 degrees of freedom
Multiple R-squared: 0.888, Adjusted R-squared: 0.8664
F-statistic: 41.21 on 5 and 26 DF, p-value: 1.503e-11
Maybe one of the coefficients has a causal interpretation and the others are confounders. Anyhow, these coefficients are actually quite difficult to understand and don’t really answer questions people usually care about. Some model parameters are not interpretable at all without some post-processing.
1. Predictions
A model doesn’t make much sense if its predictions don’t make sense. We should look at predictions for different full and partial observations of interest.
To get what we want, we should define five things (as described by
marginaleffects
)
- Quantity – what to predict.
- Grid – inputs to predict for.
- Aggregation – whether and how to summarize predictions.
- Uncertainty – how to measure uncertainty of prediction.
- Test – what kind of hypothesis to test.
Read an example call of predictions
below with some alternatives in
comments. First we get just regular predictions for the observed data.
predictions(
model = m,
# 1. Quantity
# Predict value of the response variable.
type = "response",
## Or maybe predict the value before the link function.
# type = "link",
## Or predict the probability before sampling.
# type = "probs",
# Don't transform predicted values.
transform = NULL,
# 2. Grid
# Predict for the observed values in data.
newdata = NULL,
## Or use "grid" to predict for points across the whole dataspace.
# newdata = "grid",
## Or use datagrid() to define custom grids of points like
## counterfactual datasets with some values fixed for all.
# newdata = datagrid(cyl = c(4, 6, 8), grid_type = "counterfactual"),
# Predict one value for each input.
# (This is another interface for counterfactuals.)
variables = NULL,
## Predict one value for every counterfactual alternative.
# variables = "cyl",
## Predict values for a subset of the counterfactual alternatives.
# variables = list("hp" = "minmax"),
# 3. Aggregation
# Don't aggregate estimates.
by = FALSE,
byfun = NULL,
# Averaging weights.
wts = FALSE,
# 4. Uncertainty
# Use supplied model's covariance matrix for uncertainty.
vcov = TRUE,
# Confidence interval of interest.
conf_level = 0.90,
# Use a Gaussian with infinite degrees of freedom (finite for Student's).
df = Inf,
# Use the "finite difference method with forward differences" for
# calculating delta method standard errors.
numderiv = "fdforward",
# 5. Testing
# Test whether the predicted value is 0 (default).
hypothesis = 0,
## Test whether deviation from mean is zero (predicts the deviations).
# hypothesis = "meandev",
# Don't test for equivalence, non-inferiority or non-superiority.
equivalence = NULL,
# No multiple testing adjustments.
p_adjust = NULL
)
Returns…
Estimate Std. Error z Pr(>|z|) S 5.0 % 95.0 %
25.9 0.688 37.71 <0.001 Inf 24.8 27.1
22.4 1.161 19.26 <0.001 272.2 20.5 24.3
21.7 1.143 18.96 <0.001 263.8 19.8 23.6
27.9 0.726 38.48 <0.001 Inf 26.7 29.1
32.2 1.317 24.43 <0.001 435.6 30.0 34.3
--- 22 rows omitted. See ?avg_predictions and ?print.marginaleffects ---
17.6 0.968 18.20 <0.001 243.5 16.0 19.2
15.0 0.719 20.91 <0.001 320.2 13.9 16.2
15.8 0.732 21.63 <0.001 342.2 14.6 17.0
15.5 1.064 14.55 <0.001 156.9 13.7 17.2
13.8 1.513 9.14 <0.001 63.8 11.3 16.3
Columns: rowid, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, mpg, cyl, hp, wt
Type: response
Then let’s aggregate these predictions by cylinder number. I still write all the arguments to make the five concept framework clear.
predictions(
model=m,
# 1. Quantity
# Predict the response.
type = "response",
transform = NULL,
# 2. Grid
# Predict for the observed values.
newdata = NULL,
variables = NULL,
# 3. Aggregation
# Aggregate predictions by cyl-variable.
by = "cyl",
# The summary is the mean of predictions.
byfun = mean,
wts = FALSE,
# 4. Uncertainty.
# Use default delta method to get 90% confidence intervals
# based on the model's covariance matrix.
vcov = TRUE,
conf_level = 0.90,
numderiv = "fdforward",
df = Inf,
# 5. Testing
hypothesis = 0,
p_adjust = NULL,
equivalence = NULL
)
Returns…
cyl Estimate Std. Error z Pr(>|z|) S 5.0 % 95.0 %
4 26.7 0.664 40.1 <0.001 Inf 25.6 27.8
6 19.7 0.833 23.7 <0.001 410.5 18.4 21.1
8 15.1 0.589 25.6 <0.001 479.6 14.1 16.1
Columns: cyl, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high
Type: response
Last, let’s predict the same mean predictions for different cylinder
numbers but this time counterfactually. This means that cyl
is set to
4, 6, and 8, so that there will be 3 datasets and these are then
aggregated by cyl
. (This makes sense only in relation to some
particular causal assumptions.)
We also leverage some defaults to make the call shorter.
predictions(
model = m,
# 1. Quantity
type = "response",
# 2. Grid
newdata = datagrid(cyl = c(4, 6, 8), grid_type = "counterfactual"),
# 3. Aggregation
by = "cyl"
# 4. Uncertainty.
# 5. Testing.
)
Returns…
cyl Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 %
4 21.0 1.216 17.3 <0.001 219.8 18.6 23.4
6 19.7 0.936 21.1 <0.001 325.9 17.9 21.6
8 19.5 1.042 18.8 <0.001 258.6 17.5 21.6
Columns: cyl, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high
Type: response
To avoid typing as much as possible, we could’ve used the variables
argument to define the counterfactual grid and then combined this with a
variant of the function called avg_predictions
which has the correct
default for aggregation…
avg_predictions(m, variables = "cyl")
cyl Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 %
4 21.0 1.216 17.3 <0.001 219.8 18.6 23.4
6 19.7 0.936 21.1 <0.001 325.9 17.9 21.6
8 19.5 1.042 18.8 <0.001 258.6 17.5 21.6
Columns: cyl, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high
Type: response
2. Comparisons
The most important questions to ask from the model are usually related to some comparison, difference, associations, or effect. Comparisons are functions of the predictions.
Once again, we need to define the five concepts and I’ll show how to do this explicitly first.
- Quantity – what kind of estimates?
- Grid – on which inputs are the estimates computed for?
- Aggregation – are the individual estimates aggregated?
- Uncertainty – how is the uncertainty of the estimate measured?
- Testing – what kind of statement about the estimate is statistically tested?
Here it is very crucial to pay attention to the grid. It can very confusing what kind of measure is actually being computed. Which units are included and which values are counterfactual/synthetic and which are observed?
comparisons(
model = m,
# 1. Quantities
# Predict responses for the comparisons.
type = "response",
# Use a difference as a measure of comparison.
comparison = "difference",
# Or comparison = "ratio",
# Or comparison = function (hi, lo) { hi / lo },
# Compare reference (4) to others (6, 8) and +10 increment in horse power.
# (Note again that this argument leads to counterfactual datasets.)
variables = list(cyl = "reference", hp = 10),
# Separate comparisons, not the joint change of both.
cross = FALSE,
# Don't transform estimates at the end.
transform = NULL,
# 2. Grid
# Calculate the differences at the mean of observed values.
newdata = "mean",
# 3. Aggregation
# Don't aggregate estimates (just one row in newdata anyway).
by = FALSE,
wts = FALSE,
# 4. Uncertainty
vcov = TRUE,
conf_level = 0.95,
df = Inf,
numderiv = "fdforward",
# NEW: Step size for numerical derivations (slope estimation).
eps = NULL,
# 5. Testing
# Test whether the difference is zero.
hypothesis = 0,
equivalence = NULL,
p_adjust = NULL
)
Term Contrast Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 % cyl hp
cyl 6 - 4 -1.259 1.490 -0.845 0.3980 1.3 -4.179 1.66 8 147
cyl 8 - 4 -1.454 2.064 -0.705 0.4810 1.1 -5.499 2.59 8 147
hp +10 -0.263 0.109 -2.421 0.0155 6.0 -0.475 -0.05 8 147
wt
3.22
3.22
3.22
Columns: rowid, term, contrast, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, predicted_lo, predicted_hi, predicted, cyl, hp, wt, mpg
Type: response
Let’s imagine horse power has a causal interpretation given our model. We could estimate the effect of +20 increase in horsepower like this (leveraging some defaults again).
comparisons(
model = m,
# 1. Quantity
type = "response",
comparison = "difference",
variables = list(hp = 20),
# 2. Grid
# Estimates on the observed values.
newdata = NULL,
# 3. Aggregation
by = TRUE,
# 4. Uncertainty
# 5. Testing
hypothesis = 0
)
Returns…
Term Contrast Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 %
hp mean(+20) -0.525 0.217 -2.42 0.0155 6.0 -0.951 -0.1
Columns: term, contrast, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, predicted_lo, predicted_hi, predicted
Type: response
Avoiding typing like lava, they let you write the same just…
avg_comparisons(m, variables = list(hp = 20))
Returns…
Term Contrast Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 %
hp mean(+20) -0.525 0.217 -2.42 0.0155 6.0 -0.951 -0.1
Columns: term, contrast, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, predicted_lo, predicted_hi, predicted
Type: response
But it’s almost always better to be explicit with all the arguments you use (in science, so that it’s immediately clear which method was used).
3. Slopes
Instead of comparing some sizable difference in continuous variables, like 20-increments in horse power for every observed unit, it’s also good to compare between infinitely small differences, or rather what’s the rate of change at specific points (partial derivatives).
Again we need to think about
- Quantity – which kind of slope measure
- Grid – on which input points
- Aggregation – average slope in subgroups
- Uncertainty – how to measure uncertainty of slopes
- Testing – which statement about the slopes to assess statistically
slopes(
model = m,
# 1. Quantity
type = "response",
# Use regular partial derivative.
slope = "dydx",
## Or you could use other kind of slopes like
# slope = "eyex" # Elasticity
# slope = "eyex" # dYdX * Y
# Estimate slope for the two continuous variables.
variables = c("hp", "wt"),
# 2. Grid
# Estimate slopes conditional on observed values.
newdata = NULL,
# 3. Aggregation
# Don't aggregate unit-level slopes.
by = FALSE,
# No weighted averaging.
wts = FALSE,
# 4. Uncertainty
# Variance-covariance matrix from the model.
vcov = TRUE,
conf_level = 0.95,
# Default delta method.
numderiv = "fdforward",
# Degrees of freedom.
df = Inf,
# Default step size in differentiation.
eps = NULL,
# 5. Testing
# Test whether slope in zero.
hypothesis = 0,
# No interval-based testing.
equivalence = NULL,
# No multiple testing adjustments.
p_adjust = NULL
)
Returns…
Term Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 %
hp -0.0478 0.0142 -3.364 <0.001 10.3 -0.0756 -0.01994
hp -0.0269 0.0109 -2.474 0.0134 6.2 -0.0483 -0.00559
hp -0.0279 0.0109 -2.549 0.0108 6.5 -0.0493 -0.00644
hp -0.0506 0.0149 -3.394 <0.001 10.5 -0.0799 -0.02140
hp -0.0646 0.0189 -3.416 <0.001 10.6 -0.1017 -0.02756
--- 54 rows omitted. See ?avg_slopes and ?print.marginaleffects ---
wt -3.7137 0.6795 -5.465 <0.001 24.4 -5.0455 -2.38179
wt -1.4383 0.9207 -1.562 0.1182 3.1 -3.2428 0.36618
wt -3.1149 0.6500 -4.792 <0.001 19.2 -4.3889 -1.84084
wt -0.9832 1.0485 -0.938 0.3484 1.5 -3.0383 1.07183
wt 0.7173 1.5976 0.449 0.6534 0.6 -2.4139 3.84851
Columns: rowid, term, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, predicted_lo, predicted_hi, predicted, mpg, cyl, hp, wt
Type: response
You could look at something like the mean slope of horse power and weight within the subgroups of cylinder-counts for observed values.
slopes(
model = m,
# 1. Quantity
type = "response",
slope = "dydx",
variables = c("hp", "wt"),
# 2. Grid
newdata = NULL,
# 3. Aggregation
by = "cyl"
# 4. Uncertainty
# Defaults...
# 5. Testing
# Defaults...
)
Returns…
Term Contrast cyl Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 %
hp mean(dY/dX) 4 -0.04859 0.0144 -3.375 < 0.001 10.4 -0.0768 -0.02037
hp mean(dY/dX) 6 -0.02867 0.0110 -2.608 0.00910 6.8 -0.0502 -0.00713
hp mean(dY/dX) 8 -0.00755 0.0123 -0.615 0.53839 0.9 -0.0316 0.01649
wt mean(dY/dX) 4 -5.32710 1.0331 -5.157 < 0.001 21.9 -7.3519 -3.30230
wt mean(dY/dX) 6 -4.37745 0.7889 -5.549 < 0.001 25.1 -5.9236 -2.83126
wt mean(dY/dX) 8 -2.29540 0.7294 -3.147 0.00165 9.2 -3.7249 -0.86588
Columns: term, contrast, cyl, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, predicted_lo, predicted_hi, predicted
Type: response