Home

Influence Functions from Scratch

Contents


In this post we study a deceptively simple question: how does removing a datapoint from the training set affect the model’s behavior?

Derivation

Suppose we have a dataset $X$ which contains $N$ datapoints $x_i$ for which we train a parametric model on a loss function:

$$L(X, \Theta) = \frac{1}{N} \sum_i L(x_i, \Theta)$$

such that it converges to $\Theta_0 = \argmin_\Theta L(X, \Theta)$.

If we remove a datapoint $x$ from the dataset, we can update the parameters by minimizing:

$$f(\theta) = L(X, \Theta_0 + \theta) - \frac{1}{N} L(x, \Theta_0 + \theta)$$

Taking a second-order Taylor expansion:

$$ \begin{align*} f(\theta) ={}& L(X, \Theta_0) +\underbrace{\nabla_\Theta L(X, \Theta_0)^\mathsf{T}\theta}_{=\;0} + \frac{1}{2}\theta^\mathsf{T}H_{\Theta_0}\theta \\ &- \frac{1}{N}\left[ L(x_, \Theta_0) + \nabla_\Theta L(x_, \Theta_0)^\mathsf{T}\theta + \frac{1}{2}\underbrace{\theta^\mathsf{T}H_{x_,\Theta_0}\theta}_{\mathcal{O}(1/N^2)} \right] \\ &+ \mathcal{O}(\Vert \theta\Vert^3) \end{align*} $$

Then if we assume $\Vert \theta \Vert \in \mathcal{O}(1/N)$ and retain only the terms whose complexity is quadratic in $1/N$:

$$ \begin{align*} f(\theta) \approx {}& L(X, \Theta_0) + \frac{1}{2}\theta^\mathsf{T}H_{\Theta_0}\theta \\ &- \frac{1}{N}\left[ L(x_, \Theta_0) + \nabla_\Theta L(x_, \Theta_0)^\mathsf{T}\theta \right] \\ \end{align*} $$

Minimizing with first-order optimality conditions gives:

$$\hat{\theta} = -\frac{1}{N} H_{\Theta_0}^{-1} \nabla_\Theta L(x, \Theta_0)$$

The above is a basic derivation of influence1, but can be generalized to allow for different weights on the datapoint $x$ and modify whether the datapoint is being added or removed based on $\epsilon$’s sign2:

$$\hat{\theta}(\epsilon) = \epsilon \cdot H_{\Theta_0}^{-1} \cdot \nabla_\Theta L(x, \Theta_0)$$

Consider the basic case where we have a linear model fitted with an OLS loss (for a recap check out this post).

$$L_{\mathrm{OLS}}(X, \beta) = \frac{1}{2} (y - X\beta)^\mathsf{T}(y-X\beta)$$

So the influence of a datapoint $x_i$ on the model parameters $\beta$ is: $$\hat{\theta}_{\mathrm{OLS}}(\epsilon) = \epsilon \cdot (X^\mathsf{T}X)^{-1} \cdot x_i^\mathsf{T} r_i $$

The inverse-Hessian orients and scales the direction the datapoint pushes the parameters, which for a linear model is the product of the covariate vector with the (scalar) residual. It can be difficult to interpret the influence values directly, so we often convert them into a measureable quantity through additional dot-products.

We can compute the influence of a datapoint on the model’s loss for another point $z$ using $\nabla_\Theta L(z, \Theta_0)^\mathsf{T} \cdot \hat{\theta}(\epsilon)$. A higher value indicates the datapoints have a similar effect on the model’s loss during training. If we set $z=x_i$ we get a value proportional to the datapoint’s leverage, defined as $P_{ii} = x_i^\mathsf{T} (X^\mathsf{T}X)^{-1} x_i$, indicating whether $x_i$ is an outlier.

Approximation

For a model with parameters $\Theta \in \R^D$, inverting the Hessian has cubic time complexity in $D$ which presents a computational challenge when working with billion-parameter models like LLMs. Fortunately, there are several well-studied approximations.

First, note that the

  1. Gauss-Newton Hessian rather than full Hessian
  2. K-FAC version
  3. EK-FAC 3

Note how the Hessian can be computed over the dataset once and then reused to compute the influence of many datapoints.

Any relation to Cook’s distance?

Connection between the Fisher and Hessian: https://arxiv.org/pdf/2003.11630