Skip to main content eteppo

Interpreting Models Has Never Been Easier

Published: 2024-07-02
Updated: 2024-07-02

I recommend to read the journal article presenting the marginaleffects-package by Arel-Bundock V, Greifer N, and others.

To demonstrate the kinds of questions that analysts and scientists need to ask from their models (and how it can be done easily), in this post I’ll use the marginaleffects package in R. (I also use Quarto CLI to render from Quarto markdown to regular markdown.)

library(marginaleffects)
library(tidyverse)

To get some data to work with, I use the mtcars dataset with:

  • mpg = miles per gallon (1 mpg is about 0.42 km/l)
  • cyl = number of cylinders
  • hp = horsepowers
  • wt = weight
data(mtcars)

d <- mtcars |>
  as_tibble() |>
  select(mpg, cyl, wt, hp) |>
  arrange(cyl) |>
  mutate(cyl = as.factor(cyl))
    
print(d)

Returns…

    # A tibble: 32 × 4
         mpg cyl      wt    hp
       <dbl> <fct> <dbl> <dbl>
     1  22.8 4      2.32    93
     2  24.4 4      3.19    62
     3  22.8 4      3.15    95
     4  32.4 4      2.2     66
     5  30.4 4      1.62    52
     6  33.9 4      1.84    65
     7  21.5 4      2.46    97
     8  27.3 4      1.94    66
     9  26   4      2.14    91
    10  30.4 4      1.51   113
    # ℹ 22 more rows

Normally you would think about the data generating mechanisms very carefully and construct a model of them, possibly iteratively. Often you would do it all for a single question in mind. We will ignore all of this since our focus is on a later step, model interpretation. Think of the following simple model as a replacement for almost any model you can think of.

m <- lm(mpg ~ cyl + hp + wt + (hp * wt), data = d)

The most common way to interpret models is simply to report the coefficients from the fitted model with their confidence intervals or p-values.

summary(m)

Returns…

    Call:
    lm(formula = mpg ~ cyl + hp + wt + (hp * wt), data = d)

    Residuals:
        Min      1Q  Median      3Q     Max 
    -3.5309 -1.6451 -0.4154  1.3838  4.4788 

    Coefficients:
                 Estimate Std. Error t value Pr(>|t|)    
    (Intercept) 47.337329   4.679790  10.115 1.67e-10 ***
    cyl6        -1.259073   1.489594  -0.845 0.405685    
    cyl8        -1.454339   2.063696  -0.705 0.487246    
    hp          -0.103331   0.031907  -3.238 0.003274 ** 
    wt          -7.306337   1.675258  -4.361 0.000181 ***
    hp:wt        0.023951   0.008966   2.671 0.012865 *  
    ---
    Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

    Residual standard error: 2.203 on 26 degrees of freedom
    Multiple R-squared:  0.888, Adjusted R-squared:  0.8664 
    F-statistic: 41.21 on 5 and 26 DF,  p-value: 1.503e-11

Maybe one of the coefficients has a causal interpretation and the others are confounders. Anyhow, these coefficients are actually quite difficult to understand and don’t really answer questions people usually care about. Some model parameters are not interpretable at all without some post-processing.

1. Predictions

A model doesn’t make much sense if its predictions don’t make sense. We should look at predictions for different full and partial observations of interest.

To get what we want, we should define five things (as described by marginaleffects)

  1. Quantity – what to predict.
  2. Grid – inputs to predict for.
  3. Aggregation – whether and how to summarize predictions.
  4. Uncertainty – how to measure uncertainty of prediction.
  5. Test – what kind of hypothesis to test.

Read an example call of predictions below with some alternatives in comments. First we get just regular predictions for the observed data.

predictions(
  model = m,

  # 1. Quantity
  # Predict value of the response variable.
  type = "response",
  ## Or maybe predict the value before the link function.
  # type = "link",
  ## Or predict the probability before sampling.
  # type = "probs",
  # Don't transform predicted values.
  transform = NULL,

  # 2. Grid
  # Predict for the observed values in data.
  newdata = NULL,
  ## Or use "grid" to predict for points across the whole dataspace.
  # newdata = "grid",
  ## Or use datagrid() to define custom grids of points like
  ## counterfactual datasets with some values fixed for all.
  # newdata = datagrid(cyl = c(4, 6, 8), grid_type = "counterfactual"),
  
  # Predict one value for each input. 
  # (This is another interface for counterfactuals.)
  variables = NULL,
  ## Predict one value for every counterfactual alternative.
  # variables = "cyl",
  ## Predict values for a subset of the counterfactual alternatives. 
  # variables = list("hp" = "minmax"), 

  # 3. Aggregation
  # Don't aggregate estimates.
  by = FALSE,
  byfun = NULL,
  # Averaging weights.
  wts = FALSE,

  # 4. Uncertainty
  # Use supplied model's covariance matrix for uncertainty. 
  vcov = TRUE,
  # Confidence interval of interest.
  conf_level = 0.90,
  # Use a Gaussian with infinite degrees of freedom (finite for Student's).
  df = Inf,
  # Use the "finite difference method with forward differences" for
  # calculating delta method standard errors.
  numderiv = "fdforward",

  # 5. Testing
  # Test whether the predicted value is 0 (default).
  hypothesis = 0,
  ## Test whether deviation from mean is zero (predicts the deviations).
  # hypothesis = "meandev",
  # Don't test for equivalence, non-inferiority or non-superiority.
  equivalence = NULL,
  # No multiple testing adjustments.
  p_adjust = NULL
)

Returns…

     Estimate Std. Error     z Pr(>|z|)     S 5.0 % 95.0 %
         25.9      0.688 37.71   <0.001   Inf  24.8   27.1
         22.4      1.161 19.26   <0.001 272.2  20.5   24.3
         21.7      1.143 18.96   <0.001 263.8  19.8   23.6
         27.9      0.726 38.48   <0.001   Inf  26.7   29.1
         32.2      1.317 24.43   <0.001 435.6  30.0   34.3
    --- 22 rows omitted. See ?avg_predictions and ?print.marginaleffects --- 
         17.6      0.968 18.20   <0.001 243.5  16.0   19.2
         15.0      0.719 20.91   <0.001 320.2  13.9   16.2
         15.8      0.732 21.63   <0.001 342.2  14.6   17.0
         15.5      1.064 14.55   <0.001 156.9  13.7   17.2
         13.8      1.513  9.14   <0.001  63.8  11.3   16.3
    Columns: rowid, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, mpg, cyl, hp, wt 
    Type:  response 

Then let’s aggregate these predictions by cylinder number. I still write all the arguments to make the five concept framework clear.

predictions(
  model=m,

  # 1. Quantity
  # Predict the response.
  type = "response",
  transform = NULL,

  # 2. Grid
  # Predict for the observed values.
  newdata = NULL,
  variables = NULL,

  # 3. Aggregation
  # Aggregate predictions by cyl-variable.
  by = "cyl",
  # The summary is the mean of predictions.
  byfun = mean,
  wts = FALSE,

  # 4. Uncertainty.
  # Use default delta method to get 90% confidence intervals
  # based on the model's covariance matrix.
  vcov = TRUE,
  conf_level = 0.90,
  numderiv = "fdforward",
  df = Inf,

  # 5. Testing
  hypothesis = 0,
  p_adjust = NULL,
  equivalence = NULL
)

Returns…

     cyl Estimate Std. Error    z Pr(>|z|)     S 5.0 % 95.0 %
       4     26.7      0.664 40.1   <0.001   Inf  25.6   27.8
       6     19.7      0.833 23.7   <0.001 410.5  18.4   21.1
       8     15.1      0.589 25.6   <0.001 479.6  14.1   16.1

    Columns: cyl, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high 
    Type:  response 

Last, let’s predict the same mean predictions for different cylinder numbers but this time counterfactually. This means that cyl is set to 4, 6, and 8, so that there will be 3 datasets and these are then aggregated by cyl. (This makes sense only in relation to some particular causal assumptions.)

We also leverage some defaults to make the call shorter.

predictions(
  model = m,
  # 1. Quantity
  type = "response",
  # 2. Grid
  newdata = datagrid(cyl = c(4, 6, 8), grid_type = "counterfactual"),
  # 3. Aggregation
  by = "cyl"
  # 4. Uncertainty.
  # 5. Testing.
)

Returns…

     cyl Estimate Std. Error    z Pr(>|z|)     S 2.5 % 97.5 %
       4     21.0      1.216 17.3   <0.001 219.8  18.6   23.4
       6     19.7      0.936 21.1   <0.001 325.9  17.9   21.6
       8     19.5      1.042 18.8   <0.001 258.6  17.5   21.6

    Columns: cyl, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high 
    Type:  response 

To avoid typing as much as possible, we could’ve used the variables argument to define the counterfactual grid and then combined this with a variant of the function called avg_predictions which has the correct default for aggregation…

avg_predictions(m, variables = "cyl")
     cyl Estimate Std. Error    z Pr(>|z|)     S 2.5 % 97.5 %
       4     21.0      1.216 17.3   <0.001 219.8  18.6   23.4
       6     19.7      0.936 21.1   <0.001 325.9  17.9   21.6
       8     19.5      1.042 18.8   <0.001 258.6  17.5   21.6

    Columns: cyl, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high 
    Type:  response 

2. Comparisons

The most important questions to ask from the model are usually related to some comparison, difference, associations, or effect. Comparisons are functions of the predictions.

Once again, we need to define the five concepts and I’ll show how to do this explicitly first.

  1. Quantity – what kind of estimates?
  2. Grid – on which inputs are the estimates computed for?
  3. Aggregation – are the individual estimates aggregated?
  4. Uncertainty – how is the uncertainty of the estimate measured?
  5. Testing – what kind of statement about the estimate is statistically tested?

Here it is very crucial to pay attention to the grid. It can very confusing what kind of measure is actually being computed. Which units are included and which values are counterfactual/synthetic and which are observed?

comparisons(
  model = m,

  # 1. Quantities
  # Predict responses for the comparisons.
  type = "response",
  # Use a difference as a measure of comparison.
  comparison = "difference",
  # Or comparison = "ratio",
  # Or comparison = function (hi, lo) { hi / lo },

  # Compare reference (4) to others (6, 8) and +10 increment in horse power.
  # (Note again that this argument leads to counterfactual datasets.)
  variables = list(cyl = "reference", hp = 10),
  # Separate comparisons, not the joint change of both.
  cross = FALSE,

  # Don't transform estimates at the end.
  transform = NULL,

  # 2. Grid
  # Calculate the differences at the mean of observed values.
  newdata = "mean",

  # 3. Aggregation
  # Don't aggregate estimates (just one row in newdata anyway).
  by = FALSE,
  wts = FALSE,

  # 4. Uncertainty
  vcov = TRUE,
  conf_level = 0.95,
  df = Inf,
  numderiv = "fdforward",
  # NEW: Step size for numerical derivations (slope estimation).
  eps = NULL,

  # 5. Testing
  # Test whether the difference is zero.
  hypothesis = 0,
  equivalence = NULL,
  p_adjust = NULL
)
     Term Contrast Estimate Std. Error      z Pr(>|z|)   S  2.5 % 97.5 % cyl  hp
      cyl    6 - 4   -1.259      1.490 -0.845   0.3980 1.3 -4.179   1.66   8 147
      cyl    8 - 4   -1.454      2.064 -0.705   0.4810 1.1 -5.499   2.59   8 147
      hp     +10     -0.263      0.109 -2.421   0.0155 6.0 -0.475  -0.05   8 147
       wt
     3.22
     3.22
     3.22

    Columns: rowid, term, contrast, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, predicted_lo, predicted_hi, predicted, cyl, hp, wt, mpg 
    Type:  response 

Let’s imagine horse power has a causal interpretation given our model. We could estimate the effect of +20 increase in horsepower like this (leveraging some defaults again).

comparisons(
  model = m,

  # 1. Quantity
  type = "response",
  comparison = "difference",
  variables = list(hp = 20),

  # 2. Grid
  # Estimates on the observed values.
  newdata = NULL,
  
  # 3. Aggregation
  by = TRUE,

  # 4. Uncertainty
  # 5. Testing
  hypothesis = 0
)

Returns…

     Term  Contrast Estimate Std. Error     z Pr(>|z|)   S  2.5 % 97.5 %
       hp mean(+20)   -0.525      0.217 -2.42   0.0155 6.0 -0.951   -0.1

    Columns: term, contrast, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, predicted_lo, predicted_hi, predicted 
    Type:  response 

Avoiding typing like lava, they let you write the same just…

avg_comparisons(m, variables = list(hp = 20))

Returns…

     Term  Contrast Estimate Std. Error     z Pr(>|z|)   S  2.5 % 97.5 %
       hp mean(+20)   -0.525      0.217 -2.42   0.0155 6.0 -0.951   -0.1

    Columns: term, contrast, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, predicted_lo, predicted_hi, predicted 
    Type:  response 

But it’s almost always better to be explicit with all the arguments you use (in science, so that it’s immediately clear which method was used).

3. Slopes

Instead of comparing some sizable difference in continuous variables, like 20-increments in horse power for every observed unit, it’s also good to compare between infinitely small differences, or rather what’s the rate of change at specific points (partial derivatives).

Again we need to think about

  1. Quantity – which kind of slope measure
  2. Grid – on which input points
  3. Aggregation – average slope in subgroups
  4. Uncertainty – how to measure uncertainty of slopes
  5. Testing – which statement about the slopes to assess statistically
slopes(
  model = m,

  # 1. Quantity
  type = "response",

  # Use regular partial derivative.
  slope = "dydx",
  ## Or you could use other kind of slopes like
  # slope = "eyex" # Elasticity
  # slope = "eyex" # dYdX * Y

  # Estimate slope for the two continuous variables.
  variables = c("hp", "wt"),

  # 2. Grid
  # Estimate slopes conditional on observed values.
  newdata = NULL,

  # 3. Aggregation
  # Don't aggregate unit-level slopes.
  by = FALSE,
  # No weighted averaging.
  wts = FALSE,

  # 4. Uncertainty
  # Variance-covariance matrix from the model.
  vcov = TRUE,
  conf_level = 0.95,
  # Default delta method.
  numderiv = "fdforward",
  # Degrees of freedom.
  df = Inf,
  # Default step size in differentiation.
  eps = NULL,

  # 5. Testing
  # Test whether slope in zero.
  hypothesis = 0,
  # No interval-based testing.
  equivalence = NULL,
  # No multiple testing adjustments.
  p_adjust = NULL
)

Returns…

     Term Estimate Std. Error      z Pr(>|z|)    S   2.5 %   97.5 %
       hp  -0.0478     0.0142 -3.364   <0.001 10.3 -0.0756 -0.01994
       hp  -0.0269     0.0109 -2.474   0.0134  6.2 -0.0483 -0.00559
       hp  -0.0279     0.0109 -2.549   0.0108  6.5 -0.0493 -0.00644
       hp  -0.0506     0.0149 -3.394   <0.001 10.5 -0.0799 -0.02140
       hp  -0.0646     0.0189 -3.416   <0.001 10.6 -0.1017 -0.02756
    --- 54 rows omitted. See ?avg_slopes and ?print.marginaleffects --- 
       wt  -3.7137     0.6795 -5.465   <0.001 24.4 -5.0455 -2.38179
       wt  -1.4383     0.9207 -1.562   0.1182  3.1 -3.2428  0.36618
       wt  -3.1149     0.6500 -4.792   <0.001 19.2 -4.3889 -1.84084
       wt  -0.9832     1.0485 -0.938   0.3484  1.5 -3.0383  1.07183
       wt   0.7173     1.5976  0.449   0.6534  0.6 -2.4139  3.84851
    Columns: rowid, term, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, predicted_lo, predicted_hi, predicted, mpg, cyl, hp, wt 
    Type:  response 

You could look at something like the mean slope of horse power and weight within the subgroups of cylinder-counts for observed values.

slopes(
  model = m,

  # 1. Quantity
  type = "response",
  slope = "dydx",
  variables = c("hp", "wt"),

  # 2. Grid
  newdata = NULL,

  # 3. Aggregation
  by = "cyl"

  # 4. Uncertainty
  # Defaults...
  # 5. Testing
  # Defaults...
)

Returns…

     Term    Contrast cyl Estimate Std. Error      z Pr(>|z|)    S   2.5 %   97.5 %
       hp mean(dY/dX)   4 -0.04859     0.0144 -3.375  < 0.001 10.4 -0.0768 -0.02037
       hp mean(dY/dX)   6 -0.02867     0.0110 -2.608  0.00910  6.8 -0.0502 -0.00713
       hp mean(dY/dX)   8 -0.00755     0.0123 -0.615  0.53839  0.9 -0.0316  0.01649
       wt mean(dY/dX)   4 -5.32710     1.0331 -5.157  < 0.001 21.9 -7.3519 -3.30230
       wt mean(dY/dX)   6 -4.37745     0.7889 -5.549  < 0.001 25.1 -5.9236 -2.83126
       wt mean(dY/dX)   8 -2.29540     0.7294 -3.147  0.00165  9.2 -3.7249 -0.86588

    Columns: term, contrast, cyl, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, predicted_lo, predicted_hi, predicted 
    Type:  response