Skip to main content eteppo

The Basic Idea of Causal Inference in Julia GLM

Published: 2023-08-09
Updated: 2023-08-09
using GLM, Distributions, StatsFuns, DataFrames, Gadfly, Root

As always, we have some data generating system.

function simulate(n::Int)
	z = rand(Normal(0, 3), 100)
	x = @. rand(Normal(-3 + -1.5*z, 4))
	y = @. rand(Binomial(1, logistic(1.5 + 3*x + -2*z)))
	return DataFrame(x = x, z = z, y = y)
end

We want to estimate the effect of x on y. More precisely, we want to know how y would change if we went and set the value of x to a certain value, compared to if we went and changed the value of x to some other value. This is called a causal effect.

It’s crucial to see that this is a hypothetical question (about counterfactuals) that could never be directly observed. But using the theory of causal inference, we can make a set of assumptions that allows us to use observed data to estimate this hypothetical effect. We also need some models because we don’t have infinite data but this is secondary.

Here’s some common assumptions you can use (in English):

  • Consistency: The counterfactual values of the cause (interventions) are sufficiently well-defined and these values are consistent with the observed values. Or you could assume that different versions of the value have the same effect on the outcome.
  • Positivity: All values of the cause are possible for everyone within the analyzed subgroups.
  • Exchangeability: If we flipped/shuffled the values of the cause within the analyzed subgroups, we would observe the same outcomes within the subgroups. Variables that let you achieve conditional (within-subgroup) exchangeability are often called confounders.

To estimate causal effects, you need to make causal assumptions. You can represent them as a directed acyclic graph or less commonly as a single-world intervention graph. Our graph has just three variables so doing it inline is fine. In a DAG, we would say that the nodes x, y, and z have the edges x <- z -> y and x -> y.

In a SWIG, we could say that the nodes X|x, Y(x), and Z have the edges Z -> X | x, X|x -> Y(x), and Z -> Y(x). Y(x) is the counterfactual outcome under the intervention x. The capital X is the observed value. | actually divides the node X|x into two unconnected nodes X (observed) and x (counterfactual intervention) – it is just convenient to keep them close.

Now that you have the graph representation, you can use the so-called backdoor criterion to find the confounders that give you exchangeability. You check each path from the cause to the outcome and see if the path is open. (This is automated. See for example dagitty.) These are the rules:

  • Cause has an open path to the outcome. (x -> y and z -> y)
  • Common causes of variables open a path between them. Conditioning on the common cause closes the open path. (x <- z -> y)
  • Conditioning on a common outcome of variables opens a path between them. Not conditioning on it keeps the path closed. (x -> c <- y)
  • Conditioning on a mediator of variables closes the path between them. (x -> m -> y)
  • Conditioning on outcomes of variables conditions partly on them. Partly conditioning on variables may partly open or close a path through those variables. (x <- z -> y and z -> z*)

The variables that can be conditioned on to get exchangeability must close all the backdoor paths from the cause to the outcome. The regular causal path is called a frontdoor path and should be open (DAG x -> y or SWIG x -> Y(x)).

In our example, we can see that z is a confounder. Three general methods exist to estimate the effect of x on y conditioning on z.

  • Standardization (parametric/plug-in g-formula)
  • Inverse-probability weighting (IPW)
  • G-estimation

Standardization and IPW are actually equivalent and are both estimators of the so-called g-formula. But they are calculated quite differently and this is relevant in practice. You can also combine these methods to get so-called robust estimators that give valid results even if some of the models are misspecified. It’s best to always use a robust method in practice but it’s easier to understand them one-by-one.

df = simulate(100)
zscore(x) = (x .- mean(x)) ./ std(x)
dfs = transform(df, [:x, :z] .=> zscore; renamecols = false)
dp = hstack(
	plot(dfs, x = :x, y = :y, Stat.y_jitter(range = 0.5)),
	plot(dfs, x = :x, y = :z),
	plot(dfs, x = :z, y = :y, Stat.y_jitter(range = 0.5))
)
draw(PNG("sldfyos87dfy.png", 15cm, 10cm), dp)

Standardization is the most straight-forward method. We first get a model that predicts y based on x and z. Then we predict new y values given the observed values of z and a value of x we are interested in as an intervention. Then we compare the predicted values of y by the alternative interventions x. This is effectively taking weighted averages over z.

For simplicity, we use the framework of generalized linear models from the GLM package.

# Confounder-adjusted outcome model.
model = glm(
	@formula(y ~ x + z),
	dfs,
	Binomial(),
	LogitLink()
)
y_x0_std = predict(model, DataFrame(z = dfs.z, x = 0))
y_x1_std = predict(model, DataFrame(z = dfs.z, x = 1))
# Compare predictions.
mean(y_x1_std) - mean(y_x0_std)

Inverse-probability weighting is a little different. We first get a model that predicts the probability that x (the cause) gets the observed value, based on z (the confounders). Then we get a model that predicts y based on x while using the probabilities from the first model as weights in the fitting process. Weights like this effectively measure how many observations like this were made, creating a so-called pseudopopulation.

xz_model = lm(@formula(x ~ z), dfs)

# Get probability density implied by the model for each observed value of cause.
xz_means = predict(xz_model)
xz_std = std(residuals(xz_model))
xz_prob = pdf.(Normal.(xz_means, xz_std), dfs.x)

# We repeat the same for just x. This "stabilizes" the weight measure
# but gives valid results.
x_prob = pdf.(Normal(mean(dfs.x), std(dfs.x), dfs.x)

# The stabilized weight looks like this.
x_sipw = x_prob ./ xz_prob

# Now we can fit the weighted outcome model.
y_model = glm(
	@formula(y ~ x),
	dfs,
	Binomial(),
	LogitLink(),
	wts = x_sipw
)
# And compare predictions.
y_x0_ipw = predict(y_model, DataFrame(x = 0))
y_x1_ipw = predict(y_model, DataFrame(x = 1))
y_x1_ipw - y_x0_ipw

G-estimation is very different. We first predict the counterfactual outcome Yx as a function of the observed outcome y and the observed cause x multiplied by the effect e, for example Yx = y - e*x. So the effect is an unknown parameter. This magical-seeming thing is valid when we assume consistency. Then we get a model that predicts the cause x based on the predicted counterfactual outcome Yx and confounders z. Assuming exchangeability within levels of z, the counterfactual outcome must be independent of the observed cause. This is infact the definition. So all this gives us a way to estimate the effect: we search for values of effect e that minimize the association of the predicted counterfactual outcomes and the observed causes conditioned on (within levels of) the confounders z.

This is quite complicated so an example can really reveal what’s happening.

function coefficient(effect)
	# Predicted counterfactual outcome assuming consistency.
	# @. makes the whole expression work for collections, not just single values.
	Yx = @. dfs.y * exp(-effect * dfs.x)
	# Predict observed cause from predicted counterfactual outcome 
	# and observed confounders.
	xyz_model = lm(
		@formula(x ~ 1 + Yx + z), 
		hcat(dfs, DataFrame(Yx = Yx))
	)
	# Assuming exchangeability, the coefficient for the predicted
	# counterfactual outcome should be zero. So this coefficient is 
	# like a measure of non-exchangeability.
	return coef(xyz_model)[2]
end
# Search from effects between [-10, 10] where it makes the coefficient zero.
find_zero(coefficient, (-10, 10))

In practice, you want to use trustworthy, data-type-general implementions of the robust versions of these methods. However, hopefully these examples were enough to convey the basic ideas of causal inference.